From bb941b61808afb4ed684d1769f381481fc2c646d Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Fri, 28 Aug 2020 20:40:15 +0100 Subject: [PATCH] OrderedMessageChannelDecorator doesn't preclude send limits Closes gh-25581 --- .../broker/AbstractBrokerMessageHandler.java | 4 +- ...va => OrderedMessageChannelDecorator.java} | 65 +++-- ... OrderedMessageChannelDecoratorTests.java} | 20 +- .../ConcurrentWebSocketSessionDecorator.java | 19 ++ .../messaging/StompSubProtocolHandler.java | 9 + .../handler/BlockingWebSocketSession.java | 61 +++++ ...currentWebSocketSessionDecoratorTests.java | 66 ++--- ...OrderedMessageSendingIntegrationTests.java | 255 ++++++++++++++++++ 8 files changed, 416 insertions(+), 83 deletions(-) rename spring-messaging/src/main/java/org/springframework/messaging/simp/broker/{OrderedMessageSender.java => OrderedMessageChannelDecorator.java} (64%) rename spring-messaging/src/test/java/org/springframework/messaging/simp/broker/{OrderedMessageSenderTests.java => OrderedMessageChannelDecoratorTests.java} (84%) create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/handler/BlockingWebSocketSession.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/messaging/OrderedMessageSendingIntegrationTests.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractBrokerMessageHandler.java index 2b7b5e07d0e..2f82c11e4aa 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractBrokerMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractBrokerMessageHandler.java @@ -142,7 +142,7 @@ public abstract class AbstractBrokerMessageHandler * @since 5.1 */ public void setPreservePublishOrder(boolean preservePublishOrder) { - OrderedMessageSender.configureOutboundChannel(this.clientOutboundChannel, preservePublishOrder); + OrderedMessageChannelDecorator.configureInterceptor(this.clientOutboundChannel, preservePublishOrder); this.preservePublishOrder = preservePublishOrder; } @@ -298,7 +298,7 @@ public abstract class AbstractBrokerMessageHandler */ protected MessageChannel getClientOutboundChannelForSession(String sessionId) { return this.preservePublishOrder ? - new OrderedMessageSender(getClientOutboundChannel(), logger) : getClientOutboundChannel(); + new OrderedMessageChannelDecorator(getClientOutboundChannel(), logger) : getClientOutboundChannel(); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java similarity index 64% rename from spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java rename to spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java index c155da7d681..9dbe96e1dea 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,15 +33,17 @@ import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.Assert; /** - * Submit messages to an {@link ExecutorSubscribableChannel}, one at a time. - * The channel must have been configured with {@link #configureOutboundChannel}. + * Decorator for an {@link ExecutorSubscribableChannel} that ensures messages + * are processed in the order they were published to the channel. Messages are + * sent one at a time with the next one released when the prevoius has been + * processed. This decorator is intended to be applied per session. * * @author Rossen Stoyanchev * @since 5.1 */ -class OrderedMessageSender implements MessageChannel { +public class OrderedMessageChannelDecorator implements MessageChannel { - static final String COMPLETION_TASK_HEADER = "simpSendCompletionTask"; + private static final String NEXT_MESSAGE_TASK_HEADER = "simpNextMessageTask"; private final MessageChannel channel; @@ -53,7 +55,7 @@ class OrderedMessageSender implements MessageChannel { private final AtomicBoolean sendInProgress = new AtomicBoolean(false); - public OrderedMessageSender(MessageChannel channel, Log logger) { + public OrderedMessageChannelDecorator(MessageChannel channel, Log logger) { this.channel = channel; this.logger = logger; } @@ -84,10 +86,14 @@ class OrderedMessageSender implements MessageChannel { private void sendNextMessage() { for (;;) { - Message message = this.messages.poll(); + Message message = this.messages.peek(); if (message != null) { try { - addCompletionCallback(message); + addNextMessageTaskHeader(message, () -> { + if (removeMessage(message)) { + sendNextMessage(); + } + }); if (this.channel.send(message)) { return; } @@ -97,9 +103,9 @@ class OrderedMessageSender implements MessageChannel { logger.error("Failed to send " + message, ex); } } + removeMessage(message); } else { - // We ran out of messages.. this.sendInProgress.set(false); trySend(); break; @@ -107,22 +113,40 @@ class OrderedMessageSender implements MessageChannel { } } - private void addCompletionCallback(Message msg) { - SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(msg, SimpMessageHeaderAccessor.class); + private boolean removeMessage(Message message) { + Message next = this.messages.peek(); + if (next == message) { + this.messages.remove(); + return true; + } + else { + return false; + } + } + + private static void addNextMessageTaskHeader(Message message, Runnable task) { + SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class); Assert.isTrue(accessor != null && accessor.isMutable(), "Expected mutable SimpMessageHeaderAccessor"); - accessor.setHeader(COMPLETION_TASK_HEADER, (Runnable) this::sendNextMessage); + accessor.setHeader(NEXT_MESSAGE_TASK_HEADER, task); } + /** + * Obtain the task to release the next message, if found. + */ + @Nullable + public static Runnable getNextMessageTask(Message message) { + return (Runnable) message.getHeaders().get(OrderedMessageChannelDecorator.NEXT_MESSAGE_TASK_HEADER); + } /** * Install or remove an {@link ExecutorChannelInterceptor} that invokes a - * completion task once the message is handled. + * completion task, if found in the headers of the message. * @param channel the channel to configure - * @param preservePublishOrder whether preserve order is on or off based on - * which an interceptor is either added or removed. + * @param preserveOrder whether preserve the order or publication; when + * "true" an interceptor is inserted, when "false" it removed. */ - static void configureOutboundChannel(MessageChannel channel, boolean preservePublishOrder) { - if (preservePublishOrder) { + public static void configureInterceptor(MessageChannel channel, boolean preserveOrder) { + if (preserveOrder) { Assert.isInstanceOf(ExecutorSubscribableChannel.class, channel, "An ExecutorSubscribableChannel is required for `preservePublishOrder`"); ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel; @@ -133,8 +157,7 @@ class OrderedMessageSender implements MessageChannel { else if (channel instanceof ExecutorSubscribableChannel) { ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel; execChannel.getInterceptors().stream().filter(i -> i instanceof CallbackInterceptor) - .findFirst() - .map(execChannel::removeInterceptor); + .findFirst().map(execChannel::removeInterceptor); } } @@ -144,9 +167,9 @@ class OrderedMessageSender implements MessageChannel { @Override public void afterMessageHandled( - Message msg, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) { + Message message, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) { - Runnable task = (Runnable) msg.getHeaders().get(OrderedMessageSender.COMPLETION_TASK_HEADER); + Runnable task = getNextMessageTask(message); if (task != null) { task.run(); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecoratorTests.java similarity index 84% rename from spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java rename to spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecoratorTests.java index c1fffb0a1a3..4d8b8b636b3 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecoratorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,15 +36,16 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import static org.assertj.core.api.Assertions.assertThat; /** - * Unit tests for {@link OrderedMessageSender}. + * Unit tests for {@link OrderedMessageChannelDecorator}. * @author Rossen Stoyanchev + * @see org.springframework.web.socket.messaging.OrderedMessageSendingIntegrationTests */ -public class OrderedMessageSenderTests { +public class OrderedMessageChannelDecoratorTests { - private static final Log logger = LogFactory.getLog(OrderedMessageSenderTests.class); + private static final Log logger = LogFactory.getLog(OrderedMessageChannelDecoratorTests.class); - private OrderedMessageSender sender; + private OrderedMessageChannelDecorator sender; ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(this.executor); @@ -59,9 +60,9 @@ public class OrderedMessageSenderTests { this.executor.afterPropertiesSet(); this.channel = new ExecutorSubscribableChannel(this.executor); - OrderedMessageSender.configureOutboundChannel(this.channel, true); + OrderedMessageChannelDecorator.configureInterceptor(this.channel, true); - this.sender = new OrderedMessageSender(this.channel, logger); + this.sender = new OrderedMessageChannelDecorator(this.channel, logger); } @@ -89,9 +90,10 @@ public class OrderedMessageSenderTests { latch.countDown(); return; } - if (actual == 100 || actual == 200) { + // Force messages to queue up periodically + if (actual % 101 == 0) { try { - Thread.sleep(200); + Thread.sleep(50); } catch (InterruptedException ex) { result.set(ex.toString()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java index 4c64bcfd608..46f3149fe85 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java @@ -22,10 +22,12 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.lang.Nullable; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; @@ -54,6 +56,10 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat private final OverflowStrategy overflowStrategy; + @Nullable + private Consumer> preSendCallback; + + private final Queue> buffer = new LinkedBlockingQueue<>(); private final AtomicInteger bufferSize = new AtomicInteger(); @@ -130,6 +136,15 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat return (start > 0 ? (System.currentTimeMillis() - start) : 0); } + /** + * Set a callback invoked after a message is added to the send buffer. + * @param callback the callback to invoke + * @since 5.3 + */ + public void setMessageCallback(Consumer> callback) { + this.preSendCallback = callback; + } + @Override public void sendMessage(WebSocketMessage message) throws IOException { @@ -140,6 +155,10 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat this.buffer.add(message); this.bufferSize.addAndGet(message.getPayloadLength()); + if (this.preSendCallback != null) { + this.preSendCallback.accept(message); + } + do { if (!tryFlushMessageBuffer()) { if (logger.isTraceEnabled()) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index b47ec0fca22..c48ebe9c426 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -39,6 +39,7 @@ import org.springframework.messaging.simp.SimpAttributes; import org.springframework.messaging.simp.SimpAttributesContextHolder; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator; import org.springframework.messaging.simp.stomp.BufferingStompDecoder; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompDecoder; @@ -57,6 +58,7 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; import org.springframework.web.socket.handler.SessionLimitExceededException; import org.springframework.web.socket.handler.WebSocketSessionDecorator; import org.springframework.web.socket.sockjs.transport.SockJsSession; @@ -461,6 +463,13 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE payload = errorMessage.getPayload(); } } + + Runnable task = OrderedMessageChannelDecorator.getNextMessageTask(message); + if (task != null) { + Assert.isInstanceOf(ConcurrentWebSocketSessionDecorator.class, session); + ((ConcurrentWebSocketSessionDecorator) session).setMessageCallback(m -> task.run()); + } + sendToClient(session, accessor, payload); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/handler/BlockingWebSocketSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/handler/BlockingWebSocketSession.java new file mode 100644 index 00000000000..6e592719bec --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/handler/BlockingWebSocketSession.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.socket.handler; + +import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import org.springframework.web.socket.WebSocketMessage; + +/** + * Blocks indefinitely on sending a message but provides a latch to notify when + * the message has been "sent" (i.e. session is blocked). + * + * @author Rossen Stoyanchev + */ +public class BlockingWebSocketSession extends TestWebSocketSession { + + private final AtomicReference sendLatch = new AtomicReference<>(); + + private final AtomicReference releaseLatch = new AtomicReference<>(); + + + public CountDownLatch initSendLatch() { + this.sendLatch.set(new CountDownLatch(1)); + return this.sendLatch.get(); + } + + @Override + public void sendMessage(WebSocketMessage message) throws IOException { + super.sendMessage(message); + if (this.sendLatch.get() != null) { + this.sendLatch.get().countDown(); + } + block(); + } + + private void block() { + try { + this.releaseLatch.set(new CountDownLatch(1)); + this.releaseLatch.get().await(); + } + catch (InterruptedException ex) { + ex.printStackTrace(); + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java index 937e1d992a0..3ba01e858be 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,13 +20,11 @@ import java.io.IOException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.Test; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; -import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator.OverflowStrategy; @@ -63,7 +61,7 @@ public class ConcurrentWebSocketSessionDecoratorTests { @Test public void sendAfterBlockedSend() throws IOException, InterruptedException { - BlockingSession session = new BlockingSession(); + BlockingWebSocketSession session = new BlockingWebSocketSession(); session.setOpen(true); ConcurrentWebSocketSessionDecorator decorator = @@ -85,9 +83,9 @@ public class ConcurrentWebSocketSessionDecoratorTests { } @Test - public void sendTimeLimitExceeded() throws IOException, InterruptedException { + public void sendTimeLimitExceeded() throws InterruptedException { - BlockingSession session = new BlockingSession(); + BlockingWebSocketSession session = new BlockingWebSocketSession(); session.setId("123"); session.setOpen(true); @@ -109,7 +107,7 @@ public class ConcurrentWebSocketSessionDecoratorTests { @Test public void sendBufferSizeExceeded() throws IOException, InterruptedException { - BlockingSession session = new BlockingSession(); + BlockingWebSocketSession session = new BlockingWebSocketSession(); session.setId("123"); session.setOpen(true); @@ -134,7 +132,7 @@ public class ConcurrentWebSocketSessionDecoratorTests { @Test // SPR-17140 public void overflowStrategyDrop() throws IOException, InterruptedException { - BlockingSession session = new BlockingSession(); + BlockingWebSocketSession session = new BlockingWebSocketSession(); session.setId("123"); session.setOpen(true); @@ -157,7 +155,7 @@ public class ConcurrentWebSocketSessionDecoratorTests { @Test public void closeStatusNormal() throws Exception { - BlockingSession session = new BlockingSession(); + BlockingWebSocketSession session = new BlockingWebSocketSession(); session.setOpen(true); WebSocketSession decorator = new ConcurrentWebSocketSessionDecorator(session, 10 * 1000, 1024); @@ -171,10 +169,10 @@ public class ConcurrentWebSocketSessionDecoratorTests { @Test public void closeStatusChangesToSessionNotReliable() throws Exception { - BlockingSession session = new BlockingSession(); + BlockingWebSocketSession session = new BlockingWebSocketSession(); session.setId("123"); session.setOpen(true); - CountDownLatch sentMessageLatch = session.getSentMessageLatch(); + CountDownLatch sentMessageLatch = session.initSendLatch(); int sendTimeLimit = 100; int bufferSizeLimit = 1024; @@ -182,7 +180,7 @@ public class ConcurrentWebSocketSessionDecoratorTests { ConcurrentWebSocketSessionDecorator decorator = new ConcurrentWebSocketSessionDecorator(session, sendTimeLimit, bufferSizeLimit); - Executors.newSingleThreadExecutor().submit((Runnable) () -> { + Executors.newSingleThreadExecutor().submit(() -> { TextMessage message = new TextMessage("slow message"); try { decorator.sendMessage(message); @@ -199,12 +197,13 @@ public class ConcurrentWebSocketSessionDecoratorTests { decorator.close(CloseStatus.PROTOCOL_ERROR); - assertThat(session.getCloseStatus()).as("CloseStatus should have changed to SESSION_NOT_RELIABLE").isEqualTo(CloseStatus.SESSION_NOT_RELIABLE); + assertThat(session.getCloseStatus()) + .as("CloseStatus should have changed to SESSION_NOT_RELIABLE") + .isEqualTo(CloseStatus.SESSION_NOT_RELIABLE); } private void sendBlockingMessage(ConcurrentWebSocketSessionDecorator session) throws InterruptedException { - BlockingSession delegate = (BlockingSession) session.getDelegate(); - CountDownLatch sentMessageLatch = delegate.getSentMessageLatch(); + CountDownLatch latch = ((BlockingWebSocketSession) session.getDelegate()).initSendLatch(); Executors.newSingleThreadExecutor().submit(() -> { TextMessage message = new TextMessage("slow message"); try { @@ -214,42 +213,7 @@ public class ConcurrentWebSocketSessionDecoratorTests { e.printStackTrace(); } }); - assertThat(sentMessageLatch.await(5, TimeUnit.SECONDS)).isTrue(); - } - - - - private static class BlockingSession extends TestWebSocketSession { - - private final AtomicReference nextMessageLatch = new AtomicReference<>(); - - private final AtomicReference releaseLatch = new AtomicReference<>(); - - - public CountDownLatch getSentMessageLatch() { - this.nextMessageLatch.set(new CountDownLatch(1)); - return this.nextMessageLatch.get(); - } - - @Override - public void sendMessage(WebSocketMessage message) throws IOException { - super.sendMessage(message); - if (this.nextMessageLatch != null) { - this.nextMessageLatch.get().countDown(); - } - block(); - } - - private void block() { - try { - this.releaseLatch.set(new CountDownLatch(1)); - this.releaseLatch.get().await(); - } - catch (InterruptedException e) { - e.printStackTrace(); - } - } - + assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/OrderedMessageSendingIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/OrderedMessageSendingIntegrationTests.java new file mode 100644 index 00000000000..5f8cf79331f --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/OrderedMessageSendingIntegrationTests.java @@ -0,0 +1,255 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.socket.messaging; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.lang.Nullable; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; +import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompEncoder; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.ExecutorSubscribableChannel; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.handler.BlockingWebSocketSession; +import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests to publish messages to an Executor backed channel wrapped with + * {@link OrderedMessageChannelDecorator} and handled by + * {@link StompSubProtocolHandler} delegating to a + * {@link ConcurrentWebSocketSessionDecorator} wrapped session. + * + *

