diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java
index c155da7d681..c7652d4eb43 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,8 +16,13 @@
package org.springframework.messaging.simp.broker;
+import java.util.Collection;
+import java.util.HashSet;
import java.util.Queue;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.logging.Log;
@@ -33,8 +38,14 @@ import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert;
/**
- * Submit messages to an {@link ExecutorSubscribableChannel}, one at a time.
- * The channel must have been configured with {@link #configureOutboundChannel}.
+ * {@code MessageChannel} decorator that ensures messages from the same session
+ * are sent and processed in the same order. This would not normally be the case
+ * with an {@code Executor} backed {@code MessageChannel} since the executor
+ * is free to submit tasks in any order.
+ *
+ *
To provide ordering, inbound messages are placed in a queue and sent one
+ * one at a time per session. Once a message is processed, a callback is used to
+ * notify that the next message from the same session can be sent through.
*
* @author Rossen Stoyanchev
* @since 5.1
@@ -48,9 +59,7 @@ class OrderedMessageSender implements MessageChannel {
private final Log logger;
- private final Queue> messages = new ConcurrentLinkedQueue<>();
-
- private final AtomicBoolean sendInProgress = new AtomicBoolean(false);
+ private final Control control = new Control();
public OrderedMessageSender(MessageChannel channel, Log logger) {
@@ -66,30 +75,40 @@ class OrderedMessageSender implements MessageChannel {
@Override
public boolean send(Message> message, long timeout) {
- this.messages.add(message);
+ this.control.addMessage(message);
trySend();
return true;
}
private void trySend() {
- // Take sendInProgress flag only if queue is not empty
- if (this.messages.isEmpty()) {
- return;
- }
-
- if (this.sendInProgress.compareAndSet(false, true)) {
- sendNextMessage();
+ if (this.control.acquireSendLock()) {
+ sendMessages();
}
}
- private void sendNextMessage() {
- for (;;) {
- Message> message = this.messages.poll();
- if (message != null) {
+ private void sendMessages() {
+ for ( ; ; ) {
+ Set skipSet = new HashSet<>();
+ for (Message> message : this.control.getMessagesToSend()) {
+ String sessionId = SimpMessageHeaderAccessor.getSessionId(message.getHeaders());
+ Assert.notNull(sessionId, () -> "No session id in " + message.getHeaders());
+ if (skipSet.contains(sessionId)) {
+ continue;
+ }
+ if (!this.control.acquireSessionLock(sessionId)) {
+ skipSet.add(sessionId);
+ continue;
+ }
+ this.control.removeMessage(message);
try {
- addCompletionCallback(message);
+ getMutableAccessor(message).setHeader(COMPLETION_TASK_HEADER, (Runnable) () -> {
+ this.control.releaseSessionLock(sessionId);
+ if (this.control.hasRemainingWork()) {
+ trySend();
+ }
+ });
if (this.channel.send(message)) {
- return;
+ continue;
}
}
catch (Throwable ex) {
@@ -97,20 +116,24 @@ class OrderedMessageSender implements MessageChannel {
logger.error("Failed to send " + message, ex);
}
}
+ // We didn't send
+ this.control.releaseSessionLock(sessionId);
}
- else {
- // We ran out of messages..
- this.sendInProgress.set(false);
- trySend();
- break;
+
+ if (this.control.shouldYield()) {
+ this.control.releaseSendLock();
+ if (!this.control.shouldYield()) {
+ trySend();
+ }
+ return;
}
}
}
- private void addCompletionCallback(Message> msg) {
- SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(msg, SimpMessageHeaderAccessor.class);
+ private SimpMessageHeaderAccessor getMutableAccessor(Message> message) {
+ SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class);
Assert.isTrue(accessor != null && accessor.isMutable(), "Expected mutable SimpMessageHeaderAccessor");
- accessor.setHeader(COMPLETION_TASK_HEADER, (Runnable) this::sendNextMessage);
+ return accessor;
}
@@ -126,13 +149,13 @@ class OrderedMessageSender implements MessageChannel {
Assert.isInstanceOf(ExecutorSubscribableChannel.class, channel,
"An ExecutorSubscribableChannel is required for `preservePublishOrder`");
ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel;
- if (execChannel.getInterceptors().stream().noneMatch(i -> i instanceof CallbackInterceptor)) {
- execChannel.addInterceptor(0, new CallbackInterceptor());
+ if (execChannel.getInterceptors().stream().noneMatch(i -> i instanceof CompletionTaskInterceptor)) {
+ execChannel.addInterceptor(0, new CompletionTaskInterceptor());
}
}
else if (channel instanceof ExecutorSubscribableChannel) {
ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel;
- execChannel.getInterceptors().stream().filter(i -> i instanceof CallbackInterceptor)
+ execChannel.getInterceptors().stream().filter(i -> i instanceof CompletionTaskInterceptor)
.findFirst()
.map(execChannel::removeInterceptor);
@@ -140,13 +163,71 @@ class OrderedMessageSender implements MessageChannel {
}
- private static class CallbackInterceptor implements ExecutorChannelInterceptor {
+ /**
+ * Provides locks required for ordered message sending and execution within
+ * a session as well as storage for messages waiting to be sent.
+ */
+ private static class Control {
+
+ private final Queue> messages = new ConcurrentLinkedQueue<>();
+
+ private final ConcurrentMap sessionsInProgress = new ConcurrentHashMap<>();
+
+ private final AtomicBoolean workInProgress = new AtomicBoolean(false);
+
+
+ public void addMessage(Message> message) {
+ this.messages.add(message);
+ }
+
+ public void removeMessage(Message> message) {
+ if (!this.messages.remove(message)) {
+ throw new IllegalStateException(
+ "Message " + message.getHeaders() + " was expected in the queue.");
+ }
+ }
+
+ public Collection> getMessagesToSend() {
+ return this.messages;
+ }
+
+ public boolean acquireSendLock() {
+ return this.workInProgress.compareAndSet(false, true);
+ }
+
+ public void releaseSendLock() {
+ this.workInProgress.set(false);
+ }
+
+ public boolean acquireSessionLock(String sessionId) {
+ if (this.sessionsInProgress.put(sessionId, Boolean.TRUE) != null) {
+ return false;
+ }
+ return true;
+ }
+
+ public void releaseSessionLock(String sessionId) {
+ this.sessionsInProgress.remove(sessionId);
+ }
+
+ public boolean hasRemainingWork() {
+ return !this.messages.isEmpty();
+ }
+
+ public boolean shouldYield() {
+ // No remaining work, or others can pick it up
+ return (!hasRemainingWork() || this.sessionsInProgress.size() > 0);
+ }
+ }
+
+
+ private static class CompletionTaskInterceptor implements ExecutorChannelInterceptor {
@Override
public void afterMessageHandled(
- Message> msg, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) {
+ Message> message, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) {
- Runnable task = (Runnable) msg.getHeaders().get(OrderedMessageSender.COMPLETION_TASK_HEADER);
+ Runnable task = (Runnable) message.getHeaders().get(OrderedMessageSender.COMPLETION_TASK_HEADER);
if (task != null) {
task.run();
}
diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java
index c1fffb0a1a3..cf7914c1935 100644
--- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java
+++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,6 +16,10 @@
package org.springframework.messaging.simp.broker;
+import java.time.Duration;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
@@ -26,7 +30,12 @@ import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import org.reactivestreams.Publisher;
+import reactor.core.publisher.Flux;
+import org.springframework.messaging.Message;
+import org.springframework.messaging.MessageHandler;
+import org.springframework.messaging.MessagingException;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
@@ -43,6 +52,8 @@ public class OrderedMessageSenderTests {
private static final Log logger = LogFactory.getLog(OrderedMessageSenderTests.class);
+ private static final Random random = new Random();
+
private OrderedMessageSender sender;
@@ -74,46 +85,97 @@ public class OrderedMessageSenderTests {
@Test
public void test() throws InterruptedException {
- int start = 1;
- int end = 1000;
+ int sessionCount = 25;
+ int messagesPerSessionCount = 500;
+
+ TestMessageHandler handler = new TestMessageHandler(sessionCount * messagesPerSessionCount);
+ this.channel.subscribe(handler);
+
+ Publisher>> messageFluxes =
+ Flux.range(1, sessionCount).map(sessionId ->
+ Flux.range(1, messagesPerSessionCount)
+ .map(sequence -> createMessage(sessionId, sequence))
+ .delayElements(Duration.ofMillis(Math.abs(random.nextLong()) % 5)));
+
+ Flux.merge(messageFluxes)
+ .doOnNext(message -> this.sender.send(message))
+ .blockLast();
+
+ handler.await(20, TimeUnit.SECONDS);
+
+ assertThat(handler.getDescription()).isEqualTo("Total processed: " + sessionCount * messagesPerSessionCount);
+ assertThat(handler.getSequenceBySession()).hasSize(sessionCount);
+ handler.getSequenceBySession().forEach((key, value) ->
+ assertThat(value.get()).as(key).isEqualTo(messagesPerSessionCount));
+ }
+
+ private static Message createMessage(Integer sessionId, Integer sequence) {
+ SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
+ accessor.setSessionId("session" + sessionId);
+ accessor.setHeader("seq", sequence);
+ accessor.setLeaveMutable(true);
+ return MessageBuilder.createMessage("payload", accessor.getMessageHeaders());
+ }
+
+
+ private static class TestMessageHandler implements MessageHandler {
+
+ private final int totalExpected;
+
+ private final Map sequenceBySession = new ConcurrentHashMap<>();
+
+ private final AtomicReference description = new AtomicReference<>();
+
+ private final AtomicInteger totalReceived = new AtomicInteger();
- AtomicInteger index = new AtomicInteger(start);
- AtomicReference