From 496d8321c3d8618b1f42923ec9caf0403422111e Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Mon, 2 Sep 2013 15:31:08 +0100 Subject: [PATCH] Add heart-beat support to STOMP broker relay Previously, the STOMP broker relay did not support heart-beats. It sent 0,0 in the heart-beats header for its own CONNECTED message, and set the heart-beats header to 0,0 when it was forwarding a CONNECTED from from a client to the broker. The broker relay now supports heart-beats for the system relay session. It will send heart-beats at the send interval that's been negotiated with the broker and will also expect to receive heart-beats at the receive interval that's been negotiated with the broker. The receive interval is multiplied by a factor of three to satisfy the STOMP spec's suggestion of lenience and ActiveMQ 5.8.0's heart-beat behaviour (see AMQ-4710). The broker relay also supports heart-beats between clients and the broker. For any given client's relay session, any heart-beats received from the client are forwarded on to the broker and any heart-beats received from the broker are sent back to the client. Internally, a heart-beat is represented as a Message with a byte array payload containing the single byte of new line ('\n') character and 'empty' headers. SubscriptionMethodReturnValueHandler has been updated to default the message type to SimpMessageType.MESSAGE. This eases the distinction between a heartbeat and a message that's been created from a return value from application code. --- build.gradle | 5 +- .../SubscriptionMethodReturnValueHandler.java | 2 + .../stomp/StompBrokerRelayMessageHandler.java | 63 +++++++++++++++---- .../messaging/simp/stomp/StompDecoder.java | 20 +++++- .../messaging/simp/stomp/StompEncoder.java | 24 +++++-- .../simp/stomp/StompHeaderAccessor.java | 6 +- .../simp/stomp/StompProtocolHandler.java | 32 +++++++--- ...erRelayMessageHandlerIntegrationTests.java | 22 ++++--- 8 files changed, 131 insertions(+), 43 deletions(-) diff --git a/build.gradle b/build.gradle index 90bf8b5ee53..381e9895cd7 100644 --- a/build.gradle +++ b/build.gradle @@ -71,6 +71,7 @@ configure(allprojects) { project -> maven { url "https://repository.apache.org/content/repositories/releases" } // tomcat 8 RC3 maven { url "https://repository.apache.org/content/repositories/snapshots" } // tomcat-websocket-* snapshots maven { url "https://maven.java.net/content/repositories/releases" } // javax.websocket, tyrus + maven { url 'http://repo.springsource.org/libs-snapshot' } // reactor } dependencies { @@ -352,8 +353,8 @@ project("spring-messaging") { optional(project(":spring-websocket")) optional(project(":spring-webmvc")) optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") - optional("org.projectreactor:reactor-core:1.0.0.M2") - optional("org.projectreactor:reactor-tcp:1.0.0.M2") + optional("org.projectreactor:reactor-core:1.0.0.BUILD-SNAPSHOT") + optional("org.projectreactor:reactor-tcp:1.0.0.BUILD-SNAPSHOT") optional("com.lmax:disruptor:3.1.1") optional("org.eclipse.jetty.websocket:websocket-server:9.0.5.v20130815") optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815") diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java index 683e493e4ab..9b0e4fdde26 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java @@ -23,6 +23,7 @@ import org.springframework.messaging.core.MessageSendingOperations; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.handler.method.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.annotation.SubscribeEvent; import org.springframework.messaging.support.MessageBuilder; @@ -97,6 +98,7 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); headers.setSessionId(this.sessionId); headers.setSubscriptionId(this.subscriptionId); + headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); return MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 5d93031ba0a..4596be012bf 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -196,21 +196,17 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); } - if (headers.getCommand() == null) { - logger.error("No STOMP command, ignoring message: " + message); - return; - } if (sessionId == null) { logger.error("No sessionId, ignoring message: " + message); return; } - if (command.requiresDestination() && !checkDestinationPrefix(destination)) { + + if (command != null && command.requiresDestination() && !checkDestinationPrefix(destination)) { return; } try { if (SimpMessageType.CONNECT.equals(messageType)) { - headers.setHeartbeat(0, 0); message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); RelaySession session = new RelaySession(sessionId); this.relaySessions.put(sessionId, session); @@ -305,8 +301,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (StompCommand.CONNECTED == headers.getCommand()) { - this.stompConnection.setReady(); - publishBrokerAvailableEvent(); + connected(headers, this.stompConnection); } headers.setSessionId(this.sessionId); @@ -314,12 +309,21 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler sendMessageToClient(message); } + protected void connected(StompHeaderAccessor headers, StompConnection stompConnection) { + this.stompConnection.setReady(); + publishBrokerAvailableEvent(); + } + private void handleTcpClientFailure(String message, Throwable ex) { if (logger.isErrorEnabled()) { logger.error(message + ", sessionId=" + this.sessionId, ex); } + disconnected(message); + } + + protected void disconnected(String errorMessage) { this.stompConnection.setDisconnected(); - sendError(message); + sendError(errorMessage); publishBrokerUnavailableEvent(); } @@ -445,12 +449,16 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private class SystemRelaySession extends RelaySession { - private static final long HEARTBEAT_SEND_INTERVAL = 10000; + private static final long HEARTBEAT_RECEIVE_MULTIPLIER = 3; + + private static final long HEARTBEAT_SEND_INTERVAL = 10000; - private static final long HEARTBEAT_RECEIVE_INTERVAL = 10000; + private static final long HEARTBEAT_RECEIVE_INTERVAL = 10000; public static final String ID = "stompRelaySystemSessionId"; + private final byte[] heartbeatPayload = new byte[] {'\n'}; + public SystemRelaySession() { super(ID); @@ -481,6 +489,39 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler publishBrokerUnavailableEvent(); } + @Override + protected void connected(StompHeaderAccessor headers, final StompConnection stompConnection) { + long brokerReceiveInterval = headers.getHeartbeat()[1]; + + if (HEARTBEAT_SEND_INTERVAL > 0 && brokerReceiveInterval > 0) { + long interval = Math.max(HEARTBEAT_SEND_INTERVAL, brokerReceiveInterval); + stompConnection.connection.on().writeIdle(interval, new Runnable() { + + @Override + public void run() { + stompConnection.connection.send(MessageBuilder.withPayload(heartbeatPayload).build()); + } + + }); + } + + long brokerSendInterval = headers.getHeartbeat()[0]; + if (HEARTBEAT_RECEIVE_INTERVAL > 0 && brokerSendInterval > 0) { + final long interval = + Math.max(HEARTBEAT_RECEIVE_INTERVAL, brokerSendInterval) * HEARTBEAT_RECEIVE_MULTIPLIER; + stompConnection.connection.on().readIdle(interval, new Runnable() { + @Override + public void run() { + String message = "Broker hearbeat missed: connection idle for more than " + interval + "ms"; + logger.warn(message); + disconnected(message); + } + }); + } + + super.connected(headers, stompConnection); + } + @Override protected void sendMessageToClient(Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java index 5c45244a204..e876df99ad2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java @@ -20,6 +20,8 @@ import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; import java.nio.charset.Charset; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.LinkedMultiValueMap; @@ -35,6 +37,10 @@ public class StompDecoder { private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + private static final byte[] HEARTBEAT_PAYLOAD = new byte[] {'\n'}; + + private final Log logger = LogFactory.getLog(StompDecoder.class); + /** * Decodes a STOMP frame in the given {@code buffer} into a {@link Message}. @@ -49,12 +55,20 @@ public class StompDecoder { MultiValueMap headers = readHeaders(buffer); byte[] payload = readPayload(buffer, headers); - return MessageBuilder.withPayloadAndHeaders(payload, + Message decodedMessage = MessageBuilder.withPayloadAndHeaders(payload, StompHeaderAccessor.create(StompCommand.valueOf(command), headers)).build(); + + if (logger.isTraceEnabled()) { + logger.trace("Decoded " + decodedMessage); + } + + return decodedMessage; } else { - // Heartbeat - return null; + if (logger.isTraceEnabled()) { + logger.trace("Decoded heartbeat"); + } + return MessageBuilder.withPayload(HEARTBEAT_PAYLOAD).build(); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java index f7760d57def..aa342c2e0f8 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java @@ -23,6 +23,8 @@ import java.nio.charset.Charset; import java.util.List; import java.util.Map.Entry; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; /** @@ -39,6 +41,7 @@ public final class StompEncoder { private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + private final Log logger = LogFactory.getLog(StompEncoder.class); /** * Encodes the given STOMP {@code message} into a {@code byte[]} @@ -49,16 +52,23 @@ public final class StompEncoder { */ public byte[] encode(Message message) { try { + if (logger.isTraceEnabled()) { + logger.trace("Encoding " + message); + } ByteArrayOutputStream baos = new ByteArrayOutputStream(); DataOutputStream output = new DataOutputStream(baos); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - writeCommand(headers, output); - writeHeaders(headers, message, output); - output.write(LF); - writeBody(message, output); - output.write((byte)0); + if (isHeartbeat(headers)) { + output.write(message.getPayload()); + } else { + writeCommand(headers, output); + writeHeaders(headers, message, output); + output.write(LF); + writeBody(message, output); + output.write((byte)0); + } return baos.toByteArray(); } @@ -67,6 +77,10 @@ public final class StompEncoder { } } + private boolean isHeartbeat(StompHeaderAccessor headers) { + return headers.getCommand() == null; + } + private void writeCommand(StompHeaderAccessor headers, DataOutputStream output) throws IOException { output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET)); output.write(LF); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java index f39c43fe664..f3771059446 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java @@ -82,6 +82,8 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { public static final String STOMP_HEARTBEAT_HEADER = "heart-beat"; + private static final long[] DEFAULT_HEARTBEAT = new long[] {0, 0}; + // Other header names @@ -185,7 +187,7 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { result.put(STOMP_CONTENT_TYPE_HEADER, Arrays.asList(contentType.toString())); } - if (getCommand().requiresSubscriptionId()) { + if (getCommand() != null && getCommand().requiresSubscriptionId()) { String subscriptionId = getSubscriptionId(); if (subscriptionId != null) { String name = StompCommand.MESSAGE.equals(getCommand()) ? STOMP_SUBSCRIPTION_HEADER : STOMP_ID_HEADER; @@ -252,7 +254,7 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { public long[] getHeartbeat() { String rawValue = getFirstNativeHeader(STOMP_HEARTBEAT_HEADER); if (!StringUtils.hasText(rawValue)) { - return null; + return Arrays.copyOf(DEFAULT_HEARTBEAT, 2); } String[] rawValues = StringUtils.commaDelimitedListToStringArray(rawValue); return new long[] { Long.valueOf(rawValues[0]), Long.valueOf(rawValues[1])}; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java index 1c70bbb669a..6bbd655e91b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java @@ -166,11 +166,17 @@ public class StompProtocolHandler implements SubProtocolHandler { public void handleMessageToClient(WebSocketSession session, Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - headers.setCommandIfNotSet(StompCommand.MESSAGE); + if (headers.getCommand() == null && SimpMessageType.MESSAGE == headers.getMessageType()) { + headers.setCommandIfNotSet(StompCommand.MESSAGE); + } - if (this.handleConnect && StompCommand.CONNECTED.equals(headers.getCommand())) { - // Ignore since we already sent it - return; + if (headers.getCommand() == StompCommand.CONNECTED) { + if (this.handleConnect) { + // Ignore since we already sent it + return; + } else { + augmentConnectedHeaders(headers, session); + } } if (StompCommand.MESSAGE.equals(headers.getCommand()) && (headers.getSubscriptionId() == null)) { @@ -222,20 +228,26 @@ public class StompProtocolHandler implements SubProtocolHandler { } connectedHeaders.setHeartbeat(0,0); + augmentConnectedHeaders(connectedHeaders, session); + + // TODO: security + + Message connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build(); + String payload = new String(this.stompEncoder.encode(connectedMessage), Charset.forName("UTF-8")); + session.sendMessage(new TextMessage(payload)); + } + + private void augmentConnectedHeaders(StompHeaderAccessor headers, WebSocketSession session) { Principal principal = session.getPrincipal(); if (principal != null) { - connectedHeaders.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); - connectedHeaders.setNativeHeader(QUEUE_SUFFIX_HEADER, session.getId()); + headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); + headers.setNativeHeader(QUEUE_SUFFIX_HEADER, session.getId()); if (this.queueSuffixResolver != null) { String suffix = session.getId(); this.queueSuffixResolver.addQueueSuffix(principal.getName(), session.getId(), suffix); } } - - Message connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build(); - String payload = new String(this.stompEncoder.encode(connectedMessage), Charset.forName("UTF-8")); - session.sendMessage(new TextMessage(payload)); } protected void sendErrorMessage(WebSocketSession session, Throwable error) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java index 7441536ef9b..3cf1544313c 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -287,20 +287,22 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { @Override public void handleMessage(Message message) throws MessagingException { - synchronized(this.monitor) { - for (MessageExchange exch : this.expected) { - if (exch.matchMessage(message)) { - if (exch.isDone()) { - this.expected.remove(exch); - this.actual.add(exch); - if (this.expected.isEmpty()) { - this.monitor.notifyAll(); + if (StompHeaderAccessor.wrap(message).getCommand() != null) { + synchronized(this.monitor) { + for (MessageExchange exch : this.expected) { + if (exch.matchMessage(message)) { + if (exch.isDone()) { + this.expected.remove(exch); + this.actual.add(exch); + if (this.expected.isEmpty()) { + this.monitor.notifyAll(); + } } + return; } - return; } + this.unexpected.add(message); } - this.unexpected.add(message); } }