From d26b9d60e5b8fdae01936870b1b68d4e2039aa52 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 11 Jun 2013 01:52:32 -0400 Subject: [PATCH] Refactor approach to working with STOMP headers --- .../messaging/MessageHeaders.java | 22 +- .../web/messaging/PubSubHeaders.java | 163 +++++++++ .../web/messaging/event/ReactorEventBus.java | 8 + .../service/AbstractMessageService.java | 13 +- .../service/PubSubMessageService.java | 34 +- .../method/AnnotationMessageService.java | 4 +- .../MessageChannelArgumentResolver.java | 25 +- .../method/MessageReturnValueHandler.java | 23 +- .../web/messaging/stomp/StompCommand.java | 22 ++ ...ion.java => StompConversionException.java} | 7 +- .../web/messaging/stomp/StompHeaders.java | 319 ++++++------------ .../web/messaging/stomp/StompMessage.java | 78 ----- .../web/messaging/stomp/StompSession.java | 41 --- .../socket/AbstractStompWebSocketHandler.java | 68 ++-- .../socket/DefaultStompWebSocketHandler.java | 308 +++++------------ .../stomp/socket/WebSocketStompSession.java | 92 ----- .../stomp/support/RelayStompService.java | 48 ++- .../stomp/support/StompHeaderMapper.java | 102 ------ .../stomp/support/StompMessageConverter.java | 162 +++++---- .../DestinationMessage.java} | 20 +- .../support/StompMessageConverterTests.java | 138 ++++++++ 21 files changed, 745 insertions(+), 952 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java rename spring-websocket/src/main/java/org/springframework/web/messaging/stomp/{StompException.java => StompConversionException.java} (82%) delete mode 100644 spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompMessage.java delete mode 100644 spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompSession.java delete mode 100644 spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/WebSocketStompSession.java delete mode 100644 spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompHeaderMapper.java rename spring-websocket/src/main/java/org/springframework/web/messaging/{stomp/socket/StompMessageInterceptor.java => support/DestinationMessage.java} (64%) create mode 100644 spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java diff --git a/spring-context/src/main/java/org/springframework/messaging/MessageHeaders.java b/spring-context/src/main/java/org/springframework/messaging/MessageHeaders.java index 21ae96dd23a..7fb6a26aba0 100644 --- a/spring-context/src/main/java/org/springframework/messaging/MessageHeaders.java +++ b/spring-context/src/main/java/org/springframework/messaging/MessageHeaders.java @@ -37,8 +37,9 @@ import org.apache.commons.logging.LogFactory; * The headers for a {@link Message}.
* IMPORTANT: MessageHeaders are immutable. Any mutating operation (e.g., put(..), putAll(..) etc.) * will result in {@link UnsupportedOperationException} + * *

- * TODO: update javadoc + * TODO: update below instructions * *

To create MessageHeaders instance use fluent MessageBuilder API *

