From 131b5de6f948e9dd32bdeab4bb900712854bb8f3 Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Wed, 14 Aug 2013 11:46:09 +0100 Subject: [PATCH] Add reconnect logic to the relay's system session Upgrade to Reactor snapshot builds to take advantage of TcpClient's reconnect support that was added post-M1. Now, the system relay session will try every 5 seconds to establish a connection with the broker, both when first connecting and in the event of subsequently becoming disconnected. A more sophisticated reconnection policy, including back off and failover to different brokers, is possible with the Reactor API. We may want to enhance the relay's reconnection policy in the future. Typically, a broken connection is identified by the failure to forward a message to the broker. As things stand, the message id then discarded. Any further messages that are forwarded before the connection's been re-established are queued for forwarding once the CONNECTED frame's been received. We may want to consider also queueing the message that failed to send, however we would then need to consider the possibility of the message itself being what caused the broker to close the connection and resending it would simply cause the connection to be closed again. --- build.gradle | 5 +- .../stomp/StompBrokerRelayMessageHandler.java | 129 +++++++++++------- ...erRelayMessageHandlerIntegrationTests.java | 112 +++++++++++---- .../messaging/simp/stomp/TestStompBroker.java | 1 - 4 files changed, 168 insertions(+), 79 deletions(-) diff --git a/build.gradle b/build.gradle index 68f90ace0c2..1b550caa690 100644 --- a/build.gradle +++ b/build.gradle @@ -318,8 +318,8 @@ project("spring-messaging") { compile(project(":spring-context")) optional(project(":spring-websocket")) optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") - optional("org.projectreactor:reactor-core:1.0.0.M1") - optional("org.projectreactor:reactor-tcp:1.0.0.M1") + optional("org.projectreactor:reactor-core:1.0.0.BUILD-SNAPSHOT") + optional("org.projectreactor:reactor-tcp:1.0.0.BUILD-SNAPSHOT") optional("com.lmax:disruptor:3.1.1") testCompile("commons-dbcp:commons-dbcp:1.2.2") testCompile("javax.inject:javax.inject-tck:1") @@ -328,6 +328,7 @@ project("spring-messaging") { repositories { maven { url 'http://repo.springsource.org/libs-milestone' } // reactor + maven { url 'http://repo.springsource.org/libs-snapshot' } // reactor } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 55d239c9ba8..2ad569dc72a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -16,6 +16,7 @@ package org.springframework.messaging.simp.stomp; +import java.net.InetSocketAddress; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Collection; @@ -41,16 +42,20 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; import reactor.core.Environment; +import reactor.core.composable.Composable; import reactor.core.composable.Deferred; import reactor.core.composable.Promise; import reactor.core.composable.spec.DeferredPromiseSpec; import reactor.function.Consumer; +import reactor.tcp.Reconnect; import reactor.tcp.TcpClient; import reactor.tcp.TcpConnection; import reactor.tcp.encoding.DelimitedCodec; import reactor.tcp.encoding.StandardCodecs; import reactor.tcp.netty.NettyTcpClient; import reactor.tcp.spec.TcpClientSpec; +import reactor.tuple.Tuple; +import reactor.tuple.Tuple2; /** @@ -219,12 +224,24 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife private void openSystemSession() { RelaySession session = new RelaySession(STOMP_RELAY_SYSTEM_SESSION_ID) { + @Override protected void sendMessageToClient(Message message) { // ignore, only used to send messages // TODO: ERROR frame/reconnect } + + @Override + protected Composable> openConnection() { + return tcpClient.open(new Reconnect() { + @Override + public Tuple2 reconnect(InetSocketAddress currentAddress, int attempt) { + return Tuple.of(currentAddress, 5000L); + } + }); + } }; + this.relaySessions.put(STOMP_RELAY_SYSTEM_SESSION_ID, session); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); @@ -376,7 +393,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife private final BlockingQueue> messageQueue = new LinkedBlockingQueue>(50); - private Promise> promise; + private volatile TcpConnection connection = null; private volatile boolean isConnected = false; @@ -391,21 +408,24 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife public void open(final Message message) { Assert.notNull(message, "message is required"); - this.promise = tcpClient.open(); + Composable> connectionComposable = openConnection(); - this.promise.consume(new Consumer>() { + connectionComposable.consume(new Consumer>() { @Override - public void accept(TcpConnection connection) { - connection.in().consume(new Consumer() { + public void accept(TcpConnection newConnection) { + isConnected = false; + connection = newConnection; + newConnection.in().consume(new Consumer() { @Override public void accept(String stompFrame) { readStompFrame(stompFrame); } }); - forwardInternal(message, connection); + forwardInternal(message); } }); - this.promise.onError(new Consumer() { + + connectionComposable.when(Throwable.class, new Consumer() { @Override public void accept(Throwable ex) { relaySessions.remove(sessionId); @@ -415,6 +435,10 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife }); } + protected Composable> openConnection() { + return tcpClient.open(); + } + private void readStompFrame(String stompFrame) { // heartbeat @@ -432,7 +456,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife synchronized(this.monitor) { this.isConnected = true; brokerAvailable(); - flushMessages(this.promise.get()); + flushMessages(); } return; } @@ -447,7 +471,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife } private void sendError(String sessionId, String errorText) { - brokerUnavailable(); + disconnect(); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); headers.setSessionId(sessionId); @@ -456,6 +480,14 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife sendMessageToClient(errorMessage); } + private void disconnect() { + this.isConnected = false; + this.connection.close(); + this.connection = null; + + brokerUnavailable(); + } + public void forward(Message message) { if (!this.isConnected) { @@ -463,72 +495,77 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife if (!this.isConnected) { this.messageQueue.add(message); if (logger.isTraceEnabled()) { - logger.trace("Not connected yet, message queued, queue size=" + this.messageQueue.size()); + logger.trace("Not connected, message queued. Queue size=" + this.messageQueue.size()); } return; } } } - TcpConnection connection = this.promise.get(); - if (this.messageQueue.isEmpty()) { - forwardInternal(message, connection); + forwardInternal(message); } else { this.messageQueue.add(message); - flushMessages(connection); + flushMessages(); } } - private boolean forwardInternal(final Message message, TcpConnection connection) { - if (logger.isTraceEnabled()) { - logger.trace("Forwarding message to STOMP broker, message id=" + message.getHeaders().getId()); - } - byte[] bytes = stompMessageConverter.fromMessage(message); + private boolean forwardInternal(final Message message) { + TcpConnection localConnection = this.connection; - final Deferred> deferred = new DeferredPromiseSpec().get(); + if (localConnection != null) { + if (logger.isTraceEnabled()) { + logger.trace("Forwarding message to STOMP broker, message id=" + message.getHeaders().getId()); + } + byte[] bytes = stompMessageConverter.fromMessage(message); - connection.send(new String(bytes, Charset.forName("UTF-8")), new Consumer() { + final Deferred> deferred = new DeferredPromiseSpec().get(); - @Override - public void accept(Boolean success) { - if (!success && StompHeaderAccessor.wrap(message).getCommand() != StompCommand.DISCONNECT) { - deferred.accept(false); - } else { - deferred.accept(true); + String payload = new String(bytes, Charset.forName("UTF-8")); + localConnection.send(payload, new Consumer() { + + @Override + public void accept(Boolean success) { + if (!success && StompHeaderAccessor.wrap(message).getCommand() != StompCommand.DISCONNECT) { + deferred.accept(false); + } else { + deferred.accept(true); + } } - } - }); + }); - Boolean success = null; + Boolean success = null; - try { - success = deferred.compose().await(); + try { + success = deferred.compose().await(); - if (success == null) { - sendError(sessionId, "Timed out waiting for message to be forwarded to the broker"); + if (success == null) { + sendError(sessionId, "Timed out waiting for message to be forwarded to the broker"); + } + else if (!success) { + sendError(sessionId, "Failed to forward message to the broker"); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + sendError(sessionId, "Interrupted while forwarding message to the broker"); } - else if (!success) { - sendError(sessionId, "Failed to forward message to the broker"); + + if (success == null) { + success = false; } - } catch (InterruptedException ie) { - Thread.currentThread().interrupt(); - sendError(sessionId, "Interrupted while forwarding message to the broker"); - } - if (success == null) { - success = false; + return success; + } else { + return false; } - - return success; } - private void flushMessages(TcpConnection connection) { + private void flushMessages() { List> messages = new ArrayList>(); this.messageQueue.drainTo(messages); for (Message message : messages) { - if (!forwardInternal(message, connection)) { + if (!forwardInternal(message)) { return; } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java index 92eaa501bcd..de26562c40e 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -43,8 +43,8 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.util.SocketUtils; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertEquals; /** @@ -81,7 +81,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { final CountDownLatch messageLatch = new CountDownLatch(1); - messageChannel.subscribe(new MessageHandler() { + this.messageChannel.subscribe(new MessageHandler() { @Override public void handleMessage(Message message) throws MessagingException { @@ -93,18 +93,18 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { }); - relay.handleMessage(createConnectMessage(client1SessionId)); - relay.handleMessage(createConnectMessage(client2SessionId)); - relay.handleMessage(createSubscribeMessage(client1SessionId, "/topic/test")); + this.relay.handleMessage(createConnectMessage(client1SessionId)); + this.relay.handleMessage(createConnectMessage(client2SessionId)); + this.relay.handleMessage(createSubscribeMessage(client1SessionId, "/topic/test")); - stompBroker.awaitMessages(4); + this.stompBroker.awaitMessages(4); - relay.handleMessage(createSendMessage(client2SessionId, "/topic/test", "fromClient2")); + this.relay.handleMessage(createSendMessage(client2SessionId, "/topic/test", "fromClient2")); assertTrue(messageLatch.await(30, TimeUnit.SECONDS)); - assertEquals(1, brokerAvailabilityListener.availabilityEvents.size()); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); + List availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1); + assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); } @Test @@ -115,7 +115,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { final CountDownLatch errorLatch = new CountDownLatch(1); - messageChannel.subscribe(new MessageHandler() { + this.messageChannel.subscribe(new MessageHandler() { @Override public void handleMessage(Message message) throws MessagingException { @@ -127,20 +127,20 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { }); - stompBroker.awaitMessages(1); + this.stompBroker.awaitMessages(1); - assertEquals(1, brokerAvailabilityListener.availabilityEvents.size()); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); + List availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1); + assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); - stompBroker.stop(); + this.stompBroker.stop(); - relay.handleMessage(createConnectMessage(sessionId)); + this.relay.handleMessage(createConnectMessage(sessionId)); errorLatch.await(30, TimeUnit.SECONDS); - assertEquals(2, brokerAvailabilityListener.availabilityEvents.size()); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent); + availabilityEvents = brokerAvailabilityListener.awaitAvailabilityEvents(2); + assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); + assertTrue(availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent); } @Test @@ -151,7 +151,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { final CountDownLatch errorLatch = new CountDownLatch(1); - messageChannel.subscribe(new MessageHandler() { + this.messageChannel.subscribe(new MessageHandler() { @Override public void handleMessage(Message message) throws MessagingException { @@ -163,22 +163,51 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { }); - relay.handleMessage(createConnectMessage(sessionId)); + this.relay.handleMessage(createConnectMessage(sessionId)); - stompBroker.awaitMessages(2); + this.stompBroker.awaitMessages(2); - assertEquals(1, brokerAvailabilityListener.availabilityEvents.size()); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); + List availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1); + assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); - stompBroker.stop(); + this.stompBroker.stop(); - relay.handleMessage(createSubscribeMessage(sessionId, "/topic/test/")); + this.relay.handleMessage(createSubscribeMessage(sessionId, "/topic/test/")); errorLatch.await(30, TimeUnit.SECONDS); - assertEquals(2, brokerAvailabilityListener.availabilityEvents.size()); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); - assertTrue(brokerAvailabilityListener.availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent); + availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1); + assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); + assertTrue(availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent); + } + + @Test + public void relayReconnectsIfTheBrokerComesBackUp() throws InterruptedException { + List availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1); + assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); + + List> messages = this.stompBroker.awaitMessages(1); + assertEquals(1, messages.size()); + assertStompCommand(messages.get(0), StompCommand.CONNECT); + + this.stompBroker.stop(); + + this.relay.handleMessage(createSendMessage(null, "/topic/test", "test")); + + availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(2); + assertTrue(availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent); + + this.relay.handleMessage(createSendMessage(null, "/topic/test", "test-again")); + + this.stompBroker.start(); + + messages = this.stompBroker.awaitMessages(3); + assertEquals(3, messages.size()); + assertStompCommand(messages.get(1), StompCommand.CONNECT); + assertStompCommandAndPayload(messages.get(2), StompCommand.SEND, "test-again"); + + availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(3); + assertTrue(availabilityEvents.get(2) instanceof BrokerBecameAvailableEvent); } private Message createConnectMessage(String sessionId) { @@ -204,6 +233,16 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { return MessageBuilder.withPayloadAndHeaders(payload.getBytes(), headers).build(); } + private void assertStompCommand(Message message, StompCommand expectedCommand) { + assertEquals(expectedCommand, StompHeaderAccessor.wrap(message).getCommand()); + } + + private void assertStompCommandAndPayload(Message message, StompCommand expectedCommand, + String expectedPayload) { + assertStompCommand(message, expectedCommand); + assertEquals(expectedPayload, new String(((byte[])message.getPayload()))); + } + @Configuration public static class TestConfiguration { @@ -233,14 +272,27 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { } } - private static class BrokerAvailabilityListener implements ApplicationListener { private final List availabilityEvents = new ArrayList(); + private final Object monitor = new Object(); + @Override public void onApplicationEvent(BrokerAvailabilityEvent event) { - this.availabilityEvents.add(event); + synchronized (this.monitor) { + this.availabilityEvents.add(event); + this.monitor.notifyAll(); + } + } + + private List awaitAvailabilityEvents(int eventCount) throws InterruptedException { + synchronized (this.monitor) { + while (this.availabilityEvents.size() < eventCount) { + this.monitor.wait(); + } + return new ArrayList(this.availabilityEvents); + } } } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/TestStompBroker.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/TestStompBroker.java index f5f648ced53..305e8937c33 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/TestStompBroker.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/TestStompBroker.java @@ -16,7 +16,6 @@ package org.springframework.messaging.simp.stomp; -import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet;