diff --git a/build.gradle b/build.gradle index 68f90ace0c2..1b550caa690 100644 --- a/build.gradle +++ b/build.gradle @@ -318,8 +318,8 @@ project("spring-messaging") { compile(project(":spring-context")) optional(project(":spring-websocket")) optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") - optional("org.projectreactor:reactor-core:1.0.0.M1") - optional("org.projectreactor:reactor-tcp:1.0.0.M1") + optional("org.projectreactor:reactor-core:1.0.0.BUILD-SNAPSHOT") + optional("org.projectreactor:reactor-tcp:1.0.0.BUILD-SNAPSHOT") optional("com.lmax:disruptor:3.1.1") testCompile("commons-dbcp:commons-dbcp:1.2.2") testCompile("javax.inject:javax.inject-tck:1") @@ -328,6 +328,7 @@ project("spring-messaging") { repositories { maven { url 'http://repo.springsource.org/libs-milestone' } // reactor + maven { url 'http://repo.springsource.org/libs-snapshot' } // reactor } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 55d239c9ba8..2ad569dc72a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -16,6 +16,7 @@ package org.springframework.messaging.simp.stomp; +import java.net.InetSocketAddress; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Collection; @@ -41,16 +42,20 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; import reactor.core.Environment; +import reactor.core.composable.Composable; import reactor.core.composable.Deferred; import reactor.core.composable.Promise; import reactor.core.composable.spec.DeferredPromiseSpec; import reactor.function.Consumer; +import reactor.tcp.Reconnect; import reactor.tcp.TcpClient; import reactor.tcp.TcpConnection; import reactor.tcp.encoding.DelimitedCodec; import reactor.tcp.encoding.StandardCodecs; import reactor.tcp.netty.NettyTcpClient; import reactor.tcp.spec.TcpClientSpec; +import reactor.tuple.Tuple; +import reactor.tuple.Tuple2; /** @@ -219,12 +224,24 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife private void openSystemSession() { RelaySession session = new RelaySession(STOMP_RELAY_SYSTEM_SESSION_ID) { + @Override protected void sendMessageToClient(Message message) { // ignore, only used to send messages // TODO: ERROR frame/reconnect } + + @Override + protected Composable> openConnection() { + return tcpClient.open(new Reconnect() { + @Override + public Tuple2 reconnect(InetSocketAddress currentAddress, int attempt) { + return Tuple.of(currentAddress, 5000L); + } + }); + } }; + this.relaySessions.put(STOMP_RELAY_SYSTEM_SESSION_ID, session); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); @@ -376,7 +393,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife private final BlockingQueue> messageQueue = new LinkedBlockingQueue>(50); - private Promise> promise; + private volatile TcpConnection connection = null; private volatile boolean isConnected = false; @@ -391,21 +408,24 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife public void open(final Message message) { Assert.notNull(message, "message is required"); - this.promise = tcpClient.open(); + Composable> connectionComposable = openConnection(); - this.promise.consume(new Consumer>() { + connectionComposable.consume(new Consumer>() { @Override - public void accept(TcpConnection connection) { - connection.in().consume(new Consumer() { + public void accept(TcpConnection newConnection) { + isConnected = false; + connection = newConnection; + newConnection.in().consume(new Consumer() { @Override public void accept(String stompFrame) { readStompFrame(stompFrame); } }); - forwardInternal(message, connection); + forwardInternal(message); } }); - this.promise.onError(new Consumer() { + + connectionComposable.when(Throwable.class, new Consumer() { @Override public void accept(Throwable ex) { relaySessions.remove(sessionId); @@ -415,6 +435,10 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife }); } + protected Composable> openConnection() { + return tcpClient.open(); + } + private void readStompFrame(String stompFrame) { // heartbeat @@ -432,7 +456,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife synchronized(this.monitor) { this.isConnected = true; brokerAvailable(); - flushMessages(this.promise.get()); + flushMessages(); } return; } @@ -447,7 +471,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife } private void sendError(String sessionId, String errorText) { - brokerUnavailable(); + disconnect(); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); headers.setSessionId(sessionId); @@ -456,6 +480,14 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife sendMessageToClient(errorMessage); } + private void disconnect() { + this.isConnected = false; + this.connection.close(); + this.connection = null; + + brokerUnavailable(); + } + public void forward(Message message) { if (!this.isConnected) { @@ -463,72 +495,77 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife if (!this.isConnected) { this.messageQueue.add(message); if (logger.isTraceEnabled()) { - logger.trace("Not connected yet, message queued, queue size=" + this.messageQueue.size()); + logger.trace("Not connected, message queued. Queue size=" + this.messageQueue.size()); } return; } } } - TcpConnection connection = this.promise.get(); - if (this.messageQueue.isEmpty()) { - forwardInternal(message, connection); + forwardInternal(message); } else { this.messageQueue.add(message); - flushMessages(connection); + flushMessages(); } } - private boolean forwardInternal(final Message message, TcpConnection connection) { - if (logger.isTraceEnabled()) { - logger.trace("Forwarding message to STOMP broker, message id=" + message.getHeaders().getId()); - } - byte[] bytes = stompMessageConverter.fromMessage(message); + private boolean forwardInternal(final Message message) { + TcpConnection localConnection = this.connection; - final Deferred> deferred = new DeferredPromiseSpec().get(); + if (localConnection != null) { + if (logger.isTraceEnabled()) { + logger.trace("Forwarding message to STOMP broker, message id=" + message.getHeaders().getId()); + } + byte[] bytes = stompMessageConverter.fromMessage(message); - connection.send(new String(bytes, Charset.forName("UTF-8")), new Consumer() { + final Deferred> deferred = new DeferredPromiseSpec().get(); - @Override - public void accept(Boolean success) { - if (!success && StompHeaderAccessor.wrap(message).getCommand() != StompCommand.DISCONNECT) { - deferred.accept(false); - } else { - deferred.accept(true); + String payload = new String(bytes, Charset.forName("UTF-8")); + localConnection.send(payload, new Consumer() { + + @Override + public void accept(Boolean success) { + if (!success && StompHeaderAccessor.wrap(message).getCommand() != StompCommand.DISCONNECT) { + deferred.accept(false); + } else { + deferred.accept(true); + } } - } - }); + }); - Boolean success = null; + Boolean success = null; - try { - success = deferred.compose().await(); + try { + success = deferred.compose().await(); - if (success == null) { - sendError(sessionId, "Timed out waiting for message to be forwarded to the broker"); + if (success == null) { + sendError(sessionId, "Timed out waiting for message to be forwarded to the broker"); + } + else if (!success) { + sendError(sessionId, "Failed to forward message to the broker"); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + sendError(sessionId, "Interrupted while forwarding message to the broker"); } - else if (!success) { - sendError(sessionId, "Failed to forward message to the broker"); + + if (success == null) { + success = false; } - } catch (InterruptedException ie) { - Thread.currentThread().interrupt(); - sendError(sessionId, "Interrupted while forwarding message to the broker"); - } - if (success == null) { - success = false; + return success; + } else { + return false; } - - return success; } - private void flushMessages(TcpConnection connection) { + private void flushMessages() { List> messages = new ArrayList>(); this.messageQueue.drainTo(messages); for (Message message : messages) { - if (!forwardInternal(message, connection)) { + if (!forwardInternal(message)) { return; } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java index 92eaa501bcd..de26562c40e 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -43,8 +43,8 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.util.SocketUtils; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertEquals; /** @@ -81,7 +81,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { final CountDownLatch messageLatch = new CountDownLatch(1); - messageChannel.subscribe(new MessageHandler() { + this.messageChannel.subscribe(new MessageHandler() { @Override public void handleMessage(Message message) throws MessagingException { @@ -93,18 +93,18 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { }); - relay.handleMessage(createConnectMessage(client1SessionId)); - relay.handleMessage(createConnectMessage(client2SessionId)); - relay.handleMessage(createSubscribeMessage(client1SessionId, "/topic/test")); + this.relay.handleMessage(createConnectMessage(client1SessionId)); + this.relay.handleMessage(createConnectMessage(client2SessionId)); + this.relay.handleMessage(createSubscribeMessage(client1SessionId, "/topic/test")); - stompBroker.awaitMessages(4); + this.stompBroker.awaitMessages(4); - relay.handleMessage(createSendMessage(client2SessionId, "/topic/test", "fromClient2")); + this.relay.handleMessage(createSendMessage(client2SessionId, "/topic/test", "fromClient2")); assertTrue(messageLatch.await(30, TimeUnit.SECONDS)); - assertEquals(1, brokerAvailabilityListener.availabilityEvents.size()); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); + List availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1); + assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); } @Test @@ -115,7 +115,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { final CountDownLatch errorLatch = new CountDownLatch(1); - messageChannel.subscribe(new MessageHandler() { + this.messageChannel.subscribe(new MessageHandler() { @Override public void handleMessage(Message message) throws MessagingException { @@ -127,20 +127,20 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { }); - stompBroker.awaitMessages(1); + this.stompBroker.awaitMessages(1); - assertEquals(1, brokerAvailabilityListener.availabilityEvents.size()); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); + List availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1); + assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); - stompBroker.stop(); + this.stompBroker.stop(); - relay.handleMessage(createConnectMessage(sessionId)); + this.relay.handleMessage(createConnectMessage(sessionId)); errorLatch.await(30, TimeUnit.SECONDS); - assertEquals(2, brokerAvailabilityListener.availabilityEvents.size()); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent); + availabilityEvents = brokerAvailabilityListener.awaitAvailabilityEvents(2); + assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); + assertTrue(availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent); } @Test @@ -151,7 +151,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { final CountDownLatch errorLatch = new CountDownLatch(1); - messageChannel.subscribe(new MessageHandler() { + this.messageChannel.subscribe(new MessageHandler() { @Override public void handleMessage(Message message) throws MessagingException { @@ -163,22 +163,51 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { }); - relay.handleMessage(createConnectMessage(sessionId)); + this.relay.handleMessage(createConnectMessage(sessionId)); - stompBroker.awaitMessages(2); + this.stompBroker.awaitMessages(2); - assertEquals(1, brokerAvailabilityListener.availabilityEvents.size()); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); + List availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1); + assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); - stompBroker.stop(); + this.stompBroker.stop(); - relay.handleMessage(createSubscribeMessage(sessionId, "/topic/test/")); + this.relay.handleMessage(createSubscribeMessage(sessionId, "/topic/test/")); errorLatch.await(30, TimeUnit.SECONDS); - assertEquals(2, brokerAvailabilityListener.availabilityEvents.size()); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent); + availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1); + assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); + assertTrue(availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent); + } + + @Test + public void relayReconnectsIfTheBrokerComesBackUp() throws InterruptedException { + List availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1); + assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); + + List> messages = this.stompBroker.awaitMessages(1); + assertEquals(1, messages.size()); + assertStompCommand(messages.get(0), StompCommand.CONNECT); + + this.stompBroker.stop(); + + this.relay.handleMessage(createSendMessage(null, "/topic/test", "test")); + + availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(2); + assertTrue(availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent); + + this.relay.handleMessage(createSendMessage(null, "/topic/test", "test-again")); + + this.stompBroker.start(); + + messages = this.stompBroker.awaitMessages(3); + assertEquals(3, messages.size()); + assertStompCommand(messages.get(1), StompCommand.CONNECT); + assertStompCommandAndPayload(messages.get(2), StompCommand.SEND, "test-again"); + + availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(3); + assertTrue(availabilityEvents.get(2) instanceof BrokerBecameAvailableEvent); } private Message createConnectMessage(String sessionId) { @@ -204,6 +233,16 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { return MessageBuilder.withPayloadAndHeaders(payload.getBytes(), headers).build(); } + private void assertStompCommand(Message message, StompCommand expectedCommand) { + assertEquals(expectedCommand, StompHeaderAccessor.wrap(message).getCommand()); + } + + private void assertStompCommandAndPayload(Message message, StompCommand expectedCommand, + String expectedPayload) { + assertStompCommand(message, expectedCommand); + assertEquals(expectedPayload, new String(((byte[])message.getPayload()))); + } + @Configuration public static class TestConfiguration { @@ -233,14 +272,27 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { } } - private static class BrokerAvailabilityListener implements ApplicationListener { private final List availabilityEvents = new ArrayList(); + private final Object monitor = new Object(); + @Override public void onApplicationEvent(BrokerAvailabilityEvent event) { - this.availabilityEvents.add(event); + synchronized (this.monitor) { + this.availabilityEvents.add(event); + this.monitor.notifyAll(); + } + } + + private List awaitAvailabilityEvents(int eventCount) throws InterruptedException { + synchronized (this.monitor) { + while (this.availabilityEvents.size() < eventCount) { + this.monitor.wait(); + } + return new ArrayList(this.availabilityEvents); + } } } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/TestStompBroker.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/TestStompBroker.java index f5f648ced53..305e8937c33 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/TestStompBroker.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/TestStompBroker.java @@ -16,7 +16,6 @@ package org.springframework.messaging.simp.stomp; -import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet;