diff --git a/spring-context/src/main/java/org/springframework/messaging/MessageHeaders.java b/spring-context/src/main/java/org/springframework/messaging/MessageHeaders.java
index 21ae96dd23a..7fb6a26aba0 100644
--- a/spring-context/src/main/java/org/springframework/messaging/MessageHeaders.java
+++ b/spring-context/src/main/java/org/springframework/messaging/MessageHeaders.java
@@ -37,8 +37,9 @@ import org.apache.commons.logging.LogFactory;
* The headers for a {@link Message}.
* IMPORTANT: MessageHeaders are immutable. Any mutating operation (e.g., put(..), putAll(..) etc.)
* will result in {@link UnsupportedOperationException}
+ *
*
- * TODO: update javadoc
+ * TODO: update below instructions
*
*
To create MessageHeaders instance use fluent MessageBuilder API
*
@@ -76,16 +77,7 @@ public class MessageHeaders implements Map, Serializable {
public static final String TIMESTAMP = "timestamp";
- public static final String REPLY_CHANNEL = "replyChannel";
-
- public static final String ERROR_CHANNEL = "errorChannel";
-
- public static final String CONTENT_TYPE = "content-type";
-
- // DESTINATION ?
-
- public static final List HEADER_NAMES =
- Arrays.asList(ID, TIMESTAMP, REPLY_CHANNEL, ERROR_CHANNEL, CONTENT_TYPE);
+ public static final List HEADER_NAMES = Arrays.asList(ID, TIMESTAMP);
private final Map headers;
@@ -111,14 +103,6 @@ public class MessageHeaders implements Map, Serializable {
return this.get(TIMESTAMP, Long.class);
}
- public Object getReplyChannel() {
- return this.get(REPLY_CHANNEL);
- }
-
- public Object getErrorChannel() {
- return this.get(ERROR_CHANNEL);
- }
-
@SuppressWarnings("unchecked")
public T get(Object key, Class type) {
Object value = this.headers.get(key);
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java
new file mode 100644
index 00000000000..0e020779c45
--- /dev/null
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java
@@ -0,0 +1,163 @@
+/*
+ * Copyright 2002-2013 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.messaging;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.springframework.http.MediaType;
+import org.springframework.messaging.MessageHeaders;
+import org.springframework.util.CollectionUtils;
+
+
+/**
+ *
+ * @author Rossen Stoyanchev
+ * @since 4.0
+ */
+public class PubSubHeaders {
+
+ private static final String DESTINATIONS = "destinations";
+
+ private static final String CONTENT_TYPE = "contentType";
+
+ private static final String MESSAGE_TYPE = "messageType";
+
+ private static final String SUBSCRIPTION_ID = "subscriptionId";
+
+ private static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType";
+
+ private static final String SESSION_ID = "sessionId";
+
+ private static final String RAW_HEADERS = "rawHeaders";
+
+
+ private final Map messageHeaders;
+
+ private final Map rawHeaders;
+
+
+ /**
+ * Constructor for building new headers.
+ *
+ * @param messageType the message type
+ * @param protocolMessageType the protocol-specific message type or command
+ */
+ public PubSubHeaders(MessageType messageType, Object protocolMessageType) {
+
+ this.messageHeaders = new HashMap();
+ this.messageHeaders.put(MESSAGE_TYPE, messageType);
+ if (protocolMessageType != null) {
+ this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
+ }
+
+ this.rawHeaders = new HashMap();
+ this.messageHeaders.put(RAW_HEADERS, this.rawHeaders);
+ }
+
+ public PubSubHeaders() {
+ this(MessageType.MESSAGE, null);
+ }
+
+ /**
+ * Constructor for access to existing {@link MessageHeaders}.
+ *
+ * @param messageHeaders
+ */
+ @SuppressWarnings("unchecked")
+ public PubSubHeaders(MessageHeaders messageHeaders, boolean readOnly) {
+
+ this.messageHeaders = readOnly ? messageHeaders : new HashMap(messageHeaders);
+ this.rawHeaders = this.messageHeaders.containsKey(RAW_HEADERS) ?
+ (Map) messageHeaders.get(RAW_HEADERS) : new HashMap();
+
+ if (this.messageHeaders.get(MESSAGE_TYPE) == null) {
+ this.messageHeaders.put(MESSAGE_TYPE, MessageType.MESSAGE);
+ }
+ }
+
+
+ public Map getMessageHeaders() {
+ return this.messageHeaders;
+ }
+
+ public Map getRawHeaders() {
+ return this.rawHeaders;
+ }
+
+ public MessageType getMessageType() {
+ return (MessageType) this.messageHeaders.get(MESSAGE_TYPE);
+ }
+
+ public void setProtocolMessageType(Object protocolMessageType) {
+ this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
+ }
+
+ public Object getProtocolMessageType() {
+ return this.messageHeaders.get(PROTOCOL_MESSAGE_TYPE);
+ }
+
+ public void setDestination(String destination) {
+ this.messageHeaders.put(DESTINATIONS, Arrays.asList(destination));
+ }
+
+ public String getDestination() {
+ @SuppressWarnings("unchecked")
+ List destination = (List) messageHeaders.get(DESTINATIONS);
+ return CollectionUtils.isEmpty(destination) ? null : destination.get(0);
+ }
+
+ @SuppressWarnings("unchecked")
+ public List getDestinations() {
+ return (List) messageHeaders.get(DESTINATIONS);
+ }
+
+ public void setDestinations(List destinations) {
+ if (destinations != null) {
+ this.messageHeaders.put(DESTINATIONS, destinations);
+ }
+ }
+
+ public MediaType getContentType() {
+ return (MediaType) this.messageHeaders.get(CONTENT_TYPE);
+ }
+
+ public void setContentType(MediaType mediaType) {
+ if (mediaType != null) {
+ this.messageHeaders.put(CONTENT_TYPE, mediaType);
+ }
+ }
+
+ public String getSubscriptionId() {
+ return (String) this.messageHeaders.get(SUBSCRIPTION_ID);
+ }
+
+ public void setSubscriptionId(String subscriptionId) {
+ this.messageHeaders.put(SUBSCRIPTION_ID, subscriptionId);
+ }
+
+ public String getSessionId() {
+ return (String) this.messageHeaders.get(SESSION_ID);
+ }
+
+ public void setSessionId(String sessionId) {
+ this.messageHeaders.put(SESSION_ID, sessionId);
+ }
+
+}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/event/ReactorEventBus.java b/spring-websocket/src/main/java/org/springframework/web/messaging/event/ReactorEventBus.java
index 8c782250081..a6e9ce600fc 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/event/ReactorEventBus.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/event/ReactorEventBus.java
@@ -16,6 +16,9 @@
package org.springframework.web.messaging.event;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
import reactor.core.Reactor;
import reactor.fn.Consumer;
import reactor.fn.Event;
@@ -28,6 +31,8 @@ import reactor.fn.selector.ObjectSelector;
*/
public class ReactorEventBus implements EventBus {
+ private static Log logger = LogFactory.getLog(ReactorEventBus.class);
+
private final Reactor reactor;
@@ -37,6 +42,9 @@ public class ReactorEventBus implements EventBus {
@Override
public void send(String key, Object data) {
+ if (logger.isTraceEnabled()) {
+ logger.trace("Sending notification key=" + key + ", data=" + data);
+ }
this.reactor.notify(key, Event.wrap(data));
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java
index 68a26d00d1a..d8a81b662d3 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java
@@ -27,6 +27,7 @@ import org.springframework.util.AntPathMatcher;
import org.springframework.util.Assert;
import org.springframework.util.PathMatcher;
import org.springframework.web.messaging.MessageType;
+import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.event.EventConsumer;
@@ -37,11 +38,14 @@ import org.springframework.web.messaging.event.EventConsumer;
*/
public abstract class AbstractMessageService {
- public static final String MESSAGE_KEY = "messageKey";
+ protected final Log logger = LogFactory.getLog(getClass());
+
+
+ public static final String CLIENT_TO_SERVER_MESSAGE_KEY = "clientToServerMessageKey";
public static final String CLIENT_CONNECTION_CLOSED_KEY = "clientConnectionClosed";
- protected final Log logger = LogFactory.getLog(getClass());
+ public static final String SERVER_TO_CLIENT_MESSAGE_KEY = "serverToClientMessageKey";
private final EventBus eventBus;
@@ -58,7 +62,7 @@ public abstract class AbstractMessageService {
Assert.notNull(reactor, "reactor is required");
this.eventBus = reactor;
- this.eventBus.registerConsumer(MESSAGE_KEY, new EventConsumer>() {
+ this.eventBus.registerConsumer(CLIENT_TO_SERVER_MESSAGE_KEY, new EventConsumer>() {
@Override
public void accept(Message> message) {
@@ -124,7 +128,8 @@ public abstract class AbstractMessageService {
}
private boolean isAllowedDestination(Message> message) {
- String destination = (String) message.getHeaders().get("destination");
+ PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
+ String destination = headers.getDestination();
if (destination == null) {
return true;
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java
index a1a96c0a9ca..963a978368b 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java
@@ -17,14 +17,13 @@
package org.springframework.web.messaging.service;
import java.util.ArrayList;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
-import org.springframework.http.MediaType;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
+import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.event.EventBus;
@@ -61,26 +60,21 @@ public class PubSubMessageService extends AbstractMessageService {
logger.debug("Message received: " + message);
}
- Map headers = new HashMap();
- headers.put("destination", message.getHeaders().get("destination"));
-
- MediaType contentType = (MediaType) message.getHeaders().get("content-type");
- headers.put("content-type", contentType);
-
try {
// Convert to byte[] payload before the fan-out
- byte[] payload = payloadConverter.convertToPayload(message.getPayload(), contentType);
- message = new GenericMessage(payload, headers);
+ PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
+ byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType());
+ message = new GenericMessage(payload, message.getHeaders());
- getEventBus().send(getPublishKey(message), message);
+ getEventBus().send(getPublishKey(inHeaders.getDestination()), message);
}
catch (Exception ex) {
logger.error("Failed to publish " + message, ex);
}
}
- private String getPublishKey(Message> message) {
- return "destination:" + (String) message.getHeaders().get("destination");
+ private String getPublishKey(String destination) {
+ return "destination:" + destination;
}
@Override
@@ -88,12 +82,20 @@ public class PubSubMessageService extends AbstractMessageService {
if (logger.isDebugEnabled()) {
logger.debug("Subscribe " + message);
}
- final String replyKey = (String) message.getHeaders().getReplyChannel();
- EventRegistration registration = getEventBus().registerConsumer(getPublishKey(message),
+ PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
+ final String subscriptionId = headers.getSubscriptionId();
+ EventRegistration registration = getEventBus().registerConsumer(getPublishKey(headers.getDestination()),
new EventConsumer>() {
@Override
public void accept(Message> message) {
- getEventBus().send(replyKey, message);
+ PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
+ PubSubHeaders outHeaders = new PubSubHeaders();
+ outHeaders.setDestinations(inHeaders.getDestinations());
+ outHeaders.setContentType(inHeaders.getContentType());
+ outHeaders.setSubscriptionId(subscriptionId);
+ Object payload = message.getPayload();
+ message = new GenericMessage(payload, outHeaders.getMessageHeaders());
+ getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, message);
}
});
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java
index 570ec633c4a..55ab59ea86b 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java
@@ -35,6 +35,7 @@ import org.springframework.messaging.annotation.MessageMapping;
import org.springframework.stereotype.Controller;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils.MethodFilter;
+import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.annotation.SubscribeEvent;
import org.springframework.web.messaging.annotation.UnsubscribeEvent;
import org.springframework.web.messaging.converter.MessageConverter;
@@ -166,7 +167,8 @@ public class AnnotationMessageService extends AbstractMessageService implements
private void handleMessage(final Message> message, Map handlerMethods) {
- String destination = (String) message.getHeaders().get("destination");
+ PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
+ String destination = headers.getDestination();
HandlerMethod match = getHandlerMethod(destination, handlerMethods);
if (match == null) {
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java
index 728ce0261ca..ad16b5f97c0 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java
@@ -16,16 +16,11 @@
package org.springframework.web.messaging.service.method;
-import java.util.HashMap;
-import java.util.Map;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
-import org.springframework.web.messaging.MessageType;
+import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.service.AbstractMessageService;
@@ -38,8 +33,6 @@ import reactor.util.Assert;
*/
public class MessageChannelArgumentResolver implements ArgumentResolver {
- private static Log logger = LogFactory.getLog(MessageChannelArgumentResolver.class);
-
private final EventBus eventBus;
@@ -56,24 +49,18 @@ public class MessageChannelArgumentResolver implements ArgumentResolver {
@Override
public Object resolveArgument(MethodParameter parameter, Message> message) throws Exception {
- final String sessionId = (String) message.getHeaders().get("sessionId");
+ final String sessionId = new PubSubHeaders(message.getHeaders(), true).getSessionId();
return new MessageChannel() {
@Override
public boolean send(Message> message) {
- Map headers = new HashMap(message.getHeaders());
- headers.put("messageType", MessageType.MESSAGE);
- headers.put("sessionId", sessionId);
- message = new GenericMessage(message.getPayload(), headers);
-
- if (logger.isTraceEnabled()) {
- logger.trace("Sending notification: " + message);
- }
+ PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), false);
+ headers.setSessionId(sessionId);
+ message = new GenericMessage(message.getPayload(), headers.getMessageHeaders());
- String key = AbstractMessageService.MESSAGE_KEY;
- MessageChannelArgumentResolver.this.eventBus.send(key, message);
+ eventBus.send(AbstractMessageService.CLIENT_TO_SERVER_MESSAGE_KEY, message);
return true;
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java
index b5f0e1efd10..4a357b94d61 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java
@@ -16,11 +16,12 @@
package org.springframework.web.messaging.service.method;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.springframework.core.MethodParameter;
+import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
+import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.event.EventBus;
+import org.springframework.web.messaging.service.AbstractMessageService;
import reactor.util.Assert;
@@ -31,8 +32,6 @@ import reactor.util.Assert;
*/
public class MessageReturnValueHandler implements ReturnValueHandler {
- private static Log logger = LogFactory.getLog(MessageReturnValueHandler.class);
-
private final EventBus eventBus;
@@ -67,13 +66,17 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
return;
}
- String replyTo = (String) message.getHeaders().getReplyChannel();
- Assert.notNull(replyTo, "Cannot reply to: " + message);
+ PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
+ String sessionId = inHeaders.getSessionId();
+ String subscriptionId = inHeaders.getSubscriptionId();
+ Assert.notNull(subscriptionId, "No subscription id: " + message);
- if (logger.isTraceEnabled()) {
- logger.trace("Sending notification: " + message);
- }
- this.eventBus.send(replyTo, returnMessage);
+ PubSubHeaders outHeaders = new PubSubHeaders(returnMessage.getHeaders(), false);
+ outHeaders.setSessionId(sessionId);
+ outHeaders.setSubscriptionId(subscriptionId);
+ returnMessage = new GenericMessage(returnMessage.getPayload(), outHeaders.getMessageHeaders());
+
+ this.eventBus.send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, returnMessage);
}
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompCommand.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompCommand.java
index cb99fd9e34d..95ca8d17751 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompCommand.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompCommand.java
@@ -16,6 +16,11 @@
package org.springframework.web.messaging.stomp;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.springframework.web.messaging.MessageType;
+
/**
*
@@ -43,4 +48,21 @@ public enum StompCommand {
RECEIPT,
ERROR;
+
+ private static Map commandToMessageType = new HashMap();
+
+ static {
+ commandToMessageType.put(StompCommand.CONNECT, MessageType.CONNECT);
+ commandToMessageType.put(StompCommand.STOMP, MessageType.CONNECT);
+ commandToMessageType.put(StompCommand.SEND, MessageType.MESSAGE);
+ commandToMessageType.put(StompCommand.SUBSCRIBE, MessageType.SUBSCRIBE);
+ commandToMessageType.put(StompCommand.UNSUBSCRIBE, MessageType.UNSUBSCRIBE);
+ commandToMessageType.put(StompCommand.DISCONNECT, MessageType.DISCONNECT);
+ }
+
+ public MessageType getMessageType() {
+ MessageType messageType = commandToMessageType.get(this);
+ return (messageType != null) ? messageType : MessageType.OTHER;
+ }
+
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompException.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompConversionException.java
similarity index 82%
rename from spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompException.java
rename to spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompConversionException.java
index 7f2915aa9e7..bbc951b9a17 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompException.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompConversionException.java
@@ -20,17 +20,16 @@ import org.springframework.core.NestedRuntimeException;
/**
* @author Gary Russell
* @since 4.0
- *
*/
@SuppressWarnings("serial")
-public class StompException extends NestedRuntimeException {
+public class StompConversionException extends NestedRuntimeException {
- public StompException(String msg, Throwable cause) {
+ public StompConversionException(String msg, Throwable cause) {
super(msg, cause);
}
- public StompException(String msg) {
+ public StompConversionException(String msg) {
super(msg);
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java
index d8e5535de4c..eb6e0fafdf1 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java
@@ -16,42 +16,32 @@
package org.springframework.web.messaging.stomp;
-import java.io.Serializable;
-import java.util.Arrays;
-import java.util.Collection;
import java.util.Collections;
-import java.util.LinkedHashMap;
-import java.util.LinkedList;
import java.util.List;
-import java.util.Map;
import java.util.Set;
import org.springframework.http.MediaType;
-import org.springframework.util.Assert;
-import org.springframework.util.MultiValueMap;
+import org.springframework.messaging.MessageHeaders;
import org.springframework.util.StringUtils;
+import org.springframework.web.messaging.PubSubHeaders;
+
+import reactor.util.Assert;
/**
+ * STOMP adapter for {@link MessageHeaders}.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
-public class StompHeaders implements MultiValueMap, Serializable {
-
- private static final long serialVersionUID = 1L;
+public class StompHeaders extends PubSubHeaders {
- // TODO: separate client from server headers so they can't be mixed
-
- // Client
private static final String ID = "id";
private static final String HOST = "host";
private static final String ACCEPT_VERSION = "accept-version";
- // Server
-
private static final String MESSAGE_ID = "message-id";
private static final String RECEIPT_ID = "receipt-id";
@@ -62,8 +52,6 @@ public class StompHeaders implements MultiValueMap, Serializable
private static final String MESSAGE = "message";
- // Client and Server
-
private static final String ACK = "ack";
private static final String DESTINATION = "destination";
@@ -75,270 +63,167 @@ public class StompHeaders implements MultiValueMap, Serializable
private static final String HEARTBEAT = "heart-beat";
- public static final List STANDARD_HEADER_NAMES =
- Arrays.asList(ID, HOST, ACCEPT_VERSION, MESSAGE_ID, RECEIPT_ID, SUBSCRIPTION,
- VERSION, MESSAGE, ACK, DESTINATION, CONTENT_LENGTH, CONTENT_TYPE, HEARTBEAT);
-
-
- private final Map> headers;
-
-
/**
- * Private constructor that can create read-only {@code StompHeaders} instances.
+ * Constructor for building new headers.
+ *
+ * @param command the STOMP command
*/
- private StompHeaders(Map> headers, boolean readOnly) {
- Assert.notNull(headers, "'headers' must not be null");
- if (readOnly) {
- Map> map = new LinkedHashMap>(headers.size());
- for (Entry> entry : headers.entrySet()) {
- List values = Collections.unmodifiableList(entry.getValue());
- map.put(entry.getKey(), values);
- }
- this.headers = Collections.unmodifiableMap(map);
- }
- else {
- this.headers = headers;
- }
+ public StompHeaders(StompCommand command) {
+ super(command.getMessageType(), command);
}
/**
- * Constructs a new, empty instance of the {@code StompHeaders} object.
+ * Constructor for access to existing {@link MessageHeaders}.
+ *
+ * @param messageHeaders the existing message headers
+ * @param readOnly whether the resulting instance will be used for read-only access,
+ * if {@code true}, then set methods will throw exceptions; if {@code false}
+ * they will work.
*/
- public StompHeaders() {
- this(new LinkedHashMap>(4), false);
+ public StompHeaders(MessageHeaders messageHeaders, boolean readOnly) {
+ super(messageHeaders, readOnly);
}
- /**
- * Returns {@code StompHeaders} object that can only be read, not written to.
- */
- public static StompHeaders readOnlyStompHeaders(StompHeaders headers) {
- return new StompHeaders(headers, true);
+ @Override
+ public StompCommand getProtocolMessageType() {
+ return (StompCommand) super.getProtocolMessageType();
+ }
+
+ public StompCommand getStompCommand() {
+ return (StompCommand) super.getProtocolMessageType();
}
public Set getAcceptVersion() {
- String rawValue = getFirst(ACCEPT_VERSION);
+ String rawValue = getRawHeaders().get(ACCEPT_VERSION);
return (rawValue != null) ? StringUtils.commaDelimitedListToSet(rawValue) : Collections.emptySet();
}
public void setAcceptVersion(String acceptVersion) {
- set(ACCEPT_VERSION, acceptVersion);
+ getRawHeaders().put(ACCEPT_VERSION, acceptVersion);
}
- public String getVersion() {
- return getFirst(VERSION);
- }
-
- public void setVersion(String version) {
- set(VERSION, version);
- }
-
- public String getDestination() {
- return getFirst(DESTINATION);
+ @Override
+ public void setDestination(String destination) {
+ if (destination != null) {
+ super.setDestination(destination);
+ getRawHeaders().put(DESTINATION, destination);
+ }
}
- public void setDestination(String destination) {
- set(DESTINATION, destination);
+ @Override
+ public void setDestinations(List destinations) {
+ if (destinations != null) {
+ super.setDestinations(destinations);
+ getRawHeaders().put(DESTINATION, destinations.get(0));
+ }
}
- public MediaType getContentType() {
- String contentType = getFirst(CONTENT_TYPE);
- return StringUtils.hasText(contentType) ? MediaType.valueOf(contentType) : null;
+ public long[] getHeartbeat() {
+ String rawValue = getRawHeaders().get(HEARTBEAT);
+ if (!StringUtils.hasText(rawValue)) {
+ return null;
+ }
+ String[] rawValues = StringUtils.commaDelimitedListToStringArray(rawValue);
+ // TODO assertions
+ return new long[] { Long.valueOf(rawValues[0]), Long.valueOf(rawValues[1])};
}
public void setContentType(MediaType mediaType) {
if (mediaType != null) {
- set(CONTENT_TYPE, mediaType.toString());
- }
- else {
- remove(CONTENT_TYPE);
+ super.setContentType(mediaType);
+ getRawHeaders().put(CONTENT_TYPE, mediaType.toString());
}
}
public Integer getContentLength() {
- String contentLength = getFirst(CONTENT_LENGTH);
+ String contentLength = getRawHeaders().get(CONTENT_LENGTH);
return StringUtils.hasText(contentLength) ? new Integer(contentLength) : null;
}
public void setContentLength(int contentLength) {
- set(CONTENT_LENGTH, String.valueOf(contentLength));
+ getRawHeaders().put(CONTENT_LENGTH, String.valueOf(contentLength));
}
- public long[] getHeartbeat() {
- String rawValue = getFirst(HEARTBEAT);
- if (!StringUtils.hasText(rawValue)) {
- return null;
- }
- String[] rawValues = StringUtils.commaDelimitedListToStringArray(rawValue);
- // TODO assertions
- return new long[] { Long.valueOf(rawValues[0]), Long.valueOf(rawValues[1])};
+ @Override
+ public String getSubscriptionId() {
+ return StompCommand.SUBSCRIBE.equals(getStompCommand()) ? getRawHeaders().get(ID) : null;
+ }
+
+ @Override
+ public void setSubscriptionId(String subscriptionId) {
+ Assert.isTrue(StompCommand.MESSAGE.equals(getStompCommand()),
+ "\"subscription\" can only be set on a STOMP MESSAGE frame");
+ super.setSubscriptionId(subscriptionId);
+ getRawHeaders().put(SUBSCRIPTION, subscriptionId);
}
public void setHeartbeat(long cx, long cy) {
- set(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy}));
+ getRawHeaders().put(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy}));
}
- public String getId() {
- return getFirst(ID);
+ public String getMessage() {
+ return getRawHeaders().get(MESSAGE);
}
- public void setId(String id) {
- set(ID, id);
+ public void setMessage(String content) {
+ getRawHeaders().put(MESSAGE, content);
}
public String getMessageId() {
- return getFirst(MESSAGE_ID);
+ return getRawHeaders().get(MESSAGE_ID);
}
public void setMessageId(String id) {
- set(MESSAGE_ID, id);
- }
-
- public String getSubscription() {
- return getFirst(SUBSCRIPTION);
- }
-
- public void setSubscription(String id) {
- set(SUBSCRIPTION, id);
+ getRawHeaders().put(MESSAGE_ID, id);
}
- public String getMessage() {
- return getFirst(MESSAGE);
+ public String getVersion() {
+ return getRawHeaders().get(VERSION);
}
- public void setMessage(String id) {
- set(MESSAGE, id);
+ public void setVersion(String version) {
+ getRawHeaders().put(VERSION, version);
}
- // MultiValueMap methods
-
/**
- * Return the first header value for the given header name, if any.
- * @param headerName the header name
- * @return the first header value; or {@code null}
+ * Update generic message headers from raw headers. This method only needs to be
+ * invoked when raw headers are added via {@link #getRawHeaders()}.
*/
- public String getFirst(String headerName) {
- List headerValues = headers.get(headerName);
- return headerValues != null ? headerValues.get(0) : null;
- }
-
- /**
- * Add the given, single header value under the given name.
- * @param headerName the header name
- * @param headerValue the header value
- * @throws UnsupportedOperationException if adding headers is not supported
- * @see #put(String, List)
- * @see #set(String, String)
- */
- public void add(String headerName, String headerValue) {
- List headerValues = headers.get(headerName);
- if (headerValues == null) {
- headerValues = new LinkedList();
- this.headers.put(headerName, headerValues);
+ public void updateMessageHeaders() {
+ String destination = getRawHeaders().get(DESTINATION);
+ if (destination != null) {
+ setDestination(destination);
+ }
+ String contentType = getRawHeaders().get(CONTENT_TYPE);
+ if (contentType != null) {
+ setContentType(MediaType.parseMediaType(contentType));
+ }
+ if (StompCommand.SUBSCRIBE.equals(getStompCommand())) {
+ if (getRawHeaders().get(ID) != null) {
+ super.setSubscriptionId(getRawHeaders().get(ID));
+ }
}
- headerValues.add(headerValue);
}
/**
- * Set the given, single header value under the given name.
- * @param headerName the header name
- * @param headerValue the header value
- * @throws UnsupportedOperationException if adding headers is not supported
- * @see #put(String, List)
- * @see #add(String, String)
+ * Update raw headers from generic message headers. This method only needs to be
+ * invoked if creating {@link StompHeaders} from {@link MessageHeaders} that never
+ * contained raw headers.
*/
- public void set(String headerName, String headerValue) {
- List headerValues = new LinkedList();
- headerValues.add(headerValue);
- headers.put(headerName, headerValues);
- }
-
- public void setAll(Map values) {
- for (Entry entry : values.entrySet()) {
- set(entry.getKey(), entry.getValue());
+ public void updateRawHeaders() {
+ String destination = getDestination();
+ if (destination != null) {
+ getRawHeaders().put(DESTINATION, destination);
}
- }
-
- public Map toSingleValueMap() {
- LinkedHashMap singleValueMap = new LinkedHashMap(this.headers.size());
- for (Entry> entry : headers.entrySet()) {
- singleValueMap.put(entry.getKey(), entry.getValue().get(0));
- }
- return singleValueMap;
- }
-
-
- // Map implementation
-
- public int size() {
- return this.headers.size();
- }
-
- public boolean isEmpty() {
- return this.headers.isEmpty();
- }
-
- public boolean containsKey(Object key) {
- return this.headers.containsKey(key);
- }
-
- public boolean containsValue(Object value) {
- return this.headers.containsValue(value);
- }
-
- public List get(Object key) {
- return this.headers.get(key);
- }
-
- public List put(String key, List value) {
- return this.headers.put(key, value);
- }
-
- public List remove(Object key) {
- return this.headers.remove(key);
- }
-
- public void putAll(Map extends String, ? extends List> m) {
- this.headers.putAll(m);
- }
-
- public void clear() {
- this.headers.clear();
- }
-
- public Set keySet() {
- return this.headers.keySet();
- }
-
- public Collection> values() {
- return this.headers.values();
- }
-
- public Set>> entrySet() {
- return this.headers.entrySet();
- }
-
-
- @Override
- public boolean equals(Object other) {
- if (this == other) {
- return true;
+ MediaType contentType = getContentType();
+ if (contentType != null) {
+ getRawHeaders().put(CONTENT_TYPE, contentType.toString());
}
- if (!(other instanceof StompHeaders)) {
- return false;
+ String subscriptionId = getSubscriptionId();
+ if (subscriptionId != null) {
+ getRawHeaders().put(SUBSCRIPTION, subscriptionId);
}
- StompHeaders otherHeaders = (StompHeaders) other;
- return this.headers.equals(otherHeaders.headers);
- }
-
- @Override
- public int hashCode() {
- return this.headers.hashCode();
- }
-
- @Override
- public String toString() {
- return this.headers.toString();
}
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompMessage.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompMessage.java
deleted file mode 100644
index 78bbe43efd4..00000000000
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompMessage.java
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright 2002-2013 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.web.messaging.stomp;
-
-import java.nio.charset.Charset;
-
-
-/**
- *
- * @author Rossen Stoyanchev
- * @since 4.0
- */
-public class StompMessage {
-
- public static final Charset CHARSET = Charset.forName("UTF-8");
-
- private final StompCommand command;
-
- private final StompHeaders headers;
-
- private final byte[] payload;
-
- private String sessionId;
-
-
- public StompMessage(StompCommand command, StompHeaders headers, byte[] payload) {
- this.command = command;
- this.headers = (headers != null) ? headers : new StompHeaders();
- this.payload = payload;
- }
-
- /**
- * Constructor for empty payload message.
- */
- public StompMessage(StompCommand command, StompHeaders headers) {
- this(command, headers, new byte[0]);
- }
-
- public StompCommand getCommand() {
- return this.command;
- }
-
- public StompHeaders getHeaders() {
- return this.headers;
- }
-
- public byte[] getPayload() {
- return this.payload;
- }
-
- public void setSessionId(String sessionId) {
- this.sessionId = sessionId;
- }
-
- public String getSessionId() {
- return this.sessionId;
- }
-
- @Override
- public String toString() {
- return "StompMessage [" + command + ", headers=" + this.headers + ", payload=" + new String(this.payload) + "]";
- }
-
-}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompSession.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompSession.java
deleted file mode 100644
index 1da941be2d8..00000000000
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompSession.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright 2002-2013 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.web.messaging.stomp;
-
-import java.io.IOException;
-
-
-/**
- * @author Rossen Stoyanchev
- * @since 4.0
- */
-public interface StompSession {
-
- String getId();
-
- /**
- * TODO...
- *
- * If the message is a STOMP ERROR message, the session will also be closed.
- */
- void sendMessage(StompMessage message) throws IOException;
-
- /**
- * Register a task to be invoked if the underlying connection is closed.
- */
- void registerConnectionClosedTask(Runnable task);
-
-}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java
index 43641d28937..d8e66221446 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java
@@ -15,14 +15,14 @@
*/
package org.springframework.web.messaging.stomp.socket;
+import java.nio.charset.Charset;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
-import org.springframework.util.Assert;
-import org.springframework.web.messaging.stomp.StompCommand;
+import org.springframework.messaging.GenericMessage;
+import org.springframework.messaging.Message;
import org.springframework.web.messaging.stomp.StompHeaders;
-import org.springframework.web.messaging.stomp.StompMessage;
-import org.springframework.web.messaging.stomp.StompSession;
+import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.support.StompMessageConverter;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
@@ -36,57 +36,65 @@ import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter;
*/
public abstract class AbstractStompWebSocketHandler extends TextWebSocketHandlerAdapter {
- private final StompMessageConverter messageConverter = new StompMessageConverter();
+ private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
- private final Map sessions = new ConcurrentHashMap();
+ private final Map sessions = new ConcurrentHashMap();
+ public StompMessageConverter getStompMessageConverter() {
+ return this.stompMessageConverter;
+ }
+
+ protected WebSocketSession getWebSocketSession(String sessionId) {
+ return this.sessions.get(sessionId);
+ }
+
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
- WebSocketStompSession stompSession = new WebSocketStompSession(session, this.messageConverter);
- this.sessions.put(session.getId(), stompSession);
+ this.sessions.put(session.getId(), session);
}
@Override
- protected void handleTextMessage(WebSocketSession session, TextMessage message) {
-
- StompSession stompSession = this.sessions.get(session.getId());
- Assert.notNull(stompSession, "No STOMP session for WebSocket session id=" + session.getId());
-
+ protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) {
try {
- StompMessage stompMessage = this.messageConverter.toStompMessage(message.getPayload());
- stompMessage.setSessionId(stompSession.getId());
+ String payload = textMessage.getPayload();
+ Message message = this.stompMessageConverter.toMessage(payload, session.getId());
// TODO: validate size limits
// http://stomp.github.io/stomp-specification-1.2.html#Size_Limits
- handleStompMessage(stompSession, stompMessage);
+ handleStompMessage(session, message);
// TODO: send RECEIPT message if incoming message has "receipt" header
// http://stomp.github.io/stomp-specification-1.2.html#Header_receipt
}
catch (Throwable error) {
- StompHeaders headers = new StompHeaders();
- headers.setMessage(error.getMessage());
- StompMessage errorMessage = new StompMessage(StompCommand.ERROR, headers);
- try {
- stompSession.sendMessage(errorMessage);
- }
- catch (Throwable t) {
- // ignore
- }
+ sendErrorMessage(session, error);
}
}
- protected abstract void handleStompMessage(StompSession stompSession, StompMessage stompMessage);
+ protected void sendErrorMessage(WebSocketSession session, Throwable error) {
+
+ StompHeaders stompHeaders = new StompHeaders(StompCommand.ERROR);
+ stompHeaders.setMessage(error.getMessage());
+
+ Message errorMessage = new GenericMessage(new byte[0], stompHeaders.getMessageHeaders());
+ byte[] bytes = this.stompMessageConverter.fromMessage(errorMessage);
+
+ try {
+ session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
+ }
+ catch (Throwable t) {
+ // ignore
+ }
+ }
+
+ protected abstract void handleStompMessage(WebSocketSession session, Message message);
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
- WebSocketStompSession stompSession = this.sessions.remove(session.getId());
- if (stompSession != null) {
- stompSession.handleConnectionClosed();
- }
+ this.sessions.remove(session.getId());
}
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java
index 506841fc9ce..a21cf4d0744 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java
@@ -16,32 +16,28 @@
package org.springframework.web.messaging.stomp.socket;
import java.io.IOException;
-import java.util.ArrayList;
+import java.nio.charset.Charset;
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.MediaType;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
-import org.springframework.messaging.MessageHeaders;
-import org.springframework.util.CollectionUtils;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.event.EventConsumer;
-import org.springframework.web.messaging.event.EventRegistration;
import org.springframework.web.messaging.service.AbstractMessageService;
import org.springframework.web.messaging.stomp.StompCommand;
-import org.springframework.web.messaging.stomp.StompException;
+import org.springframework.web.messaging.stomp.StompConversionException;
import org.springframework.web.messaging.stomp.StompHeaders;
-import org.springframework.web.messaging.stomp.StompMessage;
-import org.springframework.web.messaging.stomp.StompSession;
-import org.springframework.web.messaging.stomp.support.StompHeaderMapper;
+import org.springframework.web.socket.CloseStatus;
+import org.springframework.web.socket.TextMessage;
+import org.springframework.web.socket.WebSocketSession;
/**
* @author Gary Russell
@@ -57,14 +53,59 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
private MessageConverter payloadConverter = new CompositeMessageConverter(null);
- private final StompHeaderMapper headerMapper = new StompHeaderMapper();
-
- private Map> registrationsBySession =
- new ConcurrentHashMap>();
-
public DefaultStompWebSocketHandler(EventBus eventBus) {
+
this.eventBus = eventBus;
+
+ this.eventBus.registerConsumer(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY,
+ new EventConsumer>() {
+ @Override
+ public void accept(Message> message) {
+
+ StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), false);
+ if (stompHeaders.getProtocolMessageType() == null) {
+ stompHeaders.setProtocolMessageType(StompCommand.MESSAGE);
+ }
+
+ if (StompCommand.CONNECTED.equals(stompHeaders.getStompCommand())) {
+ // Ignore for now since we already sent it
+ return;
+ }
+
+ String sessionId = stompHeaders.getSessionId();
+ WebSocketSession session = getWebSocketSession(sessionId);
+
+ byte[] payload;
+ try {
+ MediaType contentType = stompHeaders.getContentType();
+ payload = payloadConverter.convertToPayload(message.getPayload(), contentType);
+ }
+ catch (Exception e) {
+ logger.error("Failed to send " + message, e);
+ return;
+ }
+
+ try {
+ Map messageHeaders = stompHeaders.getMessageHeaders();
+ Message byteMessage = new GenericMessage(payload, messageHeaders);
+ byte[] bytes = getStompMessageConverter().fromMessage(byteMessage);
+ session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
+ }
+ catch (Throwable t) {
+ sendErrorMessage(session, t);
+ }
+ finally {
+ if (StompCommand.ERROR.equals(stompHeaders.getStompCommand())) {
+ try {
+ session.close(CloseStatus.PROTOCOL_ERROR);
+ }
+ catch (IOException e) {
+ }
+ }
+ }
+ }
+ });
}
@@ -72,252 +113,83 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
this.payloadConverter = new CompositeMessageConverter(converters);
}
- public void handleStompMessage(final StompSession session, StompMessage stompMessage) {
+ public void handleStompMessage(final WebSocketSession session, Message message) {
if (logger.isTraceEnabled()) {
- logger.trace("Processing: " + stompMessage);
+ logger.trace("Processing: " + message);
}
try {
- MessageType messageType = MessageType.OTHER;
- String replyKey = null;
-
- StompCommand command = stompMessage.getCommand();
- if (StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command)) {
- session.registerConnectionClosedTask(new ConnectionClosedTask(session));
- messageType = MessageType.CONNECT;
- replyKey = handleConnect(session, stompMessage);
- }
- else if (StompCommand.SEND.equals(command)) {
- messageType = MessageType.MESSAGE;
- handleSend(session, stompMessage);
- }
- else if (StompCommand.SUBSCRIBE.equals(command)) {
- messageType = MessageType.SUBSCRIBE;
- replyKey = handleSubscribe(session, stompMessage);
+ StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), true);
+ MessageType messageType = stompHeaders.getMessageType();
+ if (MessageType.CONNECT.equals(messageType)) {
+ handleConnect(session, message);
}
- else if (StompCommand.UNSUBSCRIBE.equals(command)) {
- messageType = MessageType.UNSUBSCRIBE;
- handleUnsubscribe(session, stompMessage);
+ else if (MessageType.MESSAGE.equals(messageType)) {
+ handleMessage(message);
}
- else if (StompCommand.DISCONNECT.equals(command)) {
- messageType = MessageType.DISCONNECT;
- handleDisconnect(session, stompMessage);
+ else if (MessageType.SUBSCRIBE.equals(messageType)) {
+ handleSubscribe(message);
}
- else {
- sendErrorMessage(session, "Invalid STOMP command " + command);
- return;
+ else if (MessageType.UNSUBSCRIBE.equals(messageType)) {
+ handleUnsubscribe(message);
}
-
- Map messageHeaders = this.headerMapper.toMessageHeaders(stompMessage.getHeaders());
- messageHeaders.put("messageType", messageType);
- if (replyKey != null) {
- messageHeaders.put(MessageHeaders.REPLY_CHANNEL, replyKey);
+ else if (MessageType.DISCONNECT.equals(messageType)) {
+ handleDisconnect(message);
}
- messageHeaders.put("stompCommand", command);
- messageHeaders.put("sessionId", session.getId());
-
- Message genericMessage = new GenericMessage(stompMessage.getPayload(), messageHeaders);
-
- if (logger.isTraceEnabled()) {
- logger.trace("Sending notification: " + genericMessage);
- }
- this.eventBus.send(AbstractMessageService.MESSAGE_KEY, genericMessage);
+ this.eventBus.send(AbstractMessageService.CLIENT_TO_SERVER_MESSAGE_KEY, message);
}
catch (Throwable t) {
- handleError(session, t);
- }
- }
-
- private void handleError(final StompSession session, Throwable t) {
- logger.error("Terminating STOMP session due to failure to send message: ", t);
- sendErrorMessage(session, t.getMessage());
- if (removeSubscriptions(session)) {
- // TODO: send error event including exception info
+ logger.error("Terminating STOMP session due to failure to send message: ", t);
+ sendErrorMessage(session, t);
}
}
- private void sendErrorMessage(StompSession session, String errorText) {
- StompHeaders headers = new StompHeaders();
- headers.setMessage(errorText);
- StompMessage errorMessage = new StompMessage(StompCommand.ERROR, headers);
- try {
- session.sendMessage(errorMessage);
- }
- catch (Throwable t) {
- // ignore
- }
- }
+ protected void handleConnect(final WebSocketSession session, Message message) throws IOException {
- protected String handleConnect(final StompSession session, StompMessage stompMessage) throws IOException {
+ StompHeaders connectStompHeaders = new StompHeaders(message.getHeaders(), true);
+ StompHeaders connectedStompHeaders = new StompHeaders(StompCommand.CONNECTED);
- StompHeaders headers = new StompHeaders();
- Set acceptVersions = stompMessage.getHeaders().getAcceptVersion();
+ Set acceptVersions = connectStompHeaders.getAcceptVersion();
if (acceptVersions.contains("1.2")) {
- headers.setVersion("1.2");
+ connectedStompHeaders.setAcceptVersion("1.2");
}
else if (acceptVersions.contains("1.1")) {
- headers.setVersion("1.1");
+ connectedStompHeaders.setAcceptVersion("1.1");
}
else if (acceptVersions.isEmpty()) {
// 1.0
}
else {
- throw new StompException("Unsupported version '" + acceptVersions + "'");
+ throw new StompConversionException("Unsupported version '" + acceptVersions + "'");
}
- headers.setHeartbeat(0,0); // TODO
- headers.setId(session.getId());
+ connectedStompHeaders.setHeartbeat(0,0); // TODO
// TODO: security
- session.sendMessage(new StompMessage(StompCommand.CONNECTED, headers));
-
- String replyKey = "relay-message" + session.getId();
-
- EventRegistration registration = this.eventBus.registerConsumer(replyKey,
- new EventConsumer() {
- @Override
- public void accept(StompMessage message) {
- try {
- if (StompCommand.CONNECTED.equals(message.getCommand())) {
- // TODO: skip for now (we already sent CONNECTED)
- return;
- }
- if (logger.isTraceEnabled()) {
- logger.trace("Relaying back to client: " + message);
- }
- session.sendMessage(message);
- }
- catch (Throwable t) {
- handleError(session, t);
- }
- }
- });
- addRegistration(session, registration);
-
- return replyKey;
+ Message connectedMessage = new GenericMessage(new byte[0], connectedStompHeaders.getMessageHeaders());
+ byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage);
+ session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
- protected String handleSubscribe(final StompSession session, StompMessage message) {
-
- final String subscriptionId = message.getHeaders().getId();
- String replyKey = getSubscriptionReplyKey(session, subscriptionId);
-
- // TODO: extract and remember "ack" mode
- // http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE_ack_Header
-
- if (logger.isTraceEnabled()) {
- logger.trace("Adding subscription, key=" + replyKey);
- }
-
- EventRegistration registration = this.eventBus.registerConsumer(replyKey, new EventConsumer>() {
- @Override
- public void accept(Message> replyMessage) {
-
- StompHeaders headers = new StompHeaders();
- headers.setSubscription(subscriptionId);
-
- headerMapper.fromMessageHeaders(replyMessage.getHeaders(), headers);
-
- byte[] payload;
- try {
- MediaType contentType = headers.getContentType();
- payload = payloadConverter.convertToPayload(replyMessage.getPayload(), contentType);
- }
- catch (Exception e) {
- logger.error("Failed to send " + replyMessage, e);
- return;
- }
-
- try {
- StompMessage stompMessage = new StompMessage(StompCommand.MESSAGE, headers, payload);
- session.sendMessage(stompMessage);
- }
- catch (Throwable t) {
- handleError(session, t);
- }
- }
- });
- addRegistration(session, registration);
-
- return replyKey;
-
+ protected void handleSubscribe(Message message) {
// TODO: need a way to communicate back if subscription was successfully created or
// not in which case an ERROR should be sent back and close the connection
// http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE
}
- private String getSubscriptionReplyKey(StompSession session, String subscriptionId) {
- return StompCommand.SUBSCRIBE + ":" + session.getId() + ":" + subscriptionId;
- }
-
- private void addRegistration(StompSession session, EventRegistration registration) {
- String sessionId = session.getId();
- List list = this.registrationsBySession.get(sessionId);
- if (list == null) {
- list = new ArrayList();
- this.registrationsBySession.put(sessionId, list);
- }
- list.add(registration);
- }
-
- protected void handleUnsubscribe(StompSession session, StompMessage message) {
- cancelRegistration(session, message.getHeaders().getId());
- }
-
- private void cancelRegistration(StompSession session, String subscriptionId) {
- String key = getSubscriptionReplyKey(session, subscriptionId);
- List list = this.registrationsBySession.get(session.getId());
- for (EventRegistration registration : list) {
- if (registration.getRegistrationKey().equals(key)) {
- if (logger.isDebugEnabled()) {
- logger.debug("Cancelling subscription, key=" + key);
- }
- list.remove(registration);
- registration.cancel();
- }
- }
+ protected void handleUnsubscribe(Message message) {
}
- protected void handleSend(StompSession session, StompMessage stompMessage) {
+ protected void handleMessage(Message stompMessage) {
}
- protected void handleDisconnect(StompSession session, StompMessage stompMessage) {
- removeSubscriptions(session);
- }
-
- private boolean removeSubscriptions(StompSession session) {
- String sessionId = session.getId();
- List registrations = this.registrationsBySession.remove(sessionId);
- if (CollectionUtils.isEmpty(registrations)) {
- return false;
- }
- if (logger.isTraceEnabled()) {
- logger.trace("Cancelling " + registrations.size() + " subscriptions for session=" + sessionId);
- }
- for (EventRegistration registration : registrations) {
- registration.cancel();
- }
- return true;
+ protected void handleDisconnect(Message stompMessage) {
}
-
- private final class ConnectionClosedTask implements Runnable {
-
- private final StompSession session;
-
- private ConnectionClosedTask(StompSession session) {
- this.session = session;
- }
-
- @Override
- public void run() {
- removeSubscriptions(session);
- if (logger.isTraceEnabled()) {
- logger.trace("Sending notification for closed connection: " + session.getId());
- }
- eventBus.send(AbstractMessageService.CLIENT_CONNECTION_CLOSED_KEY, session.getId());
- }
+ @Override
+ public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
+ eventBus.send(AbstractMessageService.CLIENT_CONNECTION_CLOSED_KEY, session.getId());
}
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/WebSocketStompSession.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/WebSocketStompSession.java
deleted file mode 100644
index cbcc143b132..00000000000
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/WebSocketStompSession.java
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Copyright 2002-2013 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.web.messaging.stomp.socket;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
-import org.springframework.util.Assert;
-import org.springframework.web.messaging.stomp.StompCommand;
-import org.springframework.web.messaging.stomp.StompMessage;
-import org.springframework.web.messaging.stomp.StompSession;
-import org.springframework.web.messaging.stomp.support.StompMessageConverter;
-import org.springframework.web.socket.CloseStatus;
-import org.springframework.web.socket.TextMessage;
-import org.springframework.web.socket.WebSocketSession;
-
-
-/**
- * @author Rossen Stoyanchev
- * @since 4.0
- */
-public class WebSocketStompSession implements StompSession {
-
- private final String id;
-
- private WebSocketSession webSocketSession;
-
- private final StompMessageConverter messageConverter;
-
- private final List connectionClosedTasks = new ArrayList();
-
-
- public WebSocketStompSession(WebSocketSession webSocketSession, StompMessageConverter messageConverter) {
- Assert.notNull(webSocketSession, "webSocketSession is required");
- this.id = webSocketSession.getId();
- this.webSocketSession = webSocketSession;
- this.messageConverter = messageConverter;
- }
-
- @Override
- public String getId() {
- return this.id;
- }
-
- @Override
- public void sendMessage(StompMessage message) throws IOException {
-
- Assert.notNull(this.webSocketSession, "Cannot send message without active session");
-
- try {
- byte[] bytes = this.messageConverter.fromStompMessage(message);
- this.webSocketSession.sendMessage(new TextMessage(new String(bytes, StompMessage.CHARSET)));
- }
- finally {
- if (StompCommand.ERROR.equals(message.getCommand())) {
- this.webSocketSession.close(CloseStatus.PROTOCOL_ERROR);
- this.webSocketSession = null;
- }
- }
- }
-
- public void registerConnectionClosedTask(Runnable task) {
- this.connectionClosedTasks.add(task);
- }
-
- public void handleConnectionClosed() {
- for (Runnable task : this.connectionClosedTasks) {
- try {
- task.run();
- }
- catch (Throwable t) {
- // ignore
- }
- }
- }
-
-}
\ No newline at end of file
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java
index 33bbb274a59..bbe024514ea 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java
@@ -31,6 +31,7 @@ import javax.net.SocketFactory;
import org.springframework.core.task.TaskExecutor;
import org.springframework.http.MediaType;
+import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
@@ -38,7 +39,6 @@ import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.service.AbstractMessageService;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompHeaders;
-import org.springframework.web.messaging.stomp.StompMessage;
import reactor.util.Assert;
@@ -57,8 +57,6 @@ public class RelayStompService extends AbstractMessageService {
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
- private final StompHeaderMapper stompHeaderMapper = new StompHeaderMapper();
-
public RelayStompService(EventBus eventBus, TaskExecutor executor) {
super(eventBus);
@@ -84,8 +82,7 @@ public class RelayStompService extends AbstractMessageService {
forwardMessage(message, StompCommand.CONNECT);
- String replyTo = (String) message.getHeaders().getReplyChannel();
- RelayReadTask readTask = new RelayReadTask(sessionId, replyTo, session);
+ RelayReadTask readTask = new RelayReadTask(sessionId, session);
this.taskExecutor.execute(readTask);
}
catch (Throwable t) {
@@ -96,23 +93,25 @@ public class RelayStompService extends AbstractMessageService {
private void forwardMessage(Message> message, StompCommand command) {
- String sessionId = (String) message.getHeaders().get("sessionId");
+ StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), false);
+ String sessionId = stompHeaders.getSessionId();
RelaySession session = RelayStompService.this.relaySessions.get(sessionId);
Assert.notNull(session, "RelaySession not found");
try {
- StompHeaders stompHeaders = new StompHeaders();
- this.stompHeaderMapper.fromMessageHeaders(message.getHeaders(), stompHeaders);
+ if (stompHeaders.getProtocolMessageType() == null) {
+ stompHeaders.setProtocolMessageType(StompCommand.SEND);
+ }
MediaType contentType = stompHeaders.getContentType();
byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType);
- StompMessage stompMessage = new StompMessage(command, stompHeaders, payload);
+ Message byteMessage = new GenericMessage(payload, stompHeaders.getMessageHeaders());
if (logger.isTraceEnabled()) {
- logger.trace("Forwarding: " + stompMessage);
+ logger.trace("Forwarding: " + byteMessage);
}
- byte[] bytesToWrite = this.stompMessageConverter.fromStompMessage(stompMessage);
- session.getOutputStream().write(bytesToWrite);
+ byte[] bytes = this.stompMessageConverter.fromMessage(byteMessage);
+ session.getOutputStream().write(bytes);
session.getOutputStream().flush();
}
catch (Exception ex) {
@@ -200,13 +199,12 @@ public class RelayStompService extends AbstractMessageService {
private final class RelayReadTask implements Runnable {
- private final String stompSessionId;
- private final String replyTo;
+ private final String sessionId;
+
private final RelaySession session;
- private RelayReadTask(String stompSessionId, String replyTo, RelaySession session) {
- this.stompSessionId = stompSessionId;
- this.replyTo = replyTo;
+ private RelayReadTask(String sessionId, RelaySession session) {
+ this.sessionId = sessionId;
this.session = session;
}
@@ -221,28 +219,28 @@ public class RelayStompService extends AbstractMessageService {
}
else if (b == 0x00) {
byte[] bytes = out.toByteArray();
- StompMessage message = RelayStompService.this.stompMessageConverter.toStompMessage(bytes);
- getEventBus().send(this.replyTo, message);
+ Message message = stompMessageConverter.toMessage(bytes, sessionId);
+ getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, message);
out.reset();
}
else {
out.write(b);
}
}
- logger.debug("Socket closed, STOMP session=" + stompSessionId);
+ logger.debug("Socket closed, STOMP session=" + sessionId);
sendErrorMessage("Lost connection");
}
catch (IOException e) {
logger.error("Socket error: " + e.getMessage());
- clearRelaySession(stompSessionId);
+ clearRelaySession(sessionId);
}
}
private void sendErrorMessage(String message) {
- StompHeaders headers = new StompHeaders();
- headers.setMessage(message);
- StompMessage errorMessage = new StompMessage(StompCommand.ERROR, headers);
- getEventBus().send(this.replyTo, errorMessage);
+ StompHeaders stompHeaders = new StompHeaders(StompCommand.ERROR);
+ stompHeaders.setMessage(message);
+ Message errorMessage = new GenericMessage(new byte[0], stompHeaders.getMessageHeaders());
+ getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, errorMessage);
}
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompHeaderMapper.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompHeaderMapper.java
deleted file mode 100644
index 849ca1bf580..00000000000
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompHeaderMapper.java
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Copyright 2002-2013 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.web.messaging.stomp.support;
-
-import java.util.HashMap;
-import java.util.Map;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.springframework.http.MediaType;
-import org.springframework.messaging.MessageHeaders;
-import org.springframework.web.messaging.stomp.StompHeaders;
-
-
-/**
- * @author Rossen Stoyanchev
- * @since 4.0
- */
-public class StompHeaderMapper {
-
- private static Log logger = LogFactory.getLog(StompHeaderMapper.class);
-
- private static final String[][] stompHeaderNames;
-
- static {
- stompHeaderNames = new String[2][StompHeaders.STANDARD_HEADER_NAMES.size()];
- for (int i=0 ; i < StompHeaders.STANDARD_HEADER_NAMES.size(); i++) {
- stompHeaderNames[0][i] = StompHeaders.STANDARD_HEADER_NAMES.get(i);
- stompHeaderNames[1][i] = "stomp." + StompHeaders.STANDARD_HEADER_NAMES.get(i);
- }
- }
-
-
- public Map toMessageHeaders(StompHeaders stompHeaders) {
-
- Map headers = new HashMap();
-
- // prefixed STOMP headers
- for (int i=0; i < stompHeaderNames[0].length; i++) {
- String header = stompHeaderNames[0][i];
- if (stompHeaders.containsKey(header)) {
- String prefixedHeader = stompHeaderNames[1][i];
- headers.put(prefixedHeader, stompHeaders.getFirst(header));
- }
- }
-
- // for generic use (not-prefixed)
- if (stompHeaders.getDestination() != null) {
- headers.put("destination", stompHeaders.getDestination());
- }
- if (stompHeaders.getContentType() != null) {
- headers.put("content-type", stompHeaders.getContentType());
- }
-
- return headers;
- }
-
- public void fromMessageHeaders(MessageHeaders messageHeaders, StompHeaders stompHeaders) {
-
- // prefixed STOMP headers
- for (int i=0; i < stompHeaderNames[0].length; i++) {
- String prefixedHeader = stompHeaderNames[1][i];
- if (messageHeaders.containsKey(prefixedHeader)) {
- String header = stompHeaderNames[0][i];
- stompHeaders.add(header, (String) messageHeaders.get(prefixedHeader));
- }
- }
-
- // generic (not prefixed)
- String destination = (String) messageHeaders.get("destination");
- if (destination != null) {
- stompHeaders.setDestination(destination);
- }
- Object contentType = messageHeaders.get("content-type");
- if (contentType != null) {
- if (contentType instanceof String) {
- stompHeaders.setContentType(MediaType.valueOf((String) contentType));
- }
- else if (contentType instanceof MediaType) {
- stompHeaders.setContentType((MediaType) contentType);
- }
- else {
- logger.warn("Invalid contentType class: " + contentType.getClass());
- }
- }
- }
-
-}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java
index b120459309c..a8453307ce6 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java
@@ -17,134 +17,160 @@ package org.springframework.web.messaging.stomp.support;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
-import java.util.List;
+import java.nio.charset.Charset;
+import java.util.Map;
import java.util.Map.Entry;
+import org.springframework.messaging.GenericMessage;
+import org.springframework.messaging.Message;
+import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert;
import org.springframework.web.messaging.stomp.StompCommand;
-import org.springframework.web.messaging.stomp.StompException;
+import org.springframework.web.messaging.stomp.StompConversionException;
import org.springframework.web.messaging.stomp.StompHeaders;
-import org.springframework.web.messaging.stomp.StompMessage;
+
/**
* @author Gary Russell
+ * @author Rossen Stoyanchev
* @since 4.0
- *
*/
public class StompMessageConverter {
+ private static final Charset STOMP_CHARSET = Charset.forName("UTF-8");
+
public static final byte LF = 0x0a;
public static final byte CR = 0x0d;
private static final byte COLON = ':';
+
/**
- * @param bytes a complete STOMP message (without the trailing 0x00).
+ * @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String.
*/
- public StompMessage toStompMessage(Object stomp) {
- Assert.state(stomp instanceof String || stomp instanceof byte[], "'stomp' must be String or byte[]");
- byte[] stompBytes = null;
- if (stomp instanceof String) {
- stompBytes = ((String) stomp).getBytes(StompMessage.CHARSET);
+ public Message toMessage(Object stompContent, String sessionId) {
+
+ byte[] byteContent = null;
+ if (stompContent instanceof String) {
+ byteContent = ((String) stompContent).getBytes(STOMP_CHARSET);
+ }
+ else if (stompContent instanceof byte[]){
+ byteContent = (byte[]) stompContent;
}
else {
- stompBytes = (byte[]) stomp;
+ throw new IllegalArgumentException(
+ "stompContent is neither String nor byte[]: " + stompContent.getClass());
}
- int totalLength = stompBytes.length;
- if (stompBytes[totalLength-1] == 0) {
+
+ int totalLength = byteContent.length;
+ if (byteContent[totalLength-1] == 0) {
totalLength--;
}
- int payloadIndex = findPayloadStart(stompBytes);
+
+ int payloadIndex = findIndexOfPayload(byteContent);
if (payloadIndex == 0) {
- throw new StompException("No command found");
+ throw new StompConversionException("No command found");
}
- String headerString = new String(stompBytes, 0, payloadIndex, StompMessage.CHARSET);
- Parser parser = new Parser(headerString);
- StompHeaders headers = new StompHeaders();
+
+ String headerContent = new String(byteContent, 0, payloadIndex, STOMP_CHARSET);
+ Parser parser = new Parser(headerContent);
+
// TODO: validate command and whether a payload is allowed
StompCommand command = StompCommand.valueOf(parser.nextToken(LF).trim());
Assert.notNull(command, "No command found");
+
+ StompHeaders stompHeaders = new StompHeaders(command);
+ stompHeaders.setSessionId(sessionId);
+
while (parser.hasNext()) {
String header = parser.nextToken(COLON);
if (header != null) {
if (parser.hasNext()) {
String value = parser.nextToken(LF);
- headers.add(header, value);
+ stompHeaders.getRawHeaders().put(header, value);
}
else {
- throw new StompException("Parse exception for " + headerString);
+ throw new StompConversionException("Parse exception for " + headerContent);
}
}
}
+
byte[] payload = new byte[totalLength - payloadIndex];
- System.arraycopy(stompBytes, payloadIndex, payload, 0, totalLength - payloadIndex);
- return new StompMessage(command, headers, payload);
- }
+ System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
- public byte[] fromStompMessage(StompMessage message) {
- ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
- StompHeaders headers = message.getHeaders();
- StompCommand command = message.getCommand();
- try {
- outputStream.write(command.toString().getBytes("UTF-8"));
- outputStream.write(LF);
- for (Entry> entry : headers.entrySet()) {
- String key = entry.getKey();
- key = replaceAllOutbound(key);
- for (String value : entry.getValue()) {
- outputStream.write(key.getBytes("UTF-8"));
- outputStream.write(COLON);
- value = replaceAllOutbound(value);
- outputStream.write(value.getBytes("UTF-8"));
- outputStream.write(LF);
- }
- }
- outputStream.write(LF);
- outputStream.write(message.getPayload());
- outputStream.write(0);
- return outputStream.toByteArray();
- }
- catch (IOException e) {
- throw new StompException("Failed to serialize " + message, e);
- }
- }
+ stompHeaders.updateMessageHeaders();
- private String replaceAllOutbound(String key) {
- return key.replaceAll("\\\\", "\\\\")
- .replaceAll(":", "\\\\c")
- .replaceAll("\n", "\\\\n")
- .replaceAll("\r", "\\\\r");
+ return createMessage(command, stompHeaders.getMessageHeaders(), payload);
}
- private int findPayloadStart(byte[] bytes) {
+ private int findIndexOfPayload(byte[] bytes) {
int i;
// ignore any leading EOL from the previous message
for (i = 0; i < bytes.length; i++) {
- if (bytes[i] != '\n' && bytes[i] != '\r' ) {
+ if (bytes[i] != '\n' && bytes[i] != '\r') {
break;
}
bytes[i] = ' ';
}
- int payloadOffset = 0;
+ int index = 0;
for (; i < bytes.length - 1; i++) {
- if ((bytes[i] == LF && bytes[i+1] == LF)) {
- payloadOffset = i + 2;
+ if (bytes[i] == LF && bytes[i+1] == LF) {
+ index = i + 2;
break;
}
- if (i < bytes.length - 3 &&
- (bytes[i] == CR && bytes[i+1] == LF &&
- bytes[i+2] == CR && bytes[i+3] == LF)) {
- payloadOffset = i + 4;
+ if ((i < (bytes.length - 3)) &&
+ (bytes[i] == CR && bytes[i+1] == LF && bytes[i+2] == CR && bytes[i+3] == LF)) {
+ index = i + 4;
break;
}
}
if (i >= bytes.length) {
- throw new StompException("No end of headers found");
+ throw new StompConversionException("No end of headers found");
}
- return payloadOffset;
+ return index;
+ }
+
+ protected Message createMessage(StompCommand command, Map headers, byte[] payload) {
+ return new GenericMessage(payload, headers);
}
+ public byte[] fromMessage(Message message) {
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ MessageHeaders messageHeaders = message.getHeaders();
+ StompHeaders stompHeaders = new StompHeaders(messageHeaders, false);
+ stompHeaders.updateRawHeaders();
+ try {
+ out.write(stompHeaders.getStompCommand().toString().getBytes("UTF-8"));
+ out.write(LF);
+ for (Entry entry : stompHeaders.getRawHeaders().entrySet()) {
+ String key = entry.getKey();
+ key = replaceAllOutbound(key);
+ String value = entry.getValue();
+ out.write(key.getBytes("UTF-8"));
+ out.write(COLON);
+ value = replaceAllOutbound(value);
+ out.write(value.getBytes("UTF-8"));
+ out.write(LF);
+ }
+ out.write(LF);
+ out.write(message.getPayload());
+ out.write(0);
+ return out.toByteArray();
+ }
+ catch (IOException e) {
+ throw new StompConversionException("Failed to serialize " + message, e);
+ }
+ }
+
+ private String replaceAllOutbound(String key) {
+ return key.replaceAll("\\\\", "\\\\")
+ .replaceAll(":", "\\\\c")
+ .replaceAll("\n", "\\\\n")
+ .replaceAll("\r", "\\\\r");
+ }
+
+
private class Parser {
private final String content;
@@ -177,7 +203,7 @@ public class StompMessageConverter {
return null;
}
else {
- throw new StompException("No delimiter found at offset " + offset + " in " + this.content);
+ throw new StompConversionException("No delimiter found at offset " + offset + " in " + this.content);
}
}
int escapeAt = this.content.indexOf('\\', this.offset);
@@ -192,7 +218,7 @@ public class StompMessageConverter {
.replaceAll("\\\\\\\\", "\\\\");
}
else {
- throw new StompException("Invalid escape sequence \\" + escaped);
+ throw new StompConversionException("Invalid escape sequence \\" + escaped);
}
}
int length = token.length();
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/StompMessageInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/DestinationMessage.java
similarity index 64%
rename from spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/StompMessageInterceptor.java
rename to spring-websocket/src/main/java/org/springframework/web/messaging/support/DestinationMessage.java
index c72407f10d1..8daa9b37ec9 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/StompMessageInterceptor.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/DestinationMessage.java
@@ -14,25 +14,29 @@
* limitations under the License.
*/
-package org.springframework.web.messaging.stomp.socket;
+package org.springframework.web.messaging.support;
-import org.springframework.web.messaging.stomp.StompMessage;
+import java.util.Map;
+
+import org.springframework.messaging.GenericMessage;
/**
+ *
* @author Rossen Stoyanchev
* @since 4.0
*/
-public interface StompMessageInterceptor {
+public class DestinationMessage extends GenericMessage {
- boolean handleConnect(StompMessage message);
- boolean handleSubscribe(StompMessage message);
+ public DestinationMessage(T payload, Map headers) {
+ super(payload, headers);
+ }
- boolean handleUnsubscribe(StompMessage message);
+ public DestinationMessage(T payload) {
+ super(payload);
+ }
- StompMessage handleSend(StompMessage message);
- void handleDisconnect();
}
diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java
new file mode 100644
index 00000000000..f359100c954
--- /dev/null
+++ b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java
@@ -0,0 +1,138 @@
+/*
+ * Copyright 2002-2013 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.web.messaging.stomp.support;
+
+import java.util.Collections;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.messaging.Message;
+import org.springframework.messaging.MessageHeaders;
+import org.springframework.web.messaging.MessageType;
+import org.springframework.web.messaging.stomp.StompHeaders;
+import org.springframework.web.messaging.stomp.StompCommand;
+
+import static org.junit.Assert.*;
+
+/**
+ * @author Gary Russell
+ * @author Rossen Stoyanchev
+ */
+public class StompMessageConverterTests {
+
+ private StompMessageConverter converter;
+
+
+ @Before
+ public void setup() {
+ this.converter = new StompMessageConverter();
+ }
+
+ @Test
+ public void connectFrame() throws Exception {
+
+ String accept = "accept-version:1.1\n";
+ String host = "host:github.org\n";
+ String frame = "\n\n\nCONNECT\n" + accept + host + "\n";
+ Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123");
+
+ assertEquals(0, message.getPayload().length);
+
+ MessageHeaders messageHeaders = message.getHeaders();
+ StompHeaders stompHeaders = new StompHeaders(messageHeaders, true);
+ assertEquals(6, stompHeaders.getMessageHeaders().size());
+ assertEquals(MessageType.CONNECT, stompHeaders.getMessageType());
+ assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand());
+ assertEquals("session-123", stompHeaders.getSessionId());
+ assertNotNull(messageHeaders.get(MessageHeaders.ID));
+ assertNotNull(messageHeaders.get(MessageHeaders.TIMESTAMP));
+ assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
+ assertEquals("github.org", stompHeaders.getRawHeaders().get("host"));
+
+ String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
+
+ assertEquals("CONNECT\n", convertedBack.substring(0,8));
+ assertTrue(convertedBack.contains(accept));
+ assertTrue(convertedBack.contains(host));
+ }
+
+ @Test
+ public void connectWithEscapes() throws Exception {
+
+ String accept = "accept-version:1.1\n";
+ String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
+ String frame = "CONNECT\n" + accept + host + "\n";
+ Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123");
+
+ assertEquals(0, message.getPayload().length);
+
+ MessageHeaders messageHeaders = message.getHeaders();
+ StompHeaders stompHeaders = new StompHeaders(messageHeaders, true);
+ assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
+ assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getRawHeaders().get("ho:\ns\rt"));
+
+ String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
+
+ assertEquals("CONNECT\n", convertedBack.substring(0,8));
+ assertTrue(convertedBack.contains(accept));
+ assertTrue(convertedBack.contains(host));
+ }
+
+ @Test
+ public void connectCR12() throws Exception {
+
+ String accept = "accept-version:1.2\n";
+ String host = "host:github.org\n";
+ String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
+ Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123");
+
+ assertEquals(0, message.getPayload().length);
+
+ MessageHeaders messageHeaders = message.getHeaders();
+ StompHeaders stompHeaders = new StompHeaders(messageHeaders, true);
+ assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion());
+ assertEquals("github.org", stompHeaders.getRawHeaders().get("host"));
+
+ String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
+
+ assertEquals("CONNECT\n", convertedBack.substring(0,8));
+ assertTrue(convertedBack.contains(accept));
+ assertTrue(convertedBack.contains(host));
+ }
+
+ @Test
+ public void connectWithEscapesAndCR12() throws Exception {
+
+ String accept = "accept-version:1.1\n";
+ String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
+ String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
+ Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123");
+
+ assertEquals(0, message.getPayload().length);
+
+ MessageHeaders messageHeaders = message.getHeaders();
+ StompHeaders stompHeaders = new StompHeaders(messageHeaders, true);
+ assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
+ assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getRawHeaders().get("ho:\ns\rt"));
+
+ String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
+
+ assertEquals("CONNECT\n", convertedBack.substring(0,8));
+ assertTrue(convertedBack.contains(accept));
+ assertTrue(convertedBack.contains(host));
+ }
+
+}