The tests verify that: + *

    + *
  • messages are executed in the same order as they are published. + *
  • send buffer size and send time limits at the + * {@link ConcurrentWebSocketSessionDecorator} level are enforced. + *
+ * + *

The key is for {@link OrderedMessageChannelDecorator} to release the next + * message when after the current one is queued for sending, and not after it is + * sent, which may block and cause messages to accumulate in the + * {@link OrderedMessageChannelDecorator} instead of in + * {@link ConcurrentWebSocketSessionDecorator} where send limits are enforced. + * + * @author Rossen Stoyanchev + */ +public class OrderedMessageSendingIntegrationTests { + + private static final Log logger = LogFactory.getLog(OrderedMessageSendingIntegrationTests.class); + + private static final int MESSAGE_SIZE = new StompEncoder().encode(createMessage(0)).length; + + + private BlockingWebSocketSession blockingSession; + + private ExecutorSubscribableChannel subscribableChannel; + + private OrderedMessageChannelDecorator orderedMessageChannel; + + private ThreadPoolTaskExecutor executor; + + + + @BeforeEach + public void setup() { + this.blockingSession = new BlockingWebSocketSession(); + this.blockingSession.setId("1"); + this.blockingSession.setOpen(true); + + this.executor = new ThreadPoolTaskExecutor(); + this.executor.setCorePoolSize(Runtime.getRuntime().availableProcessors() * 2); + this.executor.setAllowCoreThreadTimeOut(true); + this.executor.afterPropertiesSet(); + + this.subscribableChannel = new ExecutorSubscribableChannel(this.executor); + OrderedMessageChannelDecorator.configureInterceptor(this.subscribableChannel, true); + + this.orderedMessageChannel = new OrderedMessageChannelDecorator(this.subscribableChannel, logger); + } + + @AfterEach + public void tearDown() { + this.executor.shutdown(); + } + + @Test + void sendAfterBlockedSend() throws InterruptedException { + + int messageCount = 1000; + + ConcurrentWebSocketSessionDecorator concurrentSessionDecorator = + new ConcurrentWebSocketSessionDecorator( + this.blockingSession, 60 * 1000, messageCount * MESSAGE_SIZE); + + TestMessageHandler handler = new TestMessageHandler(concurrentSessionDecorator); + subscribableChannel.subscribe(handler); + + List> expectedMessages = new ArrayList<>(messageCount); + + // Send one to block + Message message = createMessage(0); + expectedMessages.add(message); + this.orderedMessageChannel.send(message); + + CountDownLatch latch = new CountDownLatch(messageCount); + handler.setMessageLatch(latch); + + for (int i = 1; i <= messageCount; i++) { + message = createMessage(i); + expectedMessages.add(message); + this.orderedMessageChannel.send(message); + } + + latch.await(5, TimeUnit.SECONDS); + + assertThat(concurrentSessionDecorator.getTimeSinceSendStarted() > 0).isTrue(); + assertThat(concurrentSessionDecorator.getBufferSize()).isEqualTo((messageCount * MESSAGE_SIZE)); + assertThat(handler.getSavedMessages()).containsExactlyElementsOf(expectedMessages); + assertThat(blockingSession.isOpen()).isTrue(); + } + + @Test + void exceedTimeLimit() throws InterruptedException { + + ConcurrentWebSocketSessionDecorator concurrentSessionDecorator = + new ConcurrentWebSocketSessionDecorator(this.blockingSession, 100, 1024); + + TestMessageHandler messageHandler = new TestMessageHandler(concurrentSessionDecorator); + subscribableChannel.subscribe(messageHandler); + + // Send one to block + this.orderedMessageChannel.send(createMessage(0)); + + // Exceed send time.. + Thread.sleep(200); + + CountDownLatch messageLatch = new CountDownLatch(1); + messageHandler.setMessageLatch(messageLatch); + + // Send one more + this.orderedMessageChannel.send(createMessage(1)); + + messageLatch.await(5, TimeUnit.SECONDS); + + assertThat(messageHandler.getSavedException()).hasMessageMatching( + "Send time [\\d]+ \\(ms\\) for session '1' exceeded the allowed limit 100"); + } + + @Test + void exceedBufferSizeLimit() throws InterruptedException { + + ConcurrentWebSocketSessionDecorator concurrentSessionDecorator = + new ConcurrentWebSocketSessionDecorator(this.blockingSession, 60 * 1000, 2 * MESSAGE_SIZE); + + TestMessageHandler messageHandler = new TestMessageHandler(concurrentSessionDecorator); + subscribableChannel.subscribe(messageHandler); + + // Send one to block + this.orderedMessageChannel.send(createMessage(0)); + + int messageCount = 3; + CountDownLatch messageLatch = new CountDownLatch(messageCount); + messageHandler.setMessageLatch(messageLatch); + + for (int i = 1; i <= messageCount; i++) { + this.orderedMessageChannel.send(createMessage(i)); + } + + messageLatch.await(5, TimeUnit.SECONDS); + + assertThat(messageHandler.getSavedException()).hasMessage( + "Buffer size " + 3 * MESSAGE_SIZE + " bytes for session '1' exceeds the allowed limit " + 2 * MESSAGE_SIZE); + } + + private static Message createMessage(int index) { + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.MESSAGE); + accessor.setHeader("index", index); + accessor.setSubscriptionId("1"); + accessor.setLeaveMutable(true); + byte[] bytes = "payload".getBytes(StandardCharsets.UTF_8); + return MessageBuilder.createMessage(bytes, accessor.getMessageHeaders()); + + } + + + private static class TestMessageHandler implements MessageHandler { + + private final StompSubProtocolHandler subProtocolHandler = new StompSubProtocolHandler(); + + private final WebSocketSession session; + + @Nullable + private CountDownLatch messageLatch; + + private Queue> messages = new LinkedBlockingQueue<>(); + + private AtomicReference exception = new AtomicReference<>(); + + + public TestMessageHandler(WebSocketSession session) { + this.session = session; + } + + public void setMessageLatch(CountDownLatch latch) { + this.messageLatch = latch; + } + + public Collection> getSavedMessages() { + return this.messages; + } + + public Exception getSavedException() { + return this.exception.get(); + } + + @Override + public void handleMessage(Message message) throws MessagingException { + this.messages.add(message); + try { + this.subProtocolHandler.handleMessageToClient(this.session, message); + } + catch (Exception ex) { + this.exception.set(ex); + } + if (this.messageLatch != null) { + this.messageLatch.countDown(); + } + } + } +}