diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java index c155da7d681..c7652d4eb43 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,13 @@ package org.springframework.messaging.simp.broker; +import java.util.Collection; +import java.util.HashSet; import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.logging.Log; @@ -33,8 +38,14 @@ import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.Assert; /** - * Submit messages to an {@link ExecutorSubscribableChannel}, one at a time. - * The channel must have been configured with {@link #configureOutboundChannel}. + * {@code MessageChannel} decorator that ensures messages from the same session + * are sent and processed in the same order. This would not normally be the case + * with an {@code Executor} backed {@code MessageChannel} since the executor + * is free to submit tasks in any order. + * + *

To provide ordering, inbound messages are placed in a queue and sent one + * one at a time per session. Once a message is processed, a callback is used to + * notify that the next message from the same session can be sent through. * * @author Rossen Stoyanchev * @since 5.1 @@ -48,9 +59,7 @@ class OrderedMessageSender implements MessageChannel { private final Log logger; - private final Queue> messages = new ConcurrentLinkedQueue<>(); - - private final AtomicBoolean sendInProgress = new AtomicBoolean(false); + private final Control control = new Control(); public OrderedMessageSender(MessageChannel channel, Log logger) { @@ -66,30 +75,40 @@ class OrderedMessageSender implements MessageChannel { @Override public boolean send(Message message, long timeout) { - this.messages.add(message); + this.control.addMessage(message); trySend(); return true; } private void trySend() { - // Take sendInProgress flag only if queue is not empty - if (this.messages.isEmpty()) { - return; - } - - if (this.sendInProgress.compareAndSet(false, true)) { - sendNextMessage(); + if (this.control.acquireSendLock()) { + sendMessages(); } } - private void sendNextMessage() { - for (;;) { - Message message = this.messages.poll(); - if (message != null) { + private void sendMessages() { + for ( ; ; ) { + Set skipSet = new HashSet<>(); + for (Message message : this.control.getMessagesToSend()) { + String sessionId = SimpMessageHeaderAccessor.getSessionId(message.getHeaders()); + Assert.notNull(sessionId, () -> "No session id in " + message.getHeaders()); + if (skipSet.contains(sessionId)) { + continue; + } + if (!this.control.acquireSessionLock(sessionId)) { + skipSet.add(sessionId); + continue; + } + this.control.removeMessage(message); try { - addCompletionCallback(message); + getMutableAccessor(message).setHeader(COMPLETION_TASK_HEADER, (Runnable) () -> { + this.control.releaseSessionLock(sessionId); + if (this.control.hasRemainingWork()) { + trySend(); + } + }); if (this.channel.send(message)) { - return; + continue; } } catch (Throwable ex) { @@ -97,20 +116,24 @@ class OrderedMessageSender implements MessageChannel { logger.error("Failed to send " + message, ex); } } + // We didn't send + this.control.releaseSessionLock(sessionId); } - else { - // We ran out of messages.. - this.sendInProgress.set(false); - trySend(); - break; + + if (this.control.shouldYield()) { + this.control.releaseSendLock(); + if (!this.control.shouldYield()) { + trySend(); + } + return; } } } - private void addCompletionCallback(Message msg) { - SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(msg, SimpMessageHeaderAccessor.class); + private SimpMessageHeaderAccessor getMutableAccessor(Message message) { + SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class); Assert.isTrue(accessor != null && accessor.isMutable(), "Expected mutable SimpMessageHeaderAccessor"); - accessor.setHeader(COMPLETION_TASK_HEADER, (Runnable) this::sendNextMessage); + return accessor; } @@ -126,13 +149,13 @@ class OrderedMessageSender implements MessageChannel { Assert.isInstanceOf(ExecutorSubscribableChannel.class, channel, "An ExecutorSubscribableChannel is required for `preservePublishOrder`"); ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel; - if (execChannel.getInterceptors().stream().noneMatch(i -> i instanceof CallbackInterceptor)) { - execChannel.addInterceptor(0, new CallbackInterceptor()); + if (execChannel.getInterceptors().stream().noneMatch(i -> i instanceof CompletionTaskInterceptor)) { + execChannel.addInterceptor(0, new CompletionTaskInterceptor()); } } else if (channel instanceof ExecutorSubscribableChannel) { ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel; - execChannel.getInterceptors().stream().filter(i -> i instanceof CallbackInterceptor) + execChannel.getInterceptors().stream().filter(i -> i instanceof CompletionTaskInterceptor) .findFirst() .map(execChannel::removeInterceptor); @@ -140,13 +163,71 @@ class OrderedMessageSender implements MessageChannel { } - private static class CallbackInterceptor implements ExecutorChannelInterceptor { + /** + * Provides locks required for ordered message sending and execution within + * a session as well as storage for messages waiting to be sent. + */ + private static class Control { + + private final Queue> messages = new ConcurrentLinkedQueue<>(); + + private final ConcurrentMap sessionsInProgress = new ConcurrentHashMap<>(); + + private final AtomicBoolean workInProgress = new AtomicBoolean(false); + + + public void addMessage(Message message) { + this.messages.add(message); + } + + public void removeMessage(Message message) { + if (!this.messages.remove(message)) { + throw new IllegalStateException( + "Message " + message.getHeaders() + " was expected in the queue."); + } + } + + public Collection> getMessagesToSend() { + return this.messages; + } + + public boolean acquireSendLock() { + return this.workInProgress.compareAndSet(false, true); + } + + public void releaseSendLock() { + this.workInProgress.set(false); + } + + public boolean acquireSessionLock(String sessionId) { + if (this.sessionsInProgress.put(sessionId, Boolean.TRUE) != null) { + return false; + } + return true; + } + + public void releaseSessionLock(String sessionId) { + this.sessionsInProgress.remove(sessionId); + } + + public boolean hasRemainingWork() { + return !this.messages.isEmpty(); + } + + public boolean shouldYield() { + // No remaining work, or others can pick it up + return (!hasRemainingWork() || this.sessionsInProgress.size() > 0); + } + } + + + private static class CompletionTaskInterceptor implements ExecutorChannelInterceptor { @Override public void afterMessageHandled( - Message msg, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) { + Message message, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) { - Runnable task = (Runnable) msg.getHeaders().get(OrderedMessageSender.COMPLETION_TASK_HEADER); + Runnable task = (Runnable) message.getHeaders().get(OrderedMessageSender.COMPLETION_TASK_HEADER); if (task != null) { task.run(); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java index c1fffb0a1a3..cf7914c1935 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,10 @@ package org.springframework.messaging.simp.broker; +import java.time.Duration; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -26,7 +30,12 @@ import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.ExecutorSubscribableChannel; @@ -43,6 +52,8 @@ public class OrderedMessageSenderTests { private static final Log logger = LogFactory.getLog(OrderedMessageSenderTests.class); + private static final Random random = new Random(); + private OrderedMessageSender sender; @@ -74,46 +85,97 @@ public class OrderedMessageSenderTests { @Test public void test() throws InterruptedException { - int start = 1; - int end = 1000; + int sessionCount = 25; + int messagesPerSessionCount = 500; + + TestMessageHandler handler = new TestMessageHandler(sessionCount * messagesPerSessionCount); + this.channel.subscribe(handler); + + Publisher>> messageFluxes = + Flux.range(1, sessionCount).map(sessionId -> + Flux.range(1, messagesPerSessionCount) + .map(sequence -> createMessage(sessionId, sequence)) + .delayElements(Duration.ofMillis(Math.abs(random.nextLong()) % 5))); + + Flux.merge(messageFluxes) + .doOnNext(message -> this.sender.send(message)) + .blockLast(); + + handler.await(20, TimeUnit.SECONDS); + + assertThat(handler.getDescription()).isEqualTo("Total processed: " + sessionCount * messagesPerSessionCount); + assertThat(handler.getSequenceBySession()).hasSize(sessionCount); + handler.getSequenceBySession().forEach((key, value) -> + assertThat(value.get()).as(key).isEqualTo(messagesPerSessionCount)); + } + + private static Message createMessage(Integer sessionId, Integer sequence) { + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + accessor.setSessionId("session" + sessionId); + accessor.setHeader("seq", sequence); + accessor.setLeaveMutable(true); + return MessageBuilder.createMessage("payload", accessor.getMessageHeaders()); + } + + + private static class TestMessageHandler implements MessageHandler { + + private final int totalExpected; + + private final Map sequenceBySession = new ConcurrentHashMap<>(); + + private final AtomicReference description = new AtomicReference<>(); + + private final AtomicInteger totalReceived = new AtomicInteger(); - AtomicInteger index = new AtomicInteger(start); - AtomicReference result = new AtomicReference<>(); - CountDownLatch latch = new CountDownLatch(1); + private final CountDownLatch latch = new CountDownLatch(1); - this.channel.subscribe(message -> { - int expected = index.getAndIncrement(); - Integer actual = (Integer) message.getHeaders().getOrDefault("seq", -1); - if (actual != expected) { - result.set("Expected: " + expected + ", but was: " + actual); + TestMessageHandler(int totalExpected) { + this.totalExpected = totalExpected; + } + + public void await(long timeout, TimeUnit timeUnit) throws InterruptedException { + latch.await(timeout, timeUnit); + } + + public Map getSequenceBySession() { + return sequenceBySession; + } + + public String getDescription() { + return description.get(); + } + + @Override + public void handleMessage(Message message) throws MessagingException { + String id = SimpMessageHeaderAccessor.getSessionId(message.getHeaders()); + Integer seq = (Integer) message.getHeaders().getOrDefault("seq", -1); + + AtomicInteger prev = sequenceBySession.computeIfAbsent(id, i -> new AtomicInteger(0)); + if (!prev.compareAndSet(seq - 1, seq)) { + description.set("Out of order, session=" + id + ", prev=" + prev + ", next=" + seq); latch.countDown(); return; } - if (actual == 100 || actual == 200) { + + if (seq == 100) { try { - Thread.sleep(200); + // Processing delay to cause other session messages to queue up + Thread.sleep(50); } catch (InterruptedException ex) { - result.set(ex.toString()); + description.set(ex.toString()); latch.countDown(); + return; } } - if (actual == end) { - result.set("Done"); + + int total = totalReceived.incrementAndGet(); + description.set("Total processed: " + total); + if (total == totalExpected) { latch.countDown(); } - }); - - for (int i = start; i <= end; i++) { - SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); - accessor.setHeader("seq", i); - accessor.setLeaveMutable(true); - this.sender.send(MessageBuilder.createMessage("payload", accessor.getMessageHeaders())); } - - latch.await(10, TimeUnit.SECONDS); - assertThat(result.get()).isEqualTo("Done"); } - }