diff --git a/build.gradle b/build.gradle index 90bf8b5ee53..381e9895cd7 100644 --- a/build.gradle +++ b/build.gradle @@ -71,6 +71,7 @@ configure(allprojects) { project -> maven { url "https://repository.apache.org/content/repositories/releases" } // tomcat 8 RC3 maven { url "https://repository.apache.org/content/repositories/snapshots" } // tomcat-websocket-* snapshots maven { url "https://maven.java.net/content/repositories/releases" } // javax.websocket, tyrus + maven { url 'http://repo.springsource.org/libs-snapshot' } // reactor } dependencies { @@ -352,8 +353,8 @@ project("spring-messaging") { optional(project(":spring-websocket")) optional(project(":spring-webmvc")) optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") - optional("org.projectreactor:reactor-core:1.0.0.M2") - optional("org.projectreactor:reactor-tcp:1.0.0.M2") + optional("org.projectreactor:reactor-core:1.0.0.BUILD-SNAPSHOT") + optional("org.projectreactor:reactor-tcp:1.0.0.BUILD-SNAPSHOT") optional("com.lmax:disruptor:3.1.1") optional("org.eclipse.jetty.websocket:websocket-server:9.0.5.v20130815") optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815") diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java index 683e493e4ab..9b0e4fdde26 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java @@ -23,6 +23,7 @@ import org.springframework.messaging.core.MessageSendingOperations; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.handler.method.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.annotation.SubscribeEvent; import org.springframework.messaging.support.MessageBuilder; @@ -97,6 +98,7 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); headers.setSessionId(this.sessionId); headers.setSubscriptionId(this.subscriptionId); + headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); return MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 5d93031ba0a..4596be012bf 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -196,21 +196,17 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); } - if (headers.getCommand() == null) { - logger.error("No STOMP command, ignoring message: " + message); - return; - } if (sessionId == null) { logger.error("No sessionId, ignoring message: " + message); return; } - if (command.requiresDestination() && !checkDestinationPrefix(destination)) { + + if (command != null && command.requiresDestination() && !checkDestinationPrefix(destination)) { return; } try { if (SimpMessageType.CONNECT.equals(messageType)) { - headers.setHeartbeat(0, 0); message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); RelaySession session = new RelaySession(sessionId); this.relaySessions.put(sessionId, session); @@ -305,8 +301,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (StompCommand.CONNECTED == headers.getCommand()) { - this.stompConnection.setReady(); - publishBrokerAvailableEvent(); + connected(headers, this.stompConnection); } headers.setSessionId(this.sessionId); @@ -314,12 +309,21 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler sendMessageToClient(message); } + protected void connected(StompHeaderAccessor headers, StompConnection stompConnection) { + this.stompConnection.setReady(); + publishBrokerAvailableEvent(); + } + private void handleTcpClientFailure(String message, Throwable ex) { if (logger.isErrorEnabled()) { logger.error(message + ", sessionId=" + this.sessionId, ex); } + disconnected(message); + } + + protected void disconnected(String errorMessage) { this.stompConnection.setDisconnected(); - sendError(message); + sendError(errorMessage); publishBrokerUnavailableEvent(); } @@ -445,12 +449,16 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private class SystemRelaySession extends RelaySession { - private static final long HEARTBEAT_SEND_INTERVAL = 10000; + private static final long HEARTBEAT_RECEIVE_MULTIPLIER = 3; + + private static final long HEARTBEAT_SEND_INTERVAL = 10000; - private static final long HEARTBEAT_RECEIVE_INTERVAL = 10000; + private static final long HEARTBEAT_RECEIVE_INTERVAL = 10000; public static final String ID = "stompRelaySystemSessionId"; + private final byte[] heartbeatPayload = new byte[] {'\n'}; + public SystemRelaySession() { super(ID); @@ -481,6 +489,39 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler publishBrokerUnavailableEvent(); } + @Override + protected void connected(StompHeaderAccessor headers, final StompConnection stompConnection) { + long brokerReceiveInterval = headers.getHeartbeat()[1]; + + if (HEARTBEAT_SEND_INTERVAL > 0 && brokerReceiveInterval > 0) { + long interval = Math.max(HEARTBEAT_SEND_INTERVAL, brokerReceiveInterval); + stompConnection.connection.on().writeIdle(interval, new Runnable() { + + @Override + public void run() { + stompConnection.connection.send(MessageBuilder.withPayload(heartbeatPayload).build()); + } + + }); + } + + long brokerSendInterval = headers.getHeartbeat()[0]; + if (HEARTBEAT_RECEIVE_INTERVAL > 0 && brokerSendInterval > 0) { + final long interval = + Math.max(HEARTBEAT_RECEIVE_INTERVAL, brokerSendInterval) * HEARTBEAT_RECEIVE_MULTIPLIER; + stompConnection.connection.on().readIdle(interval, new Runnable() { + @Override + public void run() { + String message = "Broker hearbeat missed: connection idle for more than " + interval + "ms"; + logger.warn(message); + disconnected(message); + } + }); + } + + super.connected(headers, stompConnection); + } + @Override protected void sendMessageToClient(Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java index 5c45244a204..e876df99ad2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java @@ -20,6 +20,8 @@ import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; import java.nio.charset.Charset; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.LinkedMultiValueMap; @@ -35,6 +37,10 @@ public class StompDecoder { private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + private static final byte[] HEARTBEAT_PAYLOAD = new byte[] {'\n'}; + + private final Log logger = LogFactory.getLog(StompDecoder.class); + /** * Decodes a STOMP frame in the given {@code buffer} into a {@link Message}. @@ -49,12 +55,20 @@ public class StompDecoder { MultiValueMap headers = readHeaders(buffer); byte[] payload = readPayload(buffer, headers); - return MessageBuilder.withPayloadAndHeaders(payload, + Message decodedMessage = MessageBuilder.withPayloadAndHeaders(payload, StompHeaderAccessor.create(StompCommand.valueOf(command), headers)).build(); + + if (logger.isTraceEnabled()) { + logger.trace("Decoded " + decodedMessage); + } + + return decodedMessage; } else { - // Heartbeat - return null; + if (logger.isTraceEnabled()) { + logger.trace("Decoded heartbeat"); + } + return MessageBuilder.withPayload(HEARTBEAT_PAYLOAD).build(); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java index f7760d57def..aa342c2e0f8 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java @@ -23,6 +23,8 @@ import java.nio.charset.Charset; import java.util.List; import java.util.Map.Entry; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; /** @@ -39,6 +41,7 @@ public final class StompEncoder { private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + private final Log logger = LogFactory.getLog(StompEncoder.class); /** * Encodes the given STOMP {@code message} into a {@code byte[]} @@ -49,16 +52,23 @@ public final class StompEncoder { */ public byte[] encode(Message message) { try { + if (logger.isTraceEnabled()) { + logger.trace("Encoding " + message); + } ByteArrayOutputStream baos = new ByteArrayOutputStream(); DataOutputStream output = new DataOutputStream(baos); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - writeCommand(headers, output); - writeHeaders(headers, message, output); - output.write(LF); - writeBody(message, output); - output.write((byte)0); + if (isHeartbeat(headers)) { + output.write(message.getPayload()); + } else { + writeCommand(headers, output); + writeHeaders(headers, message, output); + output.write(LF); + writeBody(message, output); + output.write((byte)0); + } return baos.toByteArray(); } @@ -67,6 +77,10 @@ public final class StompEncoder { } } + private boolean isHeartbeat(StompHeaderAccessor headers) { + return headers.getCommand() == null; + } + private void writeCommand(StompHeaderAccessor headers, DataOutputStream output) throws IOException { output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET)); output.write(LF); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java index f39c43fe664..f3771059446 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java @@ -82,6 +82,8 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { public static final String STOMP_HEARTBEAT_HEADER = "heart-beat"; + private static final long[] DEFAULT_HEARTBEAT = new long[] {0, 0}; + // Other header names @@ -185,7 +187,7 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { result.put(STOMP_CONTENT_TYPE_HEADER, Arrays.asList(contentType.toString())); } - if (getCommand().requiresSubscriptionId()) { + if (getCommand() != null && getCommand().requiresSubscriptionId()) { String subscriptionId = getSubscriptionId(); if (subscriptionId != null) { String name = StompCommand.MESSAGE.equals(getCommand()) ? STOMP_SUBSCRIPTION_HEADER : STOMP_ID_HEADER; @@ -252,7 +254,7 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { public long[] getHeartbeat() { String rawValue = getFirstNativeHeader(STOMP_HEARTBEAT_HEADER); if (!StringUtils.hasText(rawValue)) { - return null; + return Arrays.copyOf(DEFAULT_HEARTBEAT, 2); } String[] rawValues = StringUtils.commaDelimitedListToStringArray(rawValue); return new long[] { Long.valueOf(rawValues[0]), Long.valueOf(rawValues[1])}; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java index 1c70bbb669a..6bbd655e91b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java @@ -166,11 +166,17 @@ public class StompProtocolHandler implements SubProtocolHandler { public void handleMessageToClient(WebSocketSession session, Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - headers.setCommandIfNotSet(StompCommand.MESSAGE); + if (headers.getCommand() == null && SimpMessageType.MESSAGE == headers.getMessageType()) { + headers.setCommandIfNotSet(StompCommand.MESSAGE); + } - if (this.handleConnect && StompCommand.CONNECTED.equals(headers.getCommand())) { - // Ignore since we already sent it - return; + if (headers.getCommand() == StompCommand.CONNECTED) { + if (this.handleConnect) { + // Ignore since we already sent it + return; + } else { + augmentConnectedHeaders(headers, session); + } } if (StompCommand.MESSAGE.equals(headers.getCommand()) && (headers.getSubscriptionId() == null)) { @@ -222,20 +228,26 @@ public class StompProtocolHandler implements SubProtocolHandler { } connectedHeaders.setHeartbeat(0,0); + augmentConnectedHeaders(connectedHeaders, session); + + // TODO: security + + Message connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build(); + String payload = new String(this.stompEncoder.encode(connectedMessage), Charset.forName("UTF-8")); + session.sendMessage(new TextMessage(payload)); + } + + private void augmentConnectedHeaders(StompHeaderAccessor headers, WebSocketSession session) { Principal principal = session.getPrincipal(); if (principal != null) { - connectedHeaders.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); - connectedHeaders.setNativeHeader(QUEUE_SUFFIX_HEADER, session.getId()); + headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); + headers.setNativeHeader(QUEUE_SUFFIX_HEADER, session.getId()); if (this.queueSuffixResolver != null) { String suffix = session.getId(); this.queueSuffixResolver.addQueueSuffix(principal.getName(), session.getId(), suffix); } } - - Message connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build(); - String payload = new String(this.stompEncoder.encode(connectedMessage), Charset.forName("UTF-8")); - session.sendMessage(new TextMessage(payload)); } protected void sendErrorMessage(WebSocketSession session, Throwable error) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java index 7441536ef9b..3cf1544313c 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -287,20 +287,22 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { @Override public void handleMessage(Message message) throws MessagingException { - synchronized(this.monitor) { - for (MessageExchange exch : this.expected) { - if (exch.matchMessage(message)) { - if (exch.isDone()) { - this.expected.remove(exch); - this.actual.add(exch); - if (this.expected.isEmpty()) { - this.monitor.notifyAll(); + if (StompHeaderAccessor.wrap(message).getCommand() != null) { + synchronized(this.monitor) { + for (MessageExchange exch : this.expected) { + if (exch.matchMessage(message)) { + if (exch.isDone()) { + this.expected.remove(exch); + this.actual.add(exch); + if (this.expected.isEmpty()) { + this.monitor.notifyAll(); + } } + return; } - return; } + this.unexpected.add(message); } - this.unexpected.add(message); } }