@@ -76,16 +77,7 @@ public class MessageHeaders implements Map, Serializable {
 
 	public static final String TIMESTAMP = "timestamp";
 
-	public static final String REPLY_CHANNEL = "replyChannel";
-
-	public static final String ERROR_CHANNEL = "errorChannel";
-
-	public static final String CONTENT_TYPE = "content-type";
-
-	// DESTINATION ?
-
-	public static final List HEADER_NAMES =
-			Arrays.asList(ID, TIMESTAMP, REPLY_CHANNEL, ERROR_CHANNEL, CONTENT_TYPE);
+	public static final List HEADER_NAMES = Arrays.asList(ID, TIMESTAMP);
 
 
 	private final Map headers;
@@ -111,14 +103,6 @@ public class MessageHeaders implements Map, Serializable {
 		return this.get(TIMESTAMP, Long.class);
 	}
 
-	public Object getReplyChannel() {
-		return this.get(REPLY_CHANNEL);
-	}
-
-	public Object getErrorChannel() {
-		return this.get(ERROR_CHANNEL);
-	}
-
 	@SuppressWarnings("unchecked")
 	public  T get(Object key, Class type) {
 		Object value = this.headers.get(key);
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java
new file mode 100644
index 00000000000..0e020779c45
--- /dev/null
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java
@@ -0,0 +1,163 @@
+/*
+ * Copyright 2002-2013 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.messaging;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.springframework.http.MediaType;
+import org.springframework.messaging.MessageHeaders;
+import org.springframework.util.CollectionUtils;
+
+
+/**
+ *
+ * @author Rossen Stoyanchev
+ * @since 4.0
+ */
+public class PubSubHeaders {
+
+	private static final String DESTINATIONS = "destinations";
+
+	private static final String CONTENT_TYPE = "contentType";
+
+	private static final String MESSAGE_TYPE = "messageType";
+
+	private static final String SUBSCRIPTION_ID = "subscriptionId";
+
+	private static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType";
+
+	private static final String SESSION_ID = "sessionId";
+
+	private static final String RAW_HEADERS = "rawHeaders";
+
+
+	private final Map messageHeaders;
+
+	private final Map rawHeaders;
+
+
+	/**
+	 * Constructor for building new headers.
+	 *
+	 * @param messageType the message type
+	 * @param protocolMessageType the protocol-specific message type or command
+	 */
+	public PubSubHeaders(MessageType messageType, Object protocolMessageType) {
+
+		this.messageHeaders = new HashMap();
+		this.messageHeaders.put(MESSAGE_TYPE, messageType);
+		if (protocolMessageType != null) {
+			this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
+		}
+
+		this.rawHeaders = new HashMap();
+		this.messageHeaders.put(RAW_HEADERS, this.rawHeaders);
+	}
+
+	public PubSubHeaders() {
+		this(MessageType.MESSAGE, null);
+	}
+
+	/**
+	 * Constructor for access to existing {@link MessageHeaders}.
+	 *
+	 * @param messageHeaders
+	 */
+	@SuppressWarnings("unchecked")
+	public PubSubHeaders(MessageHeaders messageHeaders, boolean readOnly) {
+
+		this.messageHeaders = readOnly ? messageHeaders : new HashMap(messageHeaders);
+		this.rawHeaders = this.messageHeaders.containsKey(RAW_HEADERS) ?
+				(Map) messageHeaders.get(RAW_HEADERS) : new HashMap();
+
+		if (this.messageHeaders.get(MESSAGE_TYPE) == null) {
+			this.messageHeaders.put(MESSAGE_TYPE, MessageType.MESSAGE);
+		}
+	}
+
+
+	public Map getMessageHeaders() {
+		return this.messageHeaders;
+	}
+
+	public Map getRawHeaders() {
+		return this.rawHeaders;
+	}
+
+	public MessageType getMessageType() {
+		return (MessageType) this.messageHeaders.get(MESSAGE_TYPE);
+	}
+
+	public void setProtocolMessageType(Object protocolMessageType) {
+		this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
+	}
+
+	public Object getProtocolMessageType() {
+		return this.messageHeaders.get(PROTOCOL_MESSAGE_TYPE);
+	}
+
+	public void setDestination(String destination) {
+		this.messageHeaders.put(DESTINATIONS, Arrays.asList(destination));
+	}
+
+	public String getDestination() {
+		@SuppressWarnings("unchecked")
+		List destination = (List) messageHeaders.get(DESTINATIONS);
+		return CollectionUtils.isEmpty(destination) ? null : destination.get(0);
+	}
+
+	@SuppressWarnings("unchecked")
+	public List getDestinations() {
+		return (List) messageHeaders.get(DESTINATIONS);
+	}
+
+	public void setDestinations(List destinations) {
+		if (destinations != null) {
+			this.messageHeaders.put(DESTINATIONS, destinations);
+		}
+	}
+
+	public MediaType getContentType() {
+		return (MediaType) this.messageHeaders.get(CONTENT_TYPE);
+	}
+
+	public void setContentType(MediaType mediaType) {
+		if (mediaType != null) {
+			this.messageHeaders.put(CONTENT_TYPE, mediaType);
+		}
+	}
+
+	public String getSubscriptionId() {
+		return (String) this.messageHeaders.get(SUBSCRIPTION_ID);
+	}
+
+	public void setSubscriptionId(String subscriptionId) {
+		this.messageHeaders.put(SUBSCRIPTION_ID, subscriptionId);
+	}
+
+	public String getSessionId() {
+		return (String) this.messageHeaders.get(SESSION_ID);
+	}
+
+	public void setSessionId(String sessionId) {
+		this.messageHeaders.put(SESSION_ID, sessionId);
+	}
+
+}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/event/ReactorEventBus.java b/spring-websocket/src/main/java/org/springframework/web/messaging/event/ReactorEventBus.java
index 8c782250081..a6e9ce600fc 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/event/ReactorEventBus.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/event/ReactorEventBus.java
@@ -16,6 +16,9 @@
 
 package org.springframework.web.messaging.event;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
 import reactor.core.Reactor;
 import reactor.fn.Consumer;
 import reactor.fn.Event;
@@ -28,6 +31,8 @@ import reactor.fn.selector.ObjectSelector;
  */
 public class ReactorEventBus implements EventBus {
 
+	private static Log logger = LogFactory.getLog(ReactorEventBus.class);
+
 	private final Reactor reactor;
 
 
@@ -37,6 +42,9 @@ public class ReactorEventBus implements EventBus {
 
 	@Override
 	public void send(String key, Object data) {
+		if (logger.isTraceEnabled()) {
+			logger.trace("Sending notification key=" + key + ", data=" + data);
+		}
 		this.reactor.notify(key, Event.wrap(data));
 	}
 
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java
index 68a26d00d1a..d8a81b662d3 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java
@@ -27,6 +27,7 @@ import org.springframework.util.AntPathMatcher;
 import org.springframework.util.Assert;
 import org.springframework.util.PathMatcher;
 import org.springframework.web.messaging.MessageType;
+import org.springframework.web.messaging.PubSubHeaders;
 import org.springframework.web.messaging.event.EventBus;
 import org.springframework.web.messaging.event.EventConsumer;
 
@@ -37,11 +38,14 @@ import org.springframework.web.messaging.event.EventConsumer;
  */
 public abstract class AbstractMessageService {
 
-	public static final String MESSAGE_KEY = "messageKey";
+	protected final Log logger = LogFactory.getLog(getClass());
+
+
+	public static final String CLIENT_TO_SERVER_MESSAGE_KEY = "clientToServerMessageKey";
 
 	public static final String CLIENT_CONNECTION_CLOSED_KEY = "clientConnectionClosed";
 
-	protected final Log logger = LogFactory.getLog(getClass());
+	public static final String SERVER_TO_CLIENT_MESSAGE_KEY = "serverToClientMessageKey";
 
 
 	private final EventBus eventBus;
@@ -58,7 +62,7 @@ public abstract class AbstractMessageService {
 		Assert.notNull(reactor, "reactor is required");
 		this.eventBus = reactor;
 
-		this.eventBus.registerConsumer(MESSAGE_KEY, new EventConsumer>() {
+		this.eventBus.registerConsumer(CLIENT_TO_SERVER_MESSAGE_KEY, new EventConsumer>() {
 
 			@Override
 			public void accept(Message message) {
@@ -124,7 +128,8 @@ public abstract class AbstractMessageService {
 	}
 
 	private boolean isAllowedDestination(Message message) {
-		String destination = (String) message.getHeaders().get("destination");
+		PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
+		String destination = headers.getDestination();
 		if (destination == null) {
 			return true;
 		}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java
index a1a96c0a9ca..963a978368b 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java
@@ -17,14 +17,13 @@
 package org.springframework.web.messaging.service;
 
 import java.util.ArrayList;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 
-import org.springframework.http.MediaType;
 import org.springframework.messaging.GenericMessage;
 import org.springframework.messaging.Message;
+import org.springframework.web.messaging.PubSubHeaders;
 import org.springframework.web.messaging.converter.CompositeMessageConverter;
 import org.springframework.web.messaging.converter.MessageConverter;
 import org.springframework.web.messaging.event.EventBus;
@@ -61,26 +60,21 @@ public class PubSubMessageService extends AbstractMessageService {
 			logger.debug("Message received: " + message);
 		}
 
-		Map headers = new HashMap();
-		headers.put("destination", message.getHeaders().get("destination"));
-
-		MediaType contentType = (MediaType) message.getHeaders().get("content-type");
-		headers.put("content-type", contentType);
-
 		try {
 			// Convert to byte[] payload before the fan-out
-			byte[] payload = payloadConverter.convertToPayload(message.getPayload(), contentType);
-			message = new GenericMessage(payload, headers);
+			PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
+			byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType());
+			message = new GenericMessage(payload, message.getHeaders());
 
-			getEventBus().send(getPublishKey(message), message);
+			getEventBus().send(getPublishKey(inHeaders.getDestination()), message);
 		}
 		catch (Exception ex) {
 			logger.error("Failed to publish " + message, ex);
 		}
 	}
 
-	private String getPublishKey(Message message) {
-		return "destination:" + (String) message.getHeaders().get("destination");
+	private String getPublishKey(String destination) {
+		return "destination:" + destination;
 	}
 
 	@Override
@@ -88,12 +82,20 @@ public class PubSubMessageService extends AbstractMessageService {
 		if (logger.isDebugEnabled()) {
 			logger.debug("Subscribe " + message);
 		}
-		final String replyKey = (String) message.getHeaders().getReplyChannel();
-		EventRegistration registration = getEventBus().registerConsumer(getPublishKey(message),
+		PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
+		final String subscriptionId = headers.getSubscriptionId();
+		EventRegistration registration = getEventBus().registerConsumer(getPublishKey(headers.getDestination()),
 				new EventConsumer>() {
 					@Override
 					public void accept(Message message) {
-						getEventBus().send(replyKey, message);
+						PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
+						PubSubHeaders outHeaders = new PubSubHeaders();
+						outHeaders.setDestinations(inHeaders.getDestinations());
+						outHeaders.setContentType(inHeaders.getContentType());
+						outHeaders.setSubscriptionId(subscriptionId);
+						Object payload = message.getPayload();
+						message = new GenericMessage(payload, outHeaders.getMessageHeaders());
+						getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, message);
 					}
 				});
 
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java
index 570ec633c4a..55ab59ea86b 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java
@@ -35,6 +35,7 @@ import org.springframework.messaging.annotation.MessageMapping;
 import org.springframework.stereotype.Controller;
 import org.springframework.util.ClassUtils;
 import org.springframework.util.ReflectionUtils.MethodFilter;
+import org.springframework.web.messaging.PubSubHeaders;
 import org.springframework.web.messaging.annotation.SubscribeEvent;
 import org.springframework.web.messaging.annotation.UnsubscribeEvent;
 import org.springframework.web.messaging.converter.MessageConverter;
@@ -166,7 +167,8 @@ public class AnnotationMessageService extends AbstractMessageService implements
 
 	private void handleMessage(final Message message, Map handlerMethods) {
 
-		String destination = (String) message.getHeaders().get("destination");
+		PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
+		String destination = headers.getDestination();
 
 		HandlerMethod match = getHandlerMethod(destination, handlerMethods);
 		if (match == null) {
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java
index 728ce0261ca..ad16b5f97c0 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java
@@ -16,16 +16,11 @@
 
 package org.springframework.web.messaging.service.method;
 
-import java.util.HashMap;
-import java.util.Map;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.springframework.core.MethodParameter;
 import org.springframework.messaging.GenericMessage;
 import org.springframework.messaging.Message;
 import org.springframework.messaging.MessageChannel;
-import org.springframework.web.messaging.MessageType;
+import org.springframework.web.messaging.PubSubHeaders;
 import org.springframework.web.messaging.event.EventBus;
 import org.springframework.web.messaging.service.AbstractMessageService;
 
@@ -38,8 +33,6 @@ import reactor.util.Assert;
  */
 public class MessageChannelArgumentResolver implements ArgumentResolver {
 
-	private static Log logger = LogFactory.getLog(MessageChannelArgumentResolver.class);
-
 	private final EventBus eventBus;
 
 
@@ -56,24 +49,18 @@ public class MessageChannelArgumentResolver implements ArgumentResolver {
 	@Override
 	public Object resolveArgument(MethodParameter parameter, Message message) throws Exception {
 
-		final String sessionId = (String) message.getHeaders().get("sessionId");
+		final String sessionId = new PubSubHeaders(message.getHeaders(), true).getSessionId();
 
 		return new MessageChannel() {
 
 			@Override
 			public boolean send(Message message) {
 
-				Map headers = new HashMap(message.getHeaders());
-				headers.put("messageType", MessageType.MESSAGE);
-				headers.put("sessionId", sessionId);
-				message = new GenericMessage(message.getPayload(), headers);
-
-				if (logger.isTraceEnabled()) {
-					logger.trace("Sending notification: " + message);
-				}
+				PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), false);
+				headers.setSessionId(sessionId);
+				message = new GenericMessage(message.getPayload(), headers.getMessageHeaders());
 
-				String key = AbstractMessageService.MESSAGE_KEY;
-				MessageChannelArgumentResolver.this.eventBus.send(key, message);
+				eventBus.send(AbstractMessageService.CLIENT_TO_SERVER_MESSAGE_KEY, message);
 
 				return true;
 			}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java
index b5f0e1efd10..4a357b94d61 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java
@@ -16,11 +16,12 @@
 
 package org.springframework.web.messaging.service.method;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.springframework.core.MethodParameter;
+import org.springframework.messaging.GenericMessage;
 import org.springframework.messaging.Message;
+import org.springframework.web.messaging.PubSubHeaders;
 import org.springframework.web.messaging.event.EventBus;
+import org.springframework.web.messaging.service.AbstractMessageService;
 
 import reactor.util.Assert;
 
@@ -31,8 +32,6 @@ import reactor.util.Assert;
  */
 public class MessageReturnValueHandler implements ReturnValueHandler {
 
-	private static Log logger = LogFactory.getLog(MessageReturnValueHandler.class);
-
 	private final EventBus eventBus;
 
 
@@ -67,13 +66,17 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
 			return;
 		}
 
-		String replyTo = (String) message.getHeaders().getReplyChannel();
-		Assert.notNull(replyTo, "Cannot reply to: " + message);
+		PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
+		String sessionId = inHeaders.getSessionId();
+		String subscriptionId = inHeaders.getSubscriptionId();
+		Assert.notNull(subscriptionId, "No subscription id: " + message);
 
-		if (logger.isTraceEnabled()) {
-			logger.trace("Sending notification: " + message);
-		}
-		this.eventBus.send(replyTo, returnMessage);
+		PubSubHeaders outHeaders = new PubSubHeaders(returnMessage.getHeaders(), false);
+		outHeaders.setSessionId(sessionId);
+		outHeaders.setSubscriptionId(subscriptionId);
+		returnMessage = new GenericMessage(returnMessage.getPayload(), outHeaders.getMessageHeaders());
+
+		this.eventBus.send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, returnMessage);
  	}
 
 }
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompCommand.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompCommand.java
index cb99fd9e34d..95ca8d17751 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompCommand.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompCommand.java
@@ -16,6 +16,11 @@
 
 package org.springframework.web.messaging.stomp;
 
+import java.util.HashMap;
+import java.util.Map;
+
+import org.springframework.web.messaging.MessageType;
+
 
 /**
  *
@@ -43,4 +48,21 @@ public enum StompCommand {
 	RECEIPT,
 	ERROR;
 
+
+	private static Map commandToMessageType = new HashMap();
+
+	static {
+		commandToMessageType.put(StompCommand.CONNECT, MessageType.CONNECT);
+		commandToMessageType.put(StompCommand.STOMP, MessageType.CONNECT);
+		commandToMessageType.put(StompCommand.SEND, MessageType.MESSAGE);
+		commandToMessageType.put(StompCommand.SUBSCRIBE, MessageType.SUBSCRIBE);
+		commandToMessageType.put(StompCommand.UNSUBSCRIBE, MessageType.UNSUBSCRIBE);
+		commandToMessageType.put(StompCommand.DISCONNECT, MessageType.DISCONNECT);
+	}
+
+	public MessageType getMessageType() {
+		MessageType messageType = commandToMessageType.get(this);
+		return (messageType != null) ? messageType : MessageType.OTHER;
+	}
+
 }
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompException.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompConversionException.java
similarity index 82%
rename from spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompException.java
rename to spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompConversionException.java
index 7f2915aa9e7..bbc951b9a17 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompException.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompConversionException.java
@@ -20,17 +20,16 @@ import org.springframework.core.NestedRuntimeException;
 /**
  * @author Gary Russell
  * @since 4.0
- *
  */
 @SuppressWarnings("serial")
-public class StompException extends NestedRuntimeException {
+public class StompConversionException extends NestedRuntimeException {
 
 
-	public StompException(String msg, Throwable cause) {
+	public StompConversionException(String msg, Throwable cause) {
 		super(msg, cause);
 	}
 
-	public StompException(String msg) {
+	public StompConversionException(String msg) {
 		super(msg);
 	}
 
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java
index d8e5535de4c..eb6e0fafdf1 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java
@@ -16,42 +16,32 @@
 
 package org.springframework.web.messaging.stomp;
 
-import java.io.Serializable;
-import java.util.Arrays;
-import java.util.Collection;
 import java.util.Collections;
-import java.util.LinkedHashMap;
-import java.util.LinkedList;
 import java.util.List;
-import java.util.Map;
 import java.util.Set;
 
 import org.springframework.http.MediaType;
-import org.springframework.util.Assert;
-import org.springframework.util.MultiValueMap;
+import org.springframework.messaging.MessageHeaders;
 import org.springframework.util.StringUtils;
+import org.springframework.web.messaging.PubSubHeaders;
+
+import reactor.util.Assert;
 
 
 /**
+ * STOMP adapter for {@link MessageHeaders}.
  *
  * @author Rossen Stoyanchev
  * @since 4.0
  */
-public class StompHeaders implements MultiValueMap, Serializable {
-
-	private static final long serialVersionUID = 1L;
+public class StompHeaders extends PubSubHeaders {
 
-	// TODO: separate client from server headers so they can't be mixed
-
-	// Client
 	private static final String ID = "id";
 
 	private static final String HOST = "host";
 
 	private static final String ACCEPT_VERSION = "accept-version";
 
-	// Server
-
 	private static final String MESSAGE_ID = "message-id";
 
 	private static final String RECEIPT_ID = "receipt-id";
@@ -62,8 +52,6 @@ public class StompHeaders implements MultiValueMap, Serializable
 
 	private static final String MESSAGE = "message";
 
-	// Client and Server
-
 	private static final String ACK = "ack";
 
 	private static final String DESTINATION = "destination";
@@ -75,270 +63,167 @@ public class StompHeaders implements MultiValueMap, Serializable
 	private static final String HEARTBEAT = "heart-beat";
 
 
-	public static final List STANDARD_HEADER_NAMES =
-			Arrays.asList(ID, HOST, ACCEPT_VERSION, MESSAGE_ID, RECEIPT_ID, SUBSCRIPTION,
-					VERSION, MESSAGE, ACK, DESTINATION, CONTENT_LENGTH, CONTENT_TYPE, HEARTBEAT);
-
-
-	private final Map> headers;
-
-
 	/**
-	 * Private constructor that can create read-only {@code StompHeaders} instances.
+	 * Constructor for building new headers.
+	 *
+	 * @param command the STOMP command
 	 */
-	private StompHeaders(Map> headers, boolean readOnly) {
-		Assert.notNull(headers, "'headers' must not be null");
-		if (readOnly) {
-			Map> map = new LinkedHashMap>(headers.size());
-			for (Entry> entry : headers.entrySet()) {
-				List values = Collections.unmodifiableList(entry.getValue());
-				map.put(entry.getKey(), values);
-			}
-			this.headers = Collections.unmodifiableMap(map);
-		}
-		else {
-			this.headers = headers;
-		}
+	public StompHeaders(StompCommand command) {
+		super(command.getMessageType(), command);
 	}
 
 	/**
-	 * Constructs a new, empty instance of the {@code StompHeaders} object.
+	 * Constructor for access to existing {@link MessageHeaders}.
+	 *
+	 * @param messageHeaders the existing message headers
+	 * @param readOnly whether the resulting instance will be used for read-only access,
+	 *        if {@code true}, then set methods will throw exceptions; if {@code false}
+	 *        they will work.
 	 */
-	public StompHeaders() {
-		this(new LinkedHashMap>(4), false);
+	public StompHeaders(MessageHeaders messageHeaders, boolean readOnly) {
+		super(messageHeaders, readOnly);
 	}
 
-	/**
-	 * Returns {@code StompHeaders} object that can only be read, not written to.
-	 */
-	public static StompHeaders readOnlyStompHeaders(StompHeaders headers) {
-		return new StompHeaders(headers, true);
+	@Override
+	public StompCommand getProtocolMessageType() {
+		return (StompCommand) super.getProtocolMessageType();
+	}
+
+	public StompCommand getStompCommand() {
+		return (StompCommand) super.getProtocolMessageType();
 	}
 
 	public Set getAcceptVersion() {
-		String rawValue = getFirst(ACCEPT_VERSION);
+		String rawValue = getRawHeaders().get(ACCEPT_VERSION);
 		return (rawValue != null) ? StringUtils.commaDelimitedListToSet(rawValue) : Collections.emptySet();
 	}
 
 	public void setAcceptVersion(String acceptVersion) {
-		set(ACCEPT_VERSION, acceptVersion);
+		getRawHeaders().put(ACCEPT_VERSION, acceptVersion);
 	}
 
-	public String getVersion() {
-		return getFirst(VERSION);
-	}
-
-	public void setVersion(String version) {
-		set(VERSION, version);
-	}
-
-	public String getDestination() {
-		return getFirst(DESTINATION);
+	@Override
+	public void setDestination(String destination) {
+		if (destination != null) {
+			super.setDestination(destination);
+			getRawHeaders().put(DESTINATION, destination);
+		}
 	}
 
-	public void setDestination(String destination) {
-		set(DESTINATION, destination);
+	@Override
+	public void setDestinations(List destinations) {
+		if (destinations != null) {
+			super.setDestinations(destinations);
+			getRawHeaders().put(DESTINATION, destinations.get(0));
+		}
 	}
 
-	public MediaType getContentType() {
-		String contentType = getFirst(CONTENT_TYPE);
-		return StringUtils.hasText(contentType) ? MediaType.valueOf(contentType) : null;
+	public long[] getHeartbeat() {
+		String rawValue = getRawHeaders().get(HEARTBEAT);
+		if (!StringUtils.hasText(rawValue)) {
+			return null;
+		}
+		String[] rawValues = StringUtils.commaDelimitedListToStringArray(rawValue);
+		// TODO assertions
+		return new long[] { Long.valueOf(rawValues[0]), Long.valueOf(rawValues[1])};
 	}
 
 	public void setContentType(MediaType mediaType) {
 		if (mediaType != null) {
-			set(CONTENT_TYPE, mediaType.toString());
-		}
-		else {
-			remove(CONTENT_TYPE);
+			super.setContentType(mediaType);
+			getRawHeaders().put(CONTENT_TYPE, mediaType.toString());
 		}
 	}
 
 	public Integer getContentLength() {
-		String contentLength = getFirst(CONTENT_LENGTH);
+		String contentLength = getRawHeaders().get(CONTENT_LENGTH);
 		return StringUtils.hasText(contentLength) ? new Integer(contentLength) : null;
 	}
 
 	public void setContentLength(int contentLength) {
-		set(CONTENT_LENGTH, String.valueOf(contentLength));
+		getRawHeaders().put(CONTENT_LENGTH, String.valueOf(contentLength));
 	}
 
-	public long[] getHeartbeat() {
-		String rawValue = getFirst(HEARTBEAT);
-		if (!StringUtils.hasText(rawValue)) {
-			return null;
-		}
-		String[] rawValues = StringUtils.commaDelimitedListToStringArray(rawValue);
-		// TODO assertions
-		return new long[] { Long.valueOf(rawValues[0]), Long.valueOf(rawValues[1])};
+	@Override
+	public String getSubscriptionId() {
+		return StompCommand.SUBSCRIBE.equals(getStompCommand()) ? getRawHeaders().get(ID) : null;
+	}
+
+	@Override
+	public void setSubscriptionId(String subscriptionId) {
+		Assert.isTrue(StompCommand.MESSAGE.equals(getStompCommand()),
+				"\"subscription\" can only be set on a STOMP MESSAGE frame");
+		super.setSubscriptionId(subscriptionId);
+		getRawHeaders().put(SUBSCRIPTION, subscriptionId);
 	}
 
 	public void setHeartbeat(long cx, long cy) {
-		set(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy}));
+		getRawHeaders().put(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy}));
 	}
 
-	public String getId() {
-		return getFirst(ID);
+	public String getMessage() {
+		return getRawHeaders().get(MESSAGE);
 	}
 
-	public void setId(String id) {
-		set(ID, id);
+	public void setMessage(String content) {
+		getRawHeaders().put(MESSAGE, content);
 	}
 
 	public String getMessageId() {
-		return getFirst(MESSAGE_ID);
+		return getRawHeaders().get(MESSAGE_ID);
 	}
 
 	public void setMessageId(String id) {
-		set(MESSAGE_ID, id);
-	}
-
-	public String getSubscription() {
-		return getFirst(SUBSCRIPTION);
-	}
-
-	public void setSubscription(String id) {
-		set(SUBSCRIPTION, id);
+		getRawHeaders().put(MESSAGE_ID, id);
 	}
 
-	public String getMessage() {
-		return getFirst(MESSAGE);
+	public String getVersion() {
+		return getRawHeaders().get(VERSION);
 	}
 
-	public void setMessage(String id) {
-		set(MESSAGE, id);
+	public void setVersion(String version) {
+		getRawHeaders().put(VERSION, version);
 	}
 
 
-	// MultiValueMap methods
-
 	/**
-	 * Return the first header value for the given header name, if any.
-	 * @param headerName the header name
-	 * @return the first header value; or {@code null}
+	 * Update generic message headers from raw headers. This method only needs to be
+	 * invoked when raw headers are added via {@link #getRawHeaders()}.
 	 */
-	public String getFirst(String headerName) {
-		List headerValues = headers.get(headerName);
-		return headerValues != null ? headerValues.get(0) : null;
-	}
-
-	/**
-	 * Add the given, single header value under the given name.
-	 * @param headerName  the header name
-	 * @param headerValue the header value
-	 * @throws UnsupportedOperationException if adding headers is not supported
-	 * @see #put(String, List)
-	 * @see #set(String, String)
-	 */
-	public void add(String headerName, String headerValue) {
-		List headerValues = headers.get(headerName);
-		if (headerValues == null) {
-			headerValues = new LinkedList();
-			this.headers.put(headerName, headerValues);
+	public void updateMessageHeaders() {
+		String destination = getRawHeaders().get(DESTINATION);
+		if (destination != null) {
+			setDestination(destination);
+		}
+		String contentType = getRawHeaders().get(CONTENT_TYPE);
+		if (contentType != null) {
+			setContentType(MediaType.parseMediaType(contentType));
+		}
+		if (StompCommand.SUBSCRIBE.equals(getStompCommand())) {
+			if (getRawHeaders().get(ID) != null) {
+				super.setSubscriptionId(getRawHeaders().get(ID));
+			}
 		}
-		headerValues.add(headerValue);
 	}
 
 	/**
-	 * Set the given, single header value under the given name.
-	 * @param headerName  the header name
-	 * @param headerValue the header value
-	 * @throws UnsupportedOperationException if adding headers is not supported
-	 * @see #put(String, List)
-	 * @see #add(String, String)
+	 * Update raw headers from generic message headers. This method only needs to be
+	 * invoked if creating {@link StompHeaders} from {@link MessageHeaders} that never
+	 * contained raw headers.
 	 */
-	public void set(String headerName, String headerValue) {
-		List headerValues = new LinkedList();
-		headerValues.add(headerValue);
-		headers.put(headerName, headerValues);
-	}
-
-	public void setAll(Map values) {
-		for (Entry entry : values.entrySet()) {
-			set(entry.getKey(), entry.getValue());
+	public void updateRawHeaders() {
+		String destination = getDestination();
+		if (destination != null) {
+			getRawHeaders().put(DESTINATION, destination);
 		}
-	}
-
-	public Map toSingleValueMap() {
-		LinkedHashMap singleValueMap = new LinkedHashMap(this.headers.size());
-		for (Entry> entry : headers.entrySet()) {
-			singleValueMap.put(entry.getKey(), entry.getValue().get(0));
-		}
-		return singleValueMap;
-	}
-
-
-	// Map implementation
-
-	public int size() {
-		return this.headers.size();
-	}
-
-	public boolean isEmpty() {
-		return this.headers.isEmpty();
-	}
-
-	public boolean containsKey(Object key) {
-		return this.headers.containsKey(key);
-	}
-
-	public boolean containsValue(Object value) {
-		return this.headers.containsValue(value);
-	}
-
-	public List get(Object key) {
-		return this.headers.get(key);
-	}
-
-	public List put(String key, List value) {
-		return this.headers.put(key, value);
-	}
-
-	public List remove(Object key) {
-		return this.headers.remove(key);
-	}
-
-	public void putAll(Map> m) {
-		this.headers.putAll(m);
-	}
-
-	public void clear() {
-		this.headers.clear();
-	}
-
-	public Set keySet() {
-		return this.headers.keySet();
-	}
-
-	public Collection> values() {
-		return this.headers.values();
-	}
-
-	public Set>> entrySet() {
-		return this.headers.entrySet();
-	}
-
-
-	@Override
-	public boolean equals(Object other) {
-		if (this == other) {
-			return true;
+		MediaType contentType = getContentType();
+		if (contentType != null) {
+			getRawHeaders().put(CONTENT_TYPE, contentType.toString());
 		}
-		if (!(other instanceof StompHeaders)) {
-			return false;
+		String subscriptionId = getSubscriptionId();
+		if (subscriptionId != null) {
+			getRawHeaders().put(SUBSCRIPTION, subscriptionId);
 		}
-		StompHeaders otherHeaders = (StompHeaders) other;
-		return this.headers.equals(otherHeaders.headers);
-	}
-
-	@Override
-	public int hashCode() {
-		return this.headers.hashCode();
-	}
-
-	@Override
-	public String toString() {
-		return this.headers.toString();
 	}
 
 }
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompMessage.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompMessage.java
deleted file mode 100644
index 78bbe43efd4..00000000000
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompMessage.java
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright 2002-2013 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.web.messaging.stomp;
-
-import java.nio.charset.Charset;
-
-
-/**
- *
- * @author Rossen Stoyanchev
- * @since 4.0
- */
-public class StompMessage {
-
-	public static final Charset CHARSET = Charset.forName("UTF-8");
-
-	private final StompCommand command;
-
-	private final StompHeaders headers;
-
-	private final byte[] payload;
-
-	private String sessionId;
-
-
-	public StompMessage(StompCommand command, StompHeaders headers, byte[] payload) {
-		this.command = command;
-		this.headers = (headers != null) ? headers : new StompHeaders();
-		this.payload = payload;
-	}
-
-	/**
-	 * Constructor for empty payload message.
-	 */
-	public StompMessage(StompCommand command, StompHeaders headers) {
-		this(command, headers, new byte[0]);
-	}
-
-	public StompCommand getCommand() {
-		return this.command;
-	}
-
-	public StompHeaders getHeaders() {
-		return this.headers;
-	}
-
-	public byte[] getPayload() {
-		return this.payload;
-	}
-
-	public void setSessionId(String sessionId) {
-		this.sessionId = sessionId;
-	}
-
-	public String getSessionId() {
-		return this.sessionId;
-	}
-
-	@Override
-	public String toString() {
-		return "StompMessage [" + command + ", headers=" + this.headers + ", payload=" + new String(this.payload) + "]";
-	}
-
-}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompSession.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompSession.java
deleted file mode 100644
index 1da941be2d8..00000000000
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompSession.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright 2002-2013 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.web.messaging.stomp;
-
-import java.io.IOException;
-
-
-/**
- * @author Rossen Stoyanchev
- * @since 4.0
- */
-public interface StompSession {
-
-	String getId();
-
-	/**
-	 * TODO...
-	 * 

- * If the message is a STOMP ERROR message, the session will also be closed. - */ - void sendMessage(StompMessage message) throws IOException; - - /** - * Register a task to be invoked if the underlying connection is closed. - */ - void registerConnectionClosedTask(Runnable task); - -} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java index 43641d28937..d8e66221446 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java @@ -15,14 +15,14 @@ */ package org.springframework.web.messaging.stomp.socket; +import java.nio.charset.Charset; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import org.springframework.util.Assert; -import org.springframework.web.messaging.stomp.StompCommand; +import org.springframework.messaging.GenericMessage; +import org.springframework.messaging.Message; import org.springframework.web.messaging.stomp.StompHeaders; -import org.springframework.web.messaging.stomp.StompMessage; -import org.springframework.web.messaging.stomp.StompSession; +import org.springframework.web.messaging.stomp.StompCommand; import org.springframework.web.messaging.stomp.support.StompMessageConverter; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; @@ -36,57 +36,65 @@ import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; */ public abstract class AbstractStompWebSocketHandler extends TextWebSocketHandlerAdapter { - private final StompMessageConverter messageConverter = new StompMessageConverter(); + private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); - private final Map sessions = new ConcurrentHashMap(); + private final Map sessions = new ConcurrentHashMap(); + public StompMessageConverter getStompMessageConverter() { + return this.stompMessageConverter; + } + + protected WebSocketSession getWebSocketSession(String sessionId) { + return this.sessions.get(sessionId); + } + @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { - WebSocketStompSession stompSession = new WebSocketStompSession(session, this.messageConverter); - this.sessions.put(session.getId(), stompSession); + this.sessions.put(session.getId(), session); } @Override - protected void handleTextMessage(WebSocketSession session, TextMessage message) { - - StompSession stompSession = this.sessions.get(session.getId()); - Assert.notNull(stompSession, "No STOMP session for WebSocket session id=" + session.getId()); - + protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) { try { - StompMessage stompMessage = this.messageConverter.toStompMessage(message.getPayload()); - stompMessage.setSessionId(stompSession.getId()); + String payload = textMessage.getPayload(); + Message message = this.stompMessageConverter.toMessage(payload, session.getId()); // TODO: validate size limits // http://stomp.github.io/stomp-specification-1.2.html#Size_Limits - handleStompMessage(stompSession, stompMessage); + handleStompMessage(session, message); // TODO: send RECEIPT message if incoming message has "receipt" header // http://stomp.github.io/stomp-specification-1.2.html#Header_receipt } catch (Throwable error) { - StompHeaders headers = new StompHeaders(); - headers.setMessage(error.getMessage()); - StompMessage errorMessage = new StompMessage(StompCommand.ERROR, headers); - try { - stompSession.sendMessage(errorMessage); - } - catch (Throwable t) { - // ignore - } + sendErrorMessage(session, error); } } - protected abstract void handleStompMessage(StompSession stompSession, StompMessage stompMessage); + protected void sendErrorMessage(WebSocketSession session, Throwable error) { + + StompHeaders stompHeaders = new StompHeaders(StompCommand.ERROR); + stompHeaders.setMessage(error.getMessage()); + + Message errorMessage = new GenericMessage(new byte[0], stompHeaders.getMessageHeaders()); + byte[] bytes = this.stompMessageConverter.fromMessage(errorMessage); + + try { + session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); + } + catch (Throwable t) { + // ignore + } + } + + protected abstract void handleStompMessage(WebSocketSession session, Message message); @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { - WebSocketStompSession stompSession = this.sessions.remove(session.getId()); - if (stompSession != null) { - stompSession.handleConnectionClosed(); - } + this.sessions.remove(session.getId()); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java index 506841fc9ce..a21cf4d0744 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java @@ -16,32 +16,28 @@ package org.springframework.web.messaging.stomp.socket; import java.io.IOException; -import java.util.ArrayList; +import java.nio.charset.Charset; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.http.MediaType; import org.springframework.messaging.GenericMessage; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageHeaders; -import org.springframework.util.CollectionUtils; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.converter.CompositeMessageConverter; import org.springframework.web.messaging.converter.MessageConverter; import org.springframework.web.messaging.event.EventBus; import org.springframework.web.messaging.event.EventConsumer; -import org.springframework.web.messaging.event.EventRegistration; import org.springframework.web.messaging.service.AbstractMessageService; import org.springframework.web.messaging.stomp.StompCommand; -import org.springframework.web.messaging.stomp.StompException; +import org.springframework.web.messaging.stomp.StompConversionException; import org.springframework.web.messaging.stomp.StompHeaders; -import org.springframework.web.messaging.stomp.StompMessage; -import org.springframework.web.messaging.stomp.StompSession; -import org.springframework.web.messaging.stomp.support.StompHeaderMapper; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; /** * @author Gary Russell @@ -57,14 +53,59 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler private MessageConverter payloadConverter = new CompositeMessageConverter(null); - private final StompHeaderMapper headerMapper = new StompHeaderMapper(); - - private Map> registrationsBySession = - new ConcurrentHashMap>(); - public DefaultStompWebSocketHandler(EventBus eventBus) { + this.eventBus = eventBus; + + this.eventBus.registerConsumer(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, + new EventConsumer>() { + @Override + public void accept(Message message) { + + StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), false); + if (stompHeaders.getProtocolMessageType() == null) { + stompHeaders.setProtocolMessageType(StompCommand.MESSAGE); + } + + if (StompCommand.CONNECTED.equals(stompHeaders.getStompCommand())) { + // Ignore for now since we already sent it + return; + } + + String sessionId = stompHeaders.getSessionId(); + WebSocketSession session = getWebSocketSession(sessionId); + + byte[] payload; + try { + MediaType contentType = stompHeaders.getContentType(); + payload = payloadConverter.convertToPayload(message.getPayload(), contentType); + } + catch (Exception e) { + logger.error("Failed to send " + message, e); + return; + } + + try { + Map messageHeaders = stompHeaders.getMessageHeaders(); + Message byteMessage = new GenericMessage(payload, messageHeaders); + byte[] bytes = getStompMessageConverter().fromMessage(byteMessage); + session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); + } + catch (Throwable t) { + sendErrorMessage(session, t); + } + finally { + if (StompCommand.ERROR.equals(stompHeaders.getStompCommand())) { + try { + session.close(CloseStatus.PROTOCOL_ERROR); + } + catch (IOException e) { + } + } + } + } + }); } @@ -72,252 +113,83 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler this.payloadConverter = new CompositeMessageConverter(converters); } - public void handleStompMessage(final StompSession session, StompMessage stompMessage) { + public void handleStompMessage(final WebSocketSession session, Message message) { if (logger.isTraceEnabled()) { - logger.trace("Processing: " + stompMessage); + logger.trace("Processing: " + message); } try { - MessageType messageType = MessageType.OTHER; - String replyKey = null; - - StompCommand command = stompMessage.getCommand(); - if (StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command)) { - session.registerConnectionClosedTask(new ConnectionClosedTask(session)); - messageType = MessageType.CONNECT; - replyKey = handleConnect(session, stompMessage); - } - else if (StompCommand.SEND.equals(command)) { - messageType = MessageType.MESSAGE; - handleSend(session, stompMessage); - } - else if (StompCommand.SUBSCRIBE.equals(command)) { - messageType = MessageType.SUBSCRIBE; - replyKey = handleSubscribe(session, stompMessage); + StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), true); + MessageType messageType = stompHeaders.getMessageType(); + if (MessageType.CONNECT.equals(messageType)) { + handleConnect(session, message); } - else if (StompCommand.UNSUBSCRIBE.equals(command)) { - messageType = MessageType.UNSUBSCRIBE; - handleUnsubscribe(session, stompMessage); + else if (MessageType.MESSAGE.equals(messageType)) { + handleMessage(message); } - else if (StompCommand.DISCONNECT.equals(command)) { - messageType = MessageType.DISCONNECT; - handleDisconnect(session, stompMessage); + else if (MessageType.SUBSCRIBE.equals(messageType)) { + handleSubscribe(message); } - else { - sendErrorMessage(session, "Invalid STOMP command " + command); - return; + else if (MessageType.UNSUBSCRIBE.equals(messageType)) { + handleUnsubscribe(message); } - - Map messageHeaders = this.headerMapper.toMessageHeaders(stompMessage.getHeaders()); - messageHeaders.put("messageType", messageType); - if (replyKey != null) { - messageHeaders.put(MessageHeaders.REPLY_CHANNEL, replyKey); + else if (MessageType.DISCONNECT.equals(messageType)) { + handleDisconnect(message); } - messageHeaders.put("stompCommand", command); - messageHeaders.put("sessionId", session.getId()); - - Message genericMessage = new GenericMessage(stompMessage.getPayload(), messageHeaders); - - if (logger.isTraceEnabled()) { - logger.trace("Sending notification: " + genericMessage); - } - this.eventBus.send(AbstractMessageService.MESSAGE_KEY, genericMessage); + this.eventBus.send(AbstractMessageService.CLIENT_TO_SERVER_MESSAGE_KEY, message); } catch (Throwable t) { - handleError(session, t); - } - } - - private void handleError(final StompSession session, Throwable t) { - logger.error("Terminating STOMP session due to failure to send message: ", t); - sendErrorMessage(session, t.getMessage()); - if (removeSubscriptions(session)) { - // TODO: send error event including exception info + logger.error("Terminating STOMP session due to failure to send message: ", t); + sendErrorMessage(session, t); } } - private void sendErrorMessage(StompSession session, String errorText) { - StompHeaders headers = new StompHeaders(); - headers.setMessage(errorText); - StompMessage errorMessage = new StompMessage(StompCommand.ERROR, headers); - try { - session.sendMessage(errorMessage); - } - catch (Throwable t) { - // ignore - } - } + protected void handleConnect(final WebSocketSession session, Message message) throws IOException { - protected String handleConnect(final StompSession session, StompMessage stompMessage) throws IOException { + StompHeaders connectStompHeaders = new StompHeaders(message.getHeaders(), true); + StompHeaders connectedStompHeaders = new StompHeaders(StompCommand.CONNECTED); - StompHeaders headers = new StompHeaders(); - Set acceptVersions = stompMessage.getHeaders().getAcceptVersion(); + Set acceptVersions = connectStompHeaders.getAcceptVersion(); if (acceptVersions.contains("1.2")) { - headers.setVersion("1.2"); + connectedStompHeaders.setAcceptVersion("1.2"); } else if (acceptVersions.contains("1.1")) { - headers.setVersion("1.1"); + connectedStompHeaders.setAcceptVersion("1.1"); } else if (acceptVersions.isEmpty()) { // 1.0 } else { - throw new StompException("Unsupported version '" + acceptVersions + "'"); + throw new StompConversionException("Unsupported version '" + acceptVersions + "'"); } - headers.setHeartbeat(0,0); // TODO - headers.setId(session.getId()); + connectedStompHeaders.setHeartbeat(0,0); // TODO // TODO: security - session.sendMessage(new StompMessage(StompCommand.CONNECTED, headers)); - - String replyKey = "relay-message" + session.getId(); - - EventRegistration registration = this.eventBus.registerConsumer(replyKey, - new EventConsumer() { - @Override - public void accept(StompMessage message) { - try { - if (StompCommand.CONNECTED.equals(message.getCommand())) { - // TODO: skip for now (we already sent CONNECTED) - return; - } - if (logger.isTraceEnabled()) { - logger.trace("Relaying back to client: " + message); - } - session.sendMessage(message); - } - catch (Throwable t) { - handleError(session, t); - } - } - }); - addRegistration(session, registration); - - return replyKey; + Message connectedMessage = new GenericMessage(new byte[0], connectedStompHeaders.getMessageHeaders()); + byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage); + session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } - protected String handleSubscribe(final StompSession session, StompMessage message) { - - final String subscriptionId = message.getHeaders().getId(); - String replyKey = getSubscriptionReplyKey(session, subscriptionId); - - // TODO: extract and remember "ack" mode - // http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE_ack_Header - - if (logger.isTraceEnabled()) { - logger.trace("Adding subscription, key=" + replyKey); - } - - EventRegistration registration = this.eventBus.registerConsumer(replyKey, new EventConsumer>() { - @Override - public void accept(Message replyMessage) { - - StompHeaders headers = new StompHeaders(); - headers.setSubscription(subscriptionId); - - headerMapper.fromMessageHeaders(replyMessage.getHeaders(), headers); - - byte[] payload; - try { - MediaType contentType = headers.getContentType(); - payload = payloadConverter.convertToPayload(replyMessage.getPayload(), contentType); - } - catch (Exception e) { - logger.error("Failed to send " + replyMessage, e); - return; - } - - try { - StompMessage stompMessage = new StompMessage(StompCommand.MESSAGE, headers, payload); - session.sendMessage(stompMessage); - } - catch (Throwable t) { - handleError(session, t); - } - } - }); - addRegistration(session, registration); - - return replyKey; - + protected void handleSubscribe(Message message) { // TODO: need a way to communicate back if subscription was successfully created or // not in which case an ERROR should be sent back and close the connection // http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE } - private String getSubscriptionReplyKey(StompSession session, String subscriptionId) { - return StompCommand.SUBSCRIBE + ":" + session.getId() + ":" + subscriptionId; - } - - private void addRegistration(StompSession session, EventRegistration registration) { - String sessionId = session.getId(); - List list = this.registrationsBySession.get(sessionId); - if (list == null) { - list = new ArrayList(); - this.registrationsBySession.put(sessionId, list); - } - list.add(registration); - } - - protected void handleUnsubscribe(StompSession session, StompMessage message) { - cancelRegistration(session, message.getHeaders().getId()); - } - - private void cancelRegistration(StompSession session, String subscriptionId) { - String key = getSubscriptionReplyKey(session, subscriptionId); - List list = this.registrationsBySession.get(session.getId()); - for (EventRegistration registration : list) { - if (registration.getRegistrationKey().equals(key)) { - if (logger.isDebugEnabled()) { - logger.debug("Cancelling subscription, key=" + key); - } - list.remove(registration); - registration.cancel(); - } - } + protected void handleUnsubscribe(Message message) { } - protected void handleSend(StompSession session, StompMessage stompMessage) { + protected void handleMessage(Message stompMessage) { } - protected void handleDisconnect(StompSession session, StompMessage stompMessage) { - removeSubscriptions(session); - } - - private boolean removeSubscriptions(StompSession session) { - String sessionId = session.getId(); - List registrations = this.registrationsBySession.remove(sessionId); - if (CollectionUtils.isEmpty(registrations)) { - return false; - } - if (logger.isTraceEnabled()) { - logger.trace("Cancelling " + registrations.size() + " subscriptions for session=" + sessionId); - } - for (EventRegistration registration : registrations) { - registration.cancel(); - } - return true; + protected void handleDisconnect(Message stompMessage) { } - - private final class ConnectionClosedTask implements Runnable { - - private final StompSession session; - - private ConnectionClosedTask(StompSession session) { - this.session = session; - } - - @Override - public void run() { - removeSubscriptions(session); - if (logger.isTraceEnabled()) { - logger.trace("Sending notification for closed connection: " + session.getId()); - } - eventBus.send(AbstractMessageService.CLIENT_CONNECTION_CLOSED_KEY, session.getId()); - } + @Override + public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { + eventBus.send(AbstractMessageService.CLIENT_CONNECTION_CLOSED_KEY, session.getId()); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/WebSocketStompSession.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/WebSocketStompSession.java deleted file mode 100644 index cbcc143b132..00000000000 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/WebSocketStompSession.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.messaging.stomp.socket; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import org.springframework.util.Assert; -import org.springframework.web.messaging.stomp.StompCommand; -import org.springframework.web.messaging.stomp.StompMessage; -import org.springframework.web.messaging.stomp.StompSession; -import org.springframework.web.messaging.stomp.support.StompMessageConverter; -import org.springframework.web.socket.CloseStatus; -import org.springframework.web.socket.TextMessage; -import org.springframework.web.socket.WebSocketSession; - - -/** - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class WebSocketStompSession implements StompSession { - - private final String id; - - private WebSocketSession webSocketSession; - - private final StompMessageConverter messageConverter; - - private final List connectionClosedTasks = new ArrayList(); - - - public WebSocketStompSession(WebSocketSession webSocketSession, StompMessageConverter messageConverter) { - Assert.notNull(webSocketSession, "webSocketSession is required"); - this.id = webSocketSession.getId(); - this.webSocketSession = webSocketSession; - this.messageConverter = messageConverter; - } - - @Override - public String getId() { - return this.id; - } - - @Override - public void sendMessage(StompMessage message) throws IOException { - - Assert.notNull(this.webSocketSession, "Cannot send message without active session"); - - try { - byte[] bytes = this.messageConverter.fromStompMessage(message); - this.webSocketSession.sendMessage(new TextMessage(new String(bytes, StompMessage.CHARSET))); - } - finally { - if (StompCommand.ERROR.equals(message.getCommand())) { - this.webSocketSession.close(CloseStatus.PROTOCOL_ERROR); - this.webSocketSession = null; - } - } - } - - public void registerConnectionClosedTask(Runnable task) { - this.connectionClosedTasks.add(task); - } - - public void handleConnectionClosed() { - for (Runnable task : this.connectionClosedTasks) { - try { - task.run(); - } - catch (Throwable t) { - // ignore - } - } - } - -} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java index 33bbb274a59..bbe024514ea 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java @@ -31,6 +31,7 @@ import javax.net.SocketFactory; import org.springframework.core.task.TaskExecutor; import org.springframework.http.MediaType; +import org.springframework.messaging.GenericMessage; import org.springframework.messaging.Message; import org.springframework.web.messaging.converter.CompositeMessageConverter; import org.springframework.web.messaging.converter.MessageConverter; @@ -38,7 +39,6 @@ import org.springframework.web.messaging.event.EventBus; import org.springframework.web.messaging.service.AbstractMessageService; import org.springframework.web.messaging.stomp.StompCommand; import org.springframework.web.messaging.stomp.StompHeaders; -import org.springframework.web.messaging.stomp.StompMessage; import reactor.util.Assert; @@ -57,8 +57,6 @@ public class RelayStompService extends AbstractMessageService { private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); - private final StompHeaderMapper stompHeaderMapper = new StompHeaderMapper(); - public RelayStompService(EventBus eventBus, TaskExecutor executor) { super(eventBus); @@ -84,8 +82,7 @@ public class RelayStompService extends AbstractMessageService { forwardMessage(message, StompCommand.CONNECT); - String replyTo = (String) message.getHeaders().getReplyChannel(); - RelayReadTask readTask = new RelayReadTask(sessionId, replyTo, session); + RelayReadTask readTask = new RelayReadTask(sessionId, session); this.taskExecutor.execute(readTask); } catch (Throwable t) { @@ -96,23 +93,25 @@ public class RelayStompService extends AbstractMessageService { private void forwardMessage(Message message, StompCommand command) { - String sessionId = (String) message.getHeaders().get("sessionId"); + StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), false); + String sessionId = stompHeaders.getSessionId(); RelaySession session = RelayStompService.this.relaySessions.get(sessionId); Assert.notNull(session, "RelaySession not found"); try { - StompHeaders stompHeaders = new StompHeaders(); - this.stompHeaderMapper.fromMessageHeaders(message.getHeaders(), stompHeaders); + if (stompHeaders.getProtocolMessageType() == null) { + stompHeaders.setProtocolMessageType(StompCommand.SEND); + } MediaType contentType = stompHeaders.getContentType(); byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType); - StompMessage stompMessage = new StompMessage(command, stompHeaders, payload); + Message byteMessage = new GenericMessage(payload, stompHeaders.getMessageHeaders()); if (logger.isTraceEnabled()) { - logger.trace("Forwarding: " + stompMessage); + logger.trace("Forwarding: " + byteMessage); } - byte[] bytesToWrite = this.stompMessageConverter.fromStompMessage(stompMessage); - session.getOutputStream().write(bytesToWrite); + byte[] bytes = this.stompMessageConverter.fromMessage(byteMessage); + session.getOutputStream().write(bytes); session.getOutputStream().flush(); } catch (Exception ex) { @@ -200,13 +199,12 @@ public class RelayStompService extends AbstractMessageService { private final class RelayReadTask implements Runnable { - private final String stompSessionId; - private final String replyTo; + private final String sessionId; + private final RelaySession session; - private RelayReadTask(String stompSessionId, String replyTo, RelaySession session) { - this.stompSessionId = stompSessionId; - this.replyTo = replyTo; + private RelayReadTask(String sessionId, RelaySession session) { + this.sessionId = sessionId; this.session = session; } @@ -221,28 +219,28 @@ public class RelayStompService extends AbstractMessageService { } else if (b == 0x00) { byte[] bytes = out.toByteArray(); - StompMessage message = RelayStompService.this.stompMessageConverter.toStompMessage(bytes); - getEventBus().send(this.replyTo, message); + Message message = stompMessageConverter.toMessage(bytes, sessionId); + getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, message); out.reset(); } else { out.write(b); } } - logger.debug("Socket closed, STOMP session=" + stompSessionId); + logger.debug("Socket closed, STOMP session=" + sessionId); sendErrorMessage("Lost connection"); } catch (IOException e) { logger.error("Socket error: " + e.getMessage()); - clearRelaySession(stompSessionId); + clearRelaySession(sessionId); } } private void sendErrorMessage(String message) { - StompHeaders headers = new StompHeaders(); - headers.setMessage(message); - StompMessage errorMessage = new StompMessage(StompCommand.ERROR, headers); - getEventBus().send(this.replyTo, errorMessage); + StompHeaders stompHeaders = new StompHeaders(StompCommand.ERROR); + stompHeaders.setMessage(message); + Message errorMessage = new GenericMessage(new byte[0], stompHeaders.getMessageHeaders()); + getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, errorMessage); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompHeaderMapper.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompHeaderMapper.java deleted file mode 100644 index 849ca1bf580..00000000000 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompHeaderMapper.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.messaging.stomp.support; - -import java.util.HashMap; -import java.util.Map; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.http.MediaType; -import org.springframework.messaging.MessageHeaders; -import org.springframework.web.messaging.stomp.StompHeaders; - - -/** - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class StompHeaderMapper { - - private static Log logger = LogFactory.getLog(StompHeaderMapper.class); - - private static final String[][] stompHeaderNames; - - static { - stompHeaderNames = new String[2][StompHeaders.STANDARD_HEADER_NAMES.size()]; - for (int i=0 ; i < StompHeaders.STANDARD_HEADER_NAMES.size(); i++) { - stompHeaderNames[0][i] = StompHeaders.STANDARD_HEADER_NAMES.get(i); - stompHeaderNames[1][i] = "stomp." + StompHeaders.STANDARD_HEADER_NAMES.get(i); - } - } - - - public Map toMessageHeaders(StompHeaders stompHeaders) { - - Map headers = new HashMap(); - - // prefixed STOMP headers - for (int i=0; i < stompHeaderNames[0].length; i++) { - String header = stompHeaderNames[0][i]; - if (stompHeaders.containsKey(header)) { - String prefixedHeader = stompHeaderNames[1][i]; - headers.put(prefixedHeader, stompHeaders.getFirst(header)); - } - } - - // for generic use (not-prefixed) - if (stompHeaders.getDestination() != null) { - headers.put("destination", stompHeaders.getDestination()); - } - if (stompHeaders.getContentType() != null) { - headers.put("content-type", stompHeaders.getContentType()); - } - - return headers; - } - - public void fromMessageHeaders(MessageHeaders messageHeaders, StompHeaders stompHeaders) { - - // prefixed STOMP headers - for (int i=0; i < stompHeaderNames[0].length; i++) { - String prefixedHeader = stompHeaderNames[1][i]; - if (messageHeaders.containsKey(prefixedHeader)) { - String header = stompHeaderNames[0][i]; - stompHeaders.add(header, (String) messageHeaders.get(prefixedHeader)); - } - } - - // generic (not prefixed) - String destination = (String) messageHeaders.get("destination"); - if (destination != null) { - stompHeaders.setDestination(destination); - } - Object contentType = messageHeaders.get("content-type"); - if (contentType != null) { - if (contentType instanceof String) { - stompHeaders.setContentType(MediaType.valueOf((String) contentType)); - } - else if (contentType instanceof MediaType) { - stompHeaders.setContentType((MediaType) contentType); - } - else { - logger.warn("Invalid contentType class: " + contentType.getClass()); - } - } - } - -} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java index b120459309c..a8453307ce6 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java @@ -17,134 +17,160 @@ package org.springframework.web.messaging.stomp.support; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.util.List; +import java.nio.charset.Charset; +import java.util.Map; import java.util.Map.Entry; +import org.springframework.messaging.GenericMessage; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; import org.springframework.util.Assert; import org.springframework.web.messaging.stomp.StompCommand; -import org.springframework.web.messaging.stomp.StompException; +import org.springframework.web.messaging.stomp.StompConversionException; import org.springframework.web.messaging.stomp.StompHeaders; -import org.springframework.web.messaging.stomp.StompMessage; + /** * @author Gary Russell + * @author Rossen Stoyanchev * @since 4.0 - * */ public class StompMessageConverter { + private static final Charset STOMP_CHARSET = Charset.forName("UTF-8"); + public static final byte LF = 0x0a; public static final byte CR = 0x0d; private static final byte COLON = ':'; + /** - * @param bytes a complete STOMP message (without the trailing 0x00). + * @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String. */ - public StompMessage toStompMessage(Object stomp) { - Assert.state(stomp instanceof String || stomp instanceof byte[], "'stomp' must be String or byte[]"); - byte[] stompBytes = null; - if (stomp instanceof String) { - stompBytes = ((String) stomp).getBytes(StompMessage.CHARSET); + public Message toMessage(Object stompContent, String sessionId) { + + byte[] byteContent = null; + if (stompContent instanceof String) { + byteContent = ((String) stompContent).getBytes(STOMP_CHARSET); + } + else if (stompContent instanceof byte[]){ + byteContent = (byte[]) stompContent; } else { - stompBytes = (byte[]) stomp; + throw new IllegalArgumentException( + "stompContent is neither String nor byte[]: " + stompContent.getClass()); } - int totalLength = stompBytes.length; - if (stompBytes[totalLength-1] == 0) { + + int totalLength = byteContent.length; + if (byteContent[totalLength-1] == 0) { totalLength--; } - int payloadIndex = findPayloadStart(stompBytes); + + int payloadIndex = findIndexOfPayload(byteContent); if (payloadIndex == 0) { - throw new StompException("No command found"); + throw new StompConversionException("No command found"); } - String headerString = new String(stompBytes, 0, payloadIndex, StompMessage.CHARSET); - Parser parser = new Parser(headerString); - StompHeaders headers = new StompHeaders(); + + String headerContent = new String(byteContent, 0, payloadIndex, STOMP_CHARSET); + Parser parser = new Parser(headerContent); + // TODO: validate command and whether a payload is allowed StompCommand command = StompCommand.valueOf(parser.nextToken(LF).trim()); Assert.notNull(command, "No command found"); + + StompHeaders stompHeaders = new StompHeaders(command); + stompHeaders.setSessionId(sessionId); + while (parser.hasNext()) { String header = parser.nextToken(COLON); if (header != null) { if (parser.hasNext()) { String value = parser.nextToken(LF); - headers.add(header, value); + stompHeaders.getRawHeaders().put(header, value); } else { - throw new StompException("Parse exception for " + headerString); + throw new StompConversionException("Parse exception for " + headerContent); } } } + byte[] payload = new byte[totalLength - payloadIndex]; - System.arraycopy(stompBytes, payloadIndex, payload, 0, totalLength - payloadIndex); - return new StompMessage(command, headers, payload); - } + System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex); - public byte[] fromStompMessage(StompMessage message) { - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - StompHeaders headers = message.getHeaders(); - StompCommand command = message.getCommand(); - try { - outputStream.write(command.toString().getBytes("UTF-8")); - outputStream.write(LF); - for (Entry> entry : headers.entrySet()) { - String key = entry.getKey(); - key = replaceAllOutbound(key); - for (String value : entry.getValue()) { - outputStream.write(key.getBytes("UTF-8")); - outputStream.write(COLON); - value = replaceAllOutbound(value); - outputStream.write(value.getBytes("UTF-8")); - outputStream.write(LF); - } - } - outputStream.write(LF); - outputStream.write(message.getPayload()); - outputStream.write(0); - return outputStream.toByteArray(); - } - catch (IOException e) { - throw new StompException("Failed to serialize " + message, e); - } - } + stompHeaders.updateMessageHeaders(); - private String replaceAllOutbound(String key) { - return key.replaceAll("\\\\", "\\\\") - .replaceAll(":", "\\\\c") - .replaceAll("\n", "\\\\n") - .replaceAll("\r", "\\\\r"); + return createMessage(command, stompHeaders.getMessageHeaders(), payload); } - private int findPayloadStart(byte[] bytes) { + private int findIndexOfPayload(byte[] bytes) { int i; // ignore any leading EOL from the previous message for (i = 0; i < bytes.length; i++) { - if (bytes[i] != '\n' && bytes[i] != '\r' ) { + if (bytes[i] != '\n' && bytes[i] != '\r') { break; } bytes[i] = ' '; } - int payloadOffset = 0; + int index = 0; for (; i < bytes.length - 1; i++) { - if ((bytes[i] == LF && bytes[i+1] == LF)) { - payloadOffset = i + 2; + if (bytes[i] == LF && bytes[i+1] == LF) { + index = i + 2; break; } - if (i < bytes.length - 3 && - (bytes[i] == CR && bytes[i+1] == LF && - bytes[i+2] == CR && bytes[i+3] == LF)) { - payloadOffset = i + 4; + if ((i < (bytes.length - 3)) && + (bytes[i] == CR && bytes[i+1] == LF && bytes[i+2] == CR && bytes[i+3] == LF)) { + index = i + 4; break; } } if (i >= bytes.length) { - throw new StompException("No end of headers found"); + throw new StompConversionException("No end of headers found"); } - return payloadOffset; + return index; + } + + protected Message createMessage(StompCommand command, Map headers, byte[] payload) { + return new GenericMessage(payload, headers); } + public byte[] fromMessage(Message message) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + MessageHeaders messageHeaders = message.getHeaders(); + StompHeaders stompHeaders = new StompHeaders(messageHeaders, false); + stompHeaders.updateRawHeaders(); + try { + out.write(stompHeaders.getStompCommand().toString().getBytes("UTF-8")); + out.write(LF); + for (Entry entry : stompHeaders.getRawHeaders().entrySet()) { + String key = entry.getKey(); + key = replaceAllOutbound(key); + String value = entry.getValue(); + out.write(key.getBytes("UTF-8")); + out.write(COLON); + value = replaceAllOutbound(value); + out.write(value.getBytes("UTF-8")); + out.write(LF); + } + out.write(LF); + out.write(message.getPayload()); + out.write(0); + return out.toByteArray(); + } + catch (IOException e) { + throw new StompConversionException("Failed to serialize " + message, e); + } + } + + private String replaceAllOutbound(String key) { + return key.replaceAll("\\\\", "\\\\") + .replaceAll(":", "\\\\c") + .replaceAll("\n", "\\\\n") + .replaceAll("\r", "\\\\r"); + } + + private class Parser { private final String content; @@ -177,7 +203,7 @@ public class StompMessageConverter { return null; } else { - throw new StompException("No delimiter found at offset " + offset + " in " + this.content); + throw new StompConversionException("No delimiter found at offset " + offset + " in " + this.content); } } int escapeAt = this.content.indexOf('\\', this.offset); @@ -192,7 +218,7 @@ public class StompMessageConverter { .replaceAll("\\\\\\\\", "\\\\"); } else { - throw new StompException("Invalid escape sequence \\" + escaped); + throw new StompConversionException("Invalid escape sequence \\" + escaped); } } int length = token.length(); diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/StompMessageInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/DestinationMessage.java similarity index 64% rename from spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/StompMessageInterceptor.java rename to spring-websocket/src/main/java/org/springframework/web/messaging/support/DestinationMessage.java index c72407f10d1..8daa9b37ec9 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/StompMessageInterceptor.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/DestinationMessage.java @@ -14,25 +14,29 @@ * limitations under the License. */ -package org.springframework.web.messaging.stomp.socket; +package org.springframework.web.messaging.support; -import org.springframework.web.messaging.stomp.StompMessage; +import java.util.Map; + +import org.springframework.messaging.GenericMessage; /** + * * @author Rossen Stoyanchev * @since 4.0 */ -public interface StompMessageInterceptor { +public class DestinationMessage extends GenericMessage { - boolean handleConnect(StompMessage message); - boolean handleSubscribe(StompMessage message); + public DestinationMessage(T payload, Map headers) { + super(payload, headers); + } - boolean handleUnsubscribe(StompMessage message); + public DestinationMessage(T payload) { + super(payload); + } - StompMessage handleSend(StompMessage message); - void handleDisconnect(); } diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java new file mode 100644 index 00000000000..f359100c954 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java @@ -0,0 +1,138 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.messaging.stomp.support; + +import java.util.Collections; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; +import org.springframework.web.messaging.MessageType; +import org.springframework.web.messaging.stomp.StompHeaders; +import org.springframework.web.messaging.stomp.StompCommand; + +import static org.junit.Assert.*; + +/** + * @author Gary Russell + * @author Rossen Stoyanchev + */ +public class StompMessageConverterTests { + + private StompMessageConverter converter; + + + @Before + public void setup() { + this.converter = new StompMessageConverter(); + } + + @Test + public void connectFrame() throws Exception { + + String accept = "accept-version:1.1\n"; + String host = "host:github.org\n"; + String frame = "\n\n\nCONNECT\n" + accept + host + "\n"; + Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); + + assertEquals(0, message.getPayload().length); + + MessageHeaders messageHeaders = message.getHeaders(); + StompHeaders stompHeaders = new StompHeaders(messageHeaders, true); + assertEquals(6, stompHeaders.getMessageHeaders().size()); + assertEquals(MessageType.CONNECT, stompHeaders.getMessageType()); + assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand()); + assertEquals("session-123", stompHeaders.getSessionId()); + assertNotNull(messageHeaders.get(MessageHeaders.ID)); + assertNotNull(messageHeaders.get(MessageHeaders.TIMESTAMP)); + assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); + assertEquals("github.org", stompHeaders.getRawHeaders().get("host")); + + String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); + + assertEquals("CONNECT\n", convertedBack.substring(0,8)); + assertTrue(convertedBack.contains(accept)); + assertTrue(convertedBack.contains(host)); + } + + @Test + public void connectWithEscapes() throws Exception { + + String accept = "accept-version:1.1\n"; + String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; + String frame = "CONNECT\n" + accept + host + "\n"; + Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); + + assertEquals(0, message.getPayload().length); + + MessageHeaders messageHeaders = message.getHeaders(); + StompHeaders stompHeaders = new StompHeaders(messageHeaders, true); + assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); + assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getRawHeaders().get("ho:\ns\rt")); + + String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); + + assertEquals("CONNECT\n", convertedBack.substring(0,8)); + assertTrue(convertedBack.contains(accept)); + assertTrue(convertedBack.contains(host)); + } + + @Test + public void connectCR12() throws Exception { + + String accept = "accept-version:1.2\n"; + String host = "host:github.org\n"; + String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; + Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); + + assertEquals(0, message.getPayload().length); + + MessageHeaders messageHeaders = message.getHeaders(); + StompHeaders stompHeaders = new StompHeaders(messageHeaders, true); + assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion()); + assertEquals("github.org", stompHeaders.getRawHeaders().get("host")); + + String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); + + assertEquals("CONNECT\n", convertedBack.substring(0,8)); + assertTrue(convertedBack.contains(accept)); + assertTrue(convertedBack.contains(host)); + } + + @Test + public void connectWithEscapesAndCR12() throws Exception { + + String accept = "accept-version:1.1\n"; + String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; + String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; + Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); + + assertEquals(0, message.getPayload().length); + + MessageHeaders messageHeaders = message.getHeaders(); + StompHeaders stompHeaders = new StompHeaders(messageHeaders, true); + assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); + assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getRawHeaders().get("ho:\ns\rt")); + + String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); + + assertEquals("CONNECT\n", convertedBack.substring(0,8)); + assertTrue(convertedBack.contains(accept)); + assertTrue(convertedBack.contains(host)); + } + +}