diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java index 3e1050c9205..0f54c67fd80 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java @@ -54,7 +54,8 @@ public class ServletStompEndpointRegistry implements StompEndpointRegistry { public ServletStompEndpointRegistry(WebSocketHandler webSocketHandler, - MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler) { + MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler, + boolean handleConnect) { Assert.notNull(webSocketHandler); Assert.notNull(userQueueSuffixResolver); @@ -63,6 +64,7 @@ public class ServletStompEndpointRegistry implements StompEndpointRegistry { this.subProtocolWebSocketHandler = findSubProtocolWebSocketHandler(webSocketHandler); this.stompHandler = new StompProtocolHandler(); this.stompHandler.setUserQueueSuffixResolver(userQueueSuffixResolver); + this.stompHandler.setHandleConnect(handleConnect); this.sockJsScheduler = defaultSockJsTaskScheduler; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java index ef077487fab..ba5fb13baa2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java @@ -57,8 +57,10 @@ public abstract class WebSocketMessageBrokerConfigurationSupport { @Bean public HandlerMapping brokerWebSocketHandlerMapping() { - ServletStompEndpointRegistry registry = new ServletStompEndpointRegistry( - subProtocolWebSocketHandler(), userQueueSuffixResolver(), brokerDefaultSockJsTaskScheduler()); + boolean brokerRelayConfigured = getMessageBrokerConfigurer().getStompBrokerRelay() != null; + ServletStompEndpointRegistry registry = new ServletStompEndpointRegistry(subProtocolWebSocketHandler(), + userQueueSuffixResolver(), brokerDefaultSockJsTaskScheduler(), !brokerRelayConfigured); + registerStompEndpoints(registry); AbstractHandlerMapping hm = registry.getHandlerMapping(); hm.setOrder(1); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index d2783cc94b0..5d93031ba0a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -17,13 +17,9 @@ package org.springframework.messaging.simp.stomp; import java.net.InetSocketAddress; -import java.util.ArrayList; import java.util.Collection; -import java.util.List; import java.util.Map; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicReference; import org.springframework.messaging.Message; @@ -249,12 +245,8 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private final String sessionId; - private final BlockingQueue> messageQueue = new LinkedBlockingQueue>(50); - private volatile StompConnection stompConnection = new StompConnection(); - private final Object monitor = new Object(); - private RelaySession(String sessionId) { Assert.notNull(sessionId, "sessionId is required"); @@ -291,6 +283,12 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler protected void handleTcpConnection(TcpConnection, Message> tcpConn, final Message connectMessage) { this.stompConnection.setTcpConnection(tcpConn); + tcpConn.on().close(new Runnable() { + @Override + public void run() { + connectionClosed(); + } + }); tcpConn.in().consume(new Consumer>() { @Override public void accept(Message message) { @@ -307,12 +305,8 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (StompCommand.CONNECTED == headers.getCommand()) { - synchronized(this.monitor) { - this.stompConnection.setReady(); - publishBrokerAvailableEvent(); - flushMessages(); - } - return; + this.stompConnection.setReady(); + publishBrokerAvailableEvent(); } headers.setSessionId(this.sessionId); @@ -344,24 +338,11 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler public void forward(Message message) { if (!this.stompConnection.isReady()) { - synchronized(this.monitor) { - if (!this.stompConnection.isReady()) { - this.messageQueue.add(message); - if (logger.isTraceEnabled()) { - logger.trace("Not connected, message queued. Queue size=" + this.messageQueue.size()); - } - return; - } - } + logger.warn("Message sent to relay before it was CONNECTED. Discarding message: " + message); + return; } - if (this.messageQueue.isEmpty()) { - forwardInternal(message); - } - else { - this.messageQueue.add(message); - flushMessages(); - } + forwardInternal(message); } private boolean forwardInternal(final Message message) { @@ -381,6 +362,12 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler logger.trace("Forwarding to STOMP broker, message: " + message); } + StompCommand command = StompHeaderAccessor.wrap(message).getCommand(); + + if (command == StompCommand.DISCONNECT) { + this.stompConnection.setDisconnected(); + } + final Deferred> deferred = new DeferredPromiseSpec().get(); tcpConnection.send((Message)message, new Consumer() { @Override @@ -396,7 +383,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler handleTcpClientFailure("Timed out waiting for message to be forwarded to the broker", null); } else if (!success) { - if (StompHeaderAccessor.wrap(message).getCommand() != StompCommand.DISCONNECT) { + if (command != StompCommand.DISCONNECT) { handleTcpClientFailure("Failed to forward message to the broker", null); } } @@ -408,13 +395,10 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler return (success != null) ? success : false; } - private void flushMessages() { - List> messages = new ArrayList>(); - this.messageQueue.drainTo(messages); - for (Message message : messages) { - if (!forwardInternal(message)) { - return; - } + protected void connectionClosed() { + relaySessions.remove(this.sessionId); + if (this.stompConnection.isReady()) { + sendError("Lost connection to the broker"); } } } @@ -461,6 +445,10 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private class SystemRelaySession extends RelaySession { + private static final long HEARTBEAT_SEND_INTERVAL = 10000; + + private static final long HEARTBEAT_RECEIVE_INTERVAL = 10000; + public static final String ID = "stompRelaySystemSessionId"; @@ -473,7 +461,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler headers.setAcceptVersion("1.1,1.2"); headers.setLogin(systemLogin); headers.setPasscode(systemPasscode); - headers.setHeartbeat(0,0); + headers.setHeartbeat(HEARTBEAT_SEND_INTERVAL, HEARTBEAT_RECEIVE_INTERVAL); Message connectMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); super.connect(connectMessage); } @@ -488,6 +476,11 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler }); } + @Override + protected void connectionClosed() { + publishBrokerUnavailableEvent(); + } + @Override protected void sendMessageToClient(Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java index f4ead2a35c9..1c70bbb669a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java @@ -70,6 +70,7 @@ public class StompProtocolHandler implements SubProtocolHandler { private MutableUserQueueSuffixResolver queueSuffixResolver = new SimpleUserQueueSuffixResolver(); + private volatile boolean handleConnect = false; /** * Configure a resolver to use to maintain queue suffixes for user @@ -86,6 +87,29 @@ public class StompProtocolHandler implements SubProtocolHandler { return this.queueSuffixResolver; } + /** + * Configures the handling of CONNECT frames. When {@code true}, CONNECT + * frames will be handled by this handler, and a CONNECTED response will be + * sent. When {@code false}, CONNECT frames will be forwarded for + * handling by another component. + * + * @param handleConnect {@code true} if connect frames should be handled + * by this handler, {@code false} otherwise. + */ + public void setHandleConnect(boolean handleConnect) { + this.handleConnect = handleConnect; + } + + /** + * Returns whether or not this handler will handle CONNECT frames. + * + * @return Returns {@code true} if this handler will handle CONNECT frames, + * otherwise {@code false}. + */ + public boolean willHandleConnect() { + return this.handleConnect; + } + @Override public List getSupportedProtocols() { return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp"); @@ -121,17 +145,17 @@ public class StompProtocolHandler implements SubProtocolHandler { message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); - if (SimpMessageType.CONNECT.equals(headers.getMessageType())) { + if (this.handleConnect && SimpMessageType.CONNECT.equals(headers.getMessageType())) { handleConnect(session, message); } - - outputChannel.send(message); + else { + outputChannel.send(message); + } } catch (Throwable t) { logger.error("Terminating STOMP session due to failure to send message: ", t); sendErrorMessage(session, t); } - } /** @@ -144,8 +168,8 @@ public class StompProtocolHandler implements SubProtocolHandler { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); headers.setCommandIfNotSet(StompCommand.MESSAGE); - if (StompCommand.CONNECTED.equals(headers.getCommand())) { - // Ignore for now since we already sent it + if (this.handleConnect && StompCommand.CONNECTED.equals(headers.getCommand())) { + // Ignore since we already sent it return; } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistryTests.java index 7531e8ed208..f4c190ba39d 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistryTests.java @@ -53,7 +53,7 @@ public class ServletStompEndpointRegistryTests { this.webSocketHandler = new SubProtocolWebSocketHandler(channel); this.queueSuffixResolver = new SimpleUserQueueSuffixResolver(); TaskScheduler taskScheduler = Mockito.mock(TaskScheduler.class); - this.registry = new ServletStompEndpointRegistry(webSocketHandler, queueSuffixResolver, taskScheduler); + this.registry = new ServletStompEndpointRegistry(webSocketHandler, queueSuffixResolver, taskScheduler, false); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java index 078b77c1008..7441536ef9b 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -30,7 +30,6 @@ import org.apache.commons.logging.LogFactory; import org.junit.After; import org.junit.Before; import org.junit.Test; - import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEventPublisher; import org.springframework.messaging.Message; @@ -63,16 +62,14 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { private ExpectationMatchingEventPublisher eventPublisher; + private int port; + @Before public void setUp() throws Exception { - int port = SocketUtils.findAvailableTcpPort(61613); + this.port = SocketUtils.findAvailableTcpPort(61613); - this.activeMQBroker = new BrokerService(); - this.activeMQBroker.addConnector("stomp://localhost:" + port); - this.activeMQBroker.setStartAsync(false); - this.activeMQBroker.setDeleteAllMessagesOnStartup(true); - this.activeMQBroker.start(); + createAndStartBroker(); this.responseChannel = new ExecutorSubscribableChannel(); this.responseHandler = new ExpectationMatchingMessageHandler(); @@ -86,6 +83,14 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { this.relay.start(); } + private void createAndStartBroker() throws Exception { + this.activeMQBroker = new BrokerService(); + this.activeMQBroker.addConnector("stomp://localhost:" + port); + this.activeMQBroker.setStartAsync(false); + this.activeMQBroker.setDeleteAllMessagesOnStartup(true); + this.activeMQBroker.start(); + } + @After public void tearDown() throws Exception { try { @@ -102,22 +107,24 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { String sess1 = "sess1"; MessageExchange conn1 = MessageExchangeBuilder.connect(sess1).build(); this.relay.handleMessage(conn1.message); + this.responseHandler.expect(conn1); String sess2 = "sess2"; MessageExchange conn2 = MessageExchangeBuilder.connect(sess2).build(); this.relay.handleMessage(conn2.message); + this.responseHandler.expect(conn2); + + this.responseHandler.awaitAndAssert(); String subs1 = "subs1"; String destination = "/topic/test"; MessageExchange subscribe = MessageExchangeBuilder.subscribeWithReceipt(sess1, subs1, destination, "r1").build(); - this.responseHandler.expect(subscribe); - this.relay.handleMessage(subscribe.message); + this.responseHandler.expect(subscribe); this.responseHandler.awaitAndAssert(); MessageExchange send = MessageExchangeBuilder.send(destination, "foo").andExpectMessage(sess1, subs1).build(); - this.responseHandler.reset(); this.responseHandler.expect(send); this.relay.handleMessage(send.message); @@ -129,7 +136,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { stopBrokerAndAwait(); - MessageExchange connect = MessageExchangeBuilder.connect("sess1").andExpectError().build(); + MessageExchange connect = MessageExchangeBuilder.connectWithError("sess1").build(); this.responseHandler.expect(connect); this.relay.handleMessage(connect.message); @@ -137,37 +144,31 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { } @Test - public void brokerUnvailableErrorFrameOnSend() throws Exception { + public void brokerBecomingUnvailableTriggersErrorFrame() throws Exception { String sess1 = "sess1"; MessageExchange connect = MessageExchangeBuilder.connect(sess1).build(); + this.responseHandler.expect(connect); + this.relay.handleMessage(connect.message); - // TODO: expect CONNECTED - Thread.sleep(2000); + this.responseHandler.awaitAndAssert(); - stopBrokerAndAwait(); + this.responseHandler.expect(MessageExchangeBuilder.error(sess1).build()); - MessageExchange subscribe = MessageExchangeBuilder.subscribe(sess1, "s1", "/topic/a").andExpectError().build(); - this.responseHandler.expect(subscribe); + stopBrokerAndAwait(); - this.relay.handleMessage(subscribe.message); this.responseHandler.awaitAndAssert(); } @Test public void brokerAvailabilityEvents() throws Exception { - // TODO: expect CONNECTED - Thread.sleep(2000); - - this.eventPublisher.expect(true, false); + this.eventPublisher.expect(true); + this.eventPublisher.awaitAndAssert(); + this.eventPublisher.expect(false); stopBrokerAndAwait(); - - // TODO: remove when stop is detecteded - this.relay.handleMessage(MessageExchangeBuilder.connect("sess1").build().message); - this.eventPublisher.awaitAndAssert(); } @@ -176,37 +177,55 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { String sess1 = "sess1"; MessageExchange conn1 = MessageExchangeBuilder.connect(sess1).build(); + this.responseHandler.expect(conn1); this.relay.handleMessage(conn1.message); + this.responseHandler.awaitAndAssert(); String subs1 = "subs1"; String destination = "/topic/test"; - MessageExchange subscribe = MessageExchangeBuilder.subscribeWithReceipt(sess1, subs1, destination, "r1").build(); + MessageExchange subscribe = + MessageExchangeBuilder.subscribeWithReceipt(sess1, subs1, destination, "r1").build(); this.responseHandler.expect(subscribe); this.relay.handleMessage(subscribe.message); this.responseHandler.awaitAndAssert(); + this.responseHandler.expect(MessageExchangeBuilder.error(sess1).build()); + stopBrokerAndAwait(); - // 1st message will see ERROR frame (broker shutdown is not but should be detected) - // 2nd message will be queued (a side effect of CONNECT/CONNECTED-buffering, likely to be removed) - // Finish this once the above changes are made. + this.responseHandler.awaitAndAssert(); -/* MessageExchange send = MessageExchangeBuilder.send(destination, "foo").build(); - this.responseHandler.reset(); - this.relay.handleMessage(send.message); - Thread.sleep(2000); + this.eventPublisher.expect(true, false); + this.eventPublisher.awaitAndAssert(); - this.activeMQBroker.start(); - Thread.sleep(5000); + this.eventPublisher.expect(true); + createAndStartBroker(); + this.eventPublisher.awaitAndAssert(); - send = MessageExchangeBuilder.send(destination, "foo").andExpectMessage(sess1, subs1).build(); - this.responseHandler.reset(); - this.responseHandler.expect(send); - this.relay.handleMessage(send.message); + // TODO The event publisher assertions show that the broker's back up and the system relay session + // has reconnected. We need to decide what we want the reconnect behaviour to be for client relay + // sessions and add further message sending and assertions as appropriate. At the moment any client + // sessions will be closed and an ERROR from will be sent. + } + @Test + public void disconnectClosesRelaySessionCleanly() throws Exception { + String sess1 = "sess1"; + MessageExchange conn1 = MessageExchangeBuilder.connect(sess1).build(); + this.responseHandler.expect(conn1); + this.relay.handleMessage(conn1.message); + this.responseHandler.awaitAndAssert(); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + headers.setSessionId(sess1); + + this.relay.handleMessage(MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build()); + + Thread.sleep(2000); + + // Check that we have not received an ERROR as a result of the connection closing this.responseHandler.awaitAndAssert(); -*/ } @@ -234,58 +253,66 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { */ private static class ExpectationMatchingMessageHandler implements MessageHandler { - private final List expected; + private final Object monitor = new Object(); - private final List actual = new CopyOnWriteArrayList<>(); + private final List expected; - private final List> unexpected = new CopyOnWriteArrayList<>(); + private final List actual = new ArrayList<>(); - private CountDownLatch latch = new CountDownLatch(1); + private final List> unexpected = new ArrayList<>(); public ExpectationMatchingMessageHandler(MessageExchange... expected) { - this.expected = new CopyOnWriteArrayList<>(expected); + synchronized (this.monitor) { + this.expected = new CopyOnWriteArrayList<>(expected); + } } - public void expect(MessageExchange... expected) { - this.expected.addAll(Arrays.asList(expected)); + synchronized (this.monitor) { + this.expected.addAll(Arrays.asList(expected)); + } } public void awaitAndAssert() throws InterruptedException { - boolean result = this.latch.await(10000, TimeUnit.MILLISECONDS); - assertTrue(getAsString(), result && this.unexpected.isEmpty()); - } - - public void reset() { - this.latch = new CountDownLatch(1); - this.expected.clear(); - this.actual.clear(); - this.unexpected.clear(); + long endTime = System.currentTimeMillis() + 10000; + synchronized (this.monitor) { + while (!this.expected.isEmpty() && System.currentTimeMillis() < endTime) { + this.monitor.wait(500); + } + boolean result = this.expected.isEmpty(); + assertTrue(getAsString(), result && this.unexpected.isEmpty()); + } } @Override public void handleMessage(Message message) throws MessagingException { - for (MessageExchange exch : this.expected) { - if (exch.matchMessage(message)) { - if (exch.isDone()) { - this.expected.remove(exch); - this.actual.add(exch); - if (this.expected.isEmpty()) { - this.latch.countDown(); + synchronized(this.monitor) { + for (MessageExchange exch : this.expected) { + if (exch.matchMessage(message)) { + if (exch.isDone()) { + this.expected.remove(exch); + this.actual.add(exch); + if (this.expected.isEmpty()) { + this.monitor.notifyAll(); + } } + return; } - return; } + this.unexpected.add(message); } - this.unexpected.add(message); } public String getAsString() { StringBuilder sb = new StringBuilder("\n"); - sb.append("INCOMPLETE:\n").append(this.expected).append("\n"); - sb.append("COMPLETE:\n").append(this.actual).append("\n"); - sb.append("UNMATCHED MESSAGES:\n").append(this.unexpected).append("\n"); + + synchronized (this.monitor) { + sb.append("INCOMPLETE:\n").append(this.expected).append("\n"); + sb.append("COMPLETE:\n").append(this.actual).append("\n"); + sb.append("UNMATCHED MESSAGES:\n").append(this.unexpected).append("\n"); + } + return sb.toString(); } } @@ -352,22 +379,28 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { this.headers = StompHeaderAccessor.wrap(message); } + public static MessageExchangeBuilder error(String sessionId) { + return new MessageExchangeBuilder(null).andExpectError(sessionId); + } public static MessageExchangeBuilder connect(String sessionId) { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); headers.setSessionId(sessionId); headers.setAcceptVersion("1.1,1.2"); Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); - return new MessageExchangeBuilder(message); + + MessageExchangeBuilder builder = new MessageExchangeBuilder(message); + builder.expected.add(new StompConnectedFrameMessageMatcher(sessionId)); + return builder; } - public static MessageExchangeBuilder subscribe(String sessionId, String subscriptionId, String destination) { - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE); + public static MessageExchangeBuilder connectWithError(String sessionId) { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); headers.setSessionId(sessionId); - headers.setSubscriptionId(subscriptionId); - headers.setDestination(destination); + headers.setAcceptVersion("1.1,1.2"); Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); - return new MessageExchangeBuilder(message); + MessageExchangeBuilder builder = new MessageExchangeBuilder(message); + return builder.andExpectError(); } public static MessageExchangeBuilder subscribeWithReceipt(String sessionId, String subscriptionId, @@ -515,35 +548,48 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { } } + private static class StompConnectedFrameMessageMatcher extends StompFrameMessageMatcher { + + + public StompConnectedFrameMessageMatcher(String sessionId) { + super(StompCommand.CONNECTED, sessionId); + } + + } + private static class ExpectationMatchingEventPublisher implements ApplicationEventPublisher { - private final List expected = new CopyOnWriteArrayList<>(); + private final List expected = new ArrayList<>(); - private final List actual = new CopyOnWriteArrayList<>(); + private final List actual = new ArrayList<>(); - private CountDownLatch latch = new CountDownLatch(1); + private final Object monitor = new Object(); public void expect(Boolean... expected) { - this.expected.addAll(Arrays.asList(expected)); + synchronized (this.monitor) { + this.expected.addAll(Arrays.asList(expected)); + } } public void awaitAndAssert() throws InterruptedException { - if (this.expected.size() == this.actual.size()) { + synchronized(this.monitor) { + long endTime = System.currentTimeMillis() + 5000; + while (this.expected.size() != this.actual.size() && System.currentTimeMillis() < endTime) { + this.monitor.wait(500); + } assertEquals(this.expected, this.actual); } - else { - assertTrue("Expected=" + this.expected + ", actual=" + this.actual, - this.latch.await(5, TimeUnit.SECONDS)); - } } @Override public void publishEvent(ApplicationEvent event) { if (event instanceof BrokerAvailabilityEvent) { - this.actual.add(((BrokerAvailabilityEvent) event).isBrokerAvailable()); - if (this.actual.size() == this.expected.size()) { - this.latch.countDown(); + synchronized(this.monitor) { + this.actual.add(((BrokerAvailabilityEvent) event).isBrokerAvailable()); + if (this.actual.size() == this.expected.size()) { + this.monitor.notifyAll(); + } } } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java index a6e2afa842b..78311b36d87 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java @@ -61,25 +61,15 @@ public class StompProtocolHandlerTests { } @Test - public void handleConnect() { + public void connectedResponseIsSentWhenHandlingConnect() { + this.stompHandler.setHandleConnect(true); TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers( "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build(); this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel); - verify(this.channel).send(this.messageCaptor.capture()); - Message actual = this.messageCaptor.getValue(); - assertNotNull(actual); - - StompHeaderAccessor headers = StompHeaderAccessor.wrap(actual); - assertEquals(StompCommand.CONNECT, headers.getCommand()); - assertEquals("s1", headers.getSessionId()); - assertEquals("joe", headers.getUser().getName()); - assertEquals("guest", headers.getLogin()); - assertEquals("PROTECTED", headers.getPasscode()); - assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat()); - assertEquals(new HashSet<>(Arrays.asList("1.1","1.0")), headers.getAcceptVersion()); + verifyNoMoreInteractions(this.channel); // Check CONNECTED reply @@ -95,4 +85,29 @@ public class StompProtocolHandlerTests { assertEquals("s1", replyHeaders.getNativeHeader("queue-suffix").get(0)); } + @Test + public void connectIsForwardedWhenNotHandlingConnect() { + this.stompHandler.setHandleConnect(false); + + TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers( + "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build(); + + this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel); + + verify(this.channel).send(this.messageCaptor.capture()); + Message actual = this.messageCaptor.getValue(); + assertNotNull(actual); + + StompHeaderAccessor headers = StompHeaderAccessor.wrap(actual); + assertEquals(StompCommand.CONNECT, headers.getCommand()); + assertEquals("s1", headers.getSessionId()); + assertEquals("joe", headers.getUser().getName()); + assertEquals("guest", headers.getLogin()); + assertEquals("PROTECTED", headers.getPasscode()); + assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat()); + assertEquals(new HashSet<>(Arrays.asList("1.1","1.0")), headers.getAcceptVersion()); + + assertEquals(0, this.session.getSentMessages().size()); + } + }