From be6dbe54a358744ee896d8d295a5fc0992d9e2e4 Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Thu, 1 Aug 2013 14:17:55 +0100 Subject: [PATCH] Integration tests for the broker relay --- .../stomp/StompBrokerRelayMessageHandler.java | 37 +++- ...erRelayMessageHandlerIntegrationTests.java | 186 ++++++++++++++++++ .../messaging/simp/stomp/TestStompBroker.java | 161 +++++++++++++++ 3 files changed, 378 insertions(+), 6 deletions(-) create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/TestStompBroker.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 3fda85e5d87..cb4b498063d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -37,7 +37,9 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; import reactor.core.Environment; +import reactor.core.composable.Deferred; import reactor.core.composable.Promise; +import reactor.core.composable.spec.DeferredPromiseSpec; import reactor.function.Consumer; import reactor.tcp.TcpClient; import reactor.tcp.TcpConnection; @@ -454,19 +456,42 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife logger.trace("Forwarding message to STOMP broker, message id=" + message.getHeaders().getId()); } byte[] bytes = stompMessageConverter.fromMessage(message); + + final Deferred> deferred = new DeferredPromiseSpec().get(); + connection.send(new String(bytes, Charset.forName("UTF-8")), new Consumer() { + @Override public void accept(Boolean success) { - if (!success) { - String sessionId = StompHeaderAccessor.wrap(message).getSessionId(); - relaySessions.remove(sessionId); - sendError(sessionId, "Failed to relay message to broker"); + if (!success && StompHeaderAccessor.wrap(message).getCommand() != StompCommand.DISCONNECT) { + deferred.accept(false); + } else { + deferred.accept(true); } } }); - // TODO: detect if send fails and send ERROR downstream (except on DISCONNECT) - return true; + Boolean success = null; + + try { + success = deferred.compose().await(); + + if (success == null) { + sendError(sessionId, "Timed out waiting for message to be forwarded to the broker"); + } + else if (!success) { + sendError(sessionId, "Failed to forward message to the broker"); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + sendError(sessionId, "Interrupted while forwarding message to the broker"); + } + + if (success == null) { + success = false; + } + + return success; } private void flushMessages(TcpConnection connection) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java new file mode 100644 index 00000000000..fd00093ed58 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -0,0 +1,186 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import java.io.IOException; +import java.util.Arrays; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; +import org.springframework.messaging.SubscribableChannel; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.channel.ExecutorSubscribableChannel; +import org.springframework.util.SocketUtils; + +import static org.junit.Assert.*; + + +/** + * Integration tests for {@link StompBrokerRelayMessageHandler} + * + * @author Andy Wilkinson + */ +public class StompBrokerRelayMessageHandlerIntegrationTests { + + private final SubscribableChannel messageChannel = new ExecutorSubscribableChannel(); + + private final StompBrokerRelayMessageHandler relay = + new StompBrokerRelayMessageHandler(messageChannel, Arrays.asList("/queue/", "/topic/")); + + + @Test + public void basicPublishAndSubscribe() throws IOException, InterruptedException { + int port = SocketUtils.findAvailableTcpPort(); + + TestStompBroker stompBroker = new TestStompBroker(port); + stompBroker.start(); + + String client1SessionId = "abc123"; + String client2SessionId = "def456"; + + final CountDownLatch messageLatch = new CountDownLatch(1); + + messageChannel.subscribe(new MessageHandler() { + + @Override + public void handleMessage(Message message) throws MessagingException { + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); + if (headers.getCommand() == StompCommand.MESSAGE) { + messageLatch.countDown(); + } + } + + }); + + relay.setRelayPort(port); + relay.start(); + + relay.handleMessage(createConnectMessage(client1SessionId)); + relay.handleMessage(createConnectMessage(client2SessionId)); + relay.handleMessage(createSubscribeMessage(client1SessionId, "/topic/test")); + + stompBroker.awaitMessages(4); + + relay.handleMessage(createSendMessage(client2SessionId, "/topic/test", "fromClient2")); + + assertTrue(messageLatch.await(30, TimeUnit.SECONDS)); + + this.relay.stop(); + stompBroker.stop(); + } + + @Test + public void whenConnectFailsDueToTheBrokerBeingUnavailableAnErrorFrameIsSentToTheClient() + throws IOException, InterruptedException { + int port = SocketUtils.findAvailableTcpPort(); + + TestStompBroker stompBroker = new TestStompBroker(port); + stompBroker.start(); + + String sessionId = "abc123"; + + final CountDownLatch errorLatch = new CountDownLatch(1); + + messageChannel.subscribe(new MessageHandler() { + + @Override + public void handleMessage(Message message) throws MessagingException { + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); + if (headers.getCommand() == StompCommand.ERROR) { + errorLatch.countDown(); + } + } + + }); + + relay.setRelayPort(port); + relay.start(); + + stompBroker.awaitMessages(1); + + stompBroker.stop(); + + relay.handleMessage(createConnectMessage(sessionId)); + + errorLatch.await(30, TimeUnit.SECONDS); + } + + @Test + public void whenSendFailsDueToTheBrokerBeingUnavailableAnErrorFrameIsSentToTheClient() + throws IOException, InterruptedException { + int port = SocketUtils.findAvailableTcpPort(); + + TestStompBroker stompBroker = new TestStompBroker(port); + stompBroker.start(); + + String sessionId = "abc123"; + + final CountDownLatch errorLatch = new CountDownLatch(1); + + messageChannel.subscribe(new MessageHandler() { + + @Override + public void handleMessage(Message message) throws MessagingException { + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); + if (headers.getCommand() == StompCommand.ERROR) { + errorLatch.countDown(); + } + } + + }); + + relay.setRelayPort(port); + relay.start(); + + relay.handleMessage(createConnectMessage(sessionId)); + + stompBroker.awaitMessages(2); + + stompBroker.stop(); + + relay.handleMessage(createSubscribeMessage(sessionId, "/topic/test/")); + + errorLatch.await(30, TimeUnit.SECONDS); + } + + private Message createConnectMessage(String sessionId) { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + headers.setSessionId(sessionId); + return MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); + } + + private Message createSubscribeMessage(String sessionId, String destination) { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE); + headers.setSessionId(sessionId); + headers.setDestination(destination); + headers.setNativeHeader(StompHeaderAccessor.STOMP_ID_HEADER, sessionId); + + return MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); + } + + private Message createSendMessage(String sessionId, String destination, String payload) { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.setSessionId(sessionId); + headers.setDestination(destination); + + return MessageBuilder.withPayloadAndHeaders(payload.getBytes(), headers).build(); + } +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/TestStompBroker.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/TestStompBroker.java new file mode 100644 index 00000000000..fe5d7462dd5 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/TestStompBroker.java @@ -0,0 +1,161 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; + +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; + +import reactor.core.Environment; +import reactor.function.Consumer; +import reactor.tcp.TcpConnection; +import reactor.tcp.TcpServer; +import reactor.tcp.encoding.DelimitedCodec; +import reactor.tcp.encoding.StandardCodecs; +import reactor.tcp.netty.NettyTcpServer; +import reactor.tcp.spec.TcpServerSpec; + +/** + * @author Andy Wilkinson + */ +class TestStompBroker { + + private final StompMessageConverter messageConverter = new StompMessageConverter(); + + private final List> messages = new ArrayList>(); + + private final Object messageMonitor = new Object(); + + private final Object subscriberMonitor = new Object(); + + private final Map> subscribers = new HashMap>(); + + private final AtomicLong messageIdCounter = new AtomicLong(); + + private final int port; + + private volatile Environment environment; + + private volatile TcpServer tcpServer; + + TestStompBroker(int port) { + this.port = port; + } + + public void start() throws IOException { + this.environment = new Environment(); + + this.tcpServer = new TcpServerSpec(NettyTcpServer.class) + .env(this.environment) + .codec(new DelimitedCodec((byte) 0, true, StandardCodecs.STRING_CODEC)) + .listen(port) + .consume(new Consumer>() { + + @Override + public void accept(final TcpConnection connection) { + connection.consume(new Consumer() { + @Override + public void accept(String stompFrame) { + handleMessage(messageConverter.toMessage(stompFrame), connection); + } + }); + } + }) + .get(); + + this.tcpServer.start(); + } + + public void stop() throws IOException, InterruptedException { + this.tcpServer.shutdown().await(); + } + + private void handleMessage(Message message, TcpConnection connection) { + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); + if (headers.getCommand() == StompCommand.CONNECT) { + StompHeaderAccessor responseHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED); + MessageBuilder response = MessageBuilder.withPayloadAndHeaders(new byte[0], responseHeaders); + connection.send(new String(messageConverter.fromMessage(response.build()))); + } + else if (headers.getCommand() == StompCommand.SUBSCRIBE) { + String destination = headers.getDestination(); + synchronized (this.subscriberMonitor) { + Set subscribers = this.subscribers.get(destination); + if (subscribers == null) { + subscribers = new HashSet(); + this.subscribers.put(destination, subscribers); + } + String subscriptionId = headers.getFirstNativeHeader(StompHeaderAccessor.STOMP_ID_HEADER); + subscribers.add(new Subscription(subscriptionId, connection)); + } + } + else if (headers.getCommand() == StompCommand.SEND) { + String destination = headers.getDestination(); + synchronized (this.subscriberMonitor) { + Set subscriptions = this.subscribers.get(destination); + if (subscriptions != null) { + for (Subscription subscription: subscriptions) { + StompHeaderAccessor outboundHeaders = StompHeaderAccessor.create(StompCommand.MESSAGE); + outboundHeaders.setSubscriptionId(subscription.subscriptionId); + outboundHeaders.setMessageId(Long.toString(messageIdCounter.incrementAndGet())); + Message outbound = + MessageBuilder.withPayloadAndHeaders(message.getPayload(), outboundHeaders).build(); + subscription.tcpConnection.send(new String(this.messageConverter.fromMessage(outbound))); + } + } + } + } + addMessage(message); + } + + private void addMessage(Message message) { + synchronized (this.messageMonitor) { + this.messages.add(message); + this.messageMonitor.notifyAll(); + } + } + + public List> awaitMessages(int messageCount) throws InterruptedException { + synchronized (this.messageMonitor) { + while (this.messages.size() < messageCount) { + this.messageMonitor.wait(); + } + return this.messages; + } + } + + private static final class Subscription { + + private final String subscriptionId; + + private final TcpConnection tcpConnection; + + public Subscription(String subscriptionId, TcpConnection tcpConnection) { + this.subscriptionId = subscriptionId; + this.tcpConnection = tcpConnection; + } + + } +}