Browse Source

Refactor STOMP and PubSub header message support

pull/286/merge
Rossen Stoyanchev 13 years ago
parent
commit
ad41f095a1
  1. 187
      spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java
  2. 10
      spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java
  3. 10
      spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java
  4. 2
      spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java
  5. 6
      spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java
  6. 6
      spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java
  7. 284
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java
  8. 6
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java
  9. 21
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java
  10. 19
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java
  11. 34
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java
  12. 24
      spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java

187
spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java

@ -22,143 +22,220 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
/** /**
* A base class for working with message headers in Web, messaging protocols that support
* the publish-subscribe message pattern. Provides uniform access to specific values
* common across protocols such as a destination, message type (publish,
* subscribe/unsubscribe), session id, and others.
* <p>
* This class can be used to prepare headers for a new pub-sub message, or to access
* and/or modify headers of an existing message.
* <p>
* Use one of the static factory method in this class, then call getters and setters, and
* at the end if necessary call {@link #toMessageHeaders()} to obtain the updated headers.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 4.0 * @since 4.0
*/ */
public class PubSubHeaders { public class PubSubHeaders {
protected Log logger = LogFactory.getLog(getClass());
private static final String DESTINATIONS = "destinations"; private static final String DESTINATIONS = "destinations";
private static final String CONTENT_TYPE = "contentType"; private static final String CONTENT_TYPE = "contentType";
private static final String MESSAGE_TYPE = "messageType"; private static final String MESSAGE_TYPE = "messageType";
private static final String SUBSCRIPTION_ID = "subscriptionId";
private static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType"; private static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType";
private static final String SESSION_ID = "sessionId"; private static final String SESSION_ID = "sessionId";
private static final String RAW_HEADERS = "rawHeaders"; private static final String SUBSCRIPTION_ID = "subscriptionId";
private static final String EXTERNAL_SOURCE_HEADERS = "extSourceHeaders";
private final Map<String, Object> messageHeaders; private static final Map<String, List<String>> emptyMultiValueMap =
Collections.unmodifiableMap(new LinkedMultiValueMap<String, String>(0));
// wrapped read-only message headers
private final MessageHeaders originalHeaders;
// header updates
private final Map<String, Object> headers = new HashMap<String, Object>(4);
// saved headers from a message from a remote source
private final Map<String, List<String>> externalSourceHeaders;
private final Map<String, String> rawHeaders;
/** /**
* Constructor for building new headers. * A constructor for creating new message headers.
* * This constructor is protected. See factory methods in this and sub-classes.
* @param messageType the message type
* @param protocolMessageType the protocol-specific message type or command
*/ */
public PubSubHeaders(MessageType messageType, Object protocolMessageType) { protected PubSubHeaders(MessageType messageType, Object protocolMessageType,
Map<String, List<String>> externalSourceHeaders) {
this.originalHeaders = null;
Assert.notNull(messageType, "messageType is required");
this.headers.put(MESSAGE_TYPE, messageType);
this.messageHeaders = new HashMap<String, Object>();
this.messageHeaders.put(MESSAGE_TYPE, messageType);
if (protocolMessageType != null) { if (protocolMessageType != null) {
this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType); this.headers.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
} }
this.rawHeaders = new HashMap<String, String>(); if (externalSourceHeaders == null) {
this.messageHeaders.put(RAW_HEADERS, this.rawHeaders); this.externalSourceHeaders = emptyMultiValueMap;
} }
else {
public PubSubHeaders() { this.externalSourceHeaders = Collections.unmodifiableMap(externalSourceHeaders); // TODO: list values must also be read-only
this(MessageType.MESSAGE, null); this.headers.put(EXTERNAL_SOURCE_HEADERS, this.externalSourceHeaders);
}
} }
/** /**
* Constructor for access to existing {@link MessageHeaders}. * A constructor for accessing and modifying existing message headers. This
* * constructor is protected. See factory methods in this and sub-classes.
* @param messageHeaders
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public PubSubHeaders(MessageHeaders messageHeaders, boolean readOnly) { protected PubSubHeaders(MessageHeaders originalHeaders) {
this.originalHeaders = originalHeaders;
this.externalSourceHeaders = (originalHeaders.get(EXTERNAL_SOURCE_HEADERS) != null) ?
(Map<String, List<String>>) originalHeaders.get(EXTERNAL_SOURCE_HEADERS) : emptyMultiValueMap;
}
this.messageHeaders = readOnly ? messageHeaders : new HashMap<String, Object>(messageHeaders);
this.rawHeaders = this.messageHeaders.containsKey(RAW_HEADERS) ?
(Map<String, String>) messageHeaders.get(RAW_HEADERS) : Collections.<String, String>emptyMap();
if (this.messageHeaders.get(MESSAGE_TYPE) == null) { /**
this.messageHeaders.put(MESSAGE_TYPE, MessageType.MESSAGE); * Create {@link PubSubHeaders} for a new {@link Message}.
} */
public static PubSubHeaders create() {
return new PubSubHeaders(MessageType.MESSAGE, null, null);
}
/**
* Create {@link PubSubHeaders} from existing message headers.
*/
public static PubSubHeaders fromMessageHeaders(MessageHeaders originalHeaders) {
return new PubSubHeaders(originalHeaders);
} }
public Map<String, Object> getMessageHeaders() { /**
return this.messageHeaders; * Return the original, wrapped headers (i.e. unmodified) or a new Map including any
* updates made via setters.
*/
public Map<String, Object> toMessageHeaders() {
if (!isModified()) {
return this.originalHeaders;
}
Map<String, Object> result = new HashMap<String, Object>();
if (this.originalHeaders != null) {
result.putAll(this.originalHeaders);
}
result.putAll(this.headers);
return result;
} }
public Map<String, String> getRawHeaders() { public boolean isModified() {
return this.rawHeaders; return ((this.originalHeaders == null) || !this.headers.isEmpty());
} }
public MessageType getMessageType() { public MessageType getMessageType() {
return (MessageType) this.messageHeaders.get(MESSAGE_TYPE); return (MessageType) getHeaderValue(MESSAGE_TYPE);
}
private Object getHeaderValue(String headerName) {
if (this.headers.get(headerName) != null) {
return this.headers.get(headerName);
}
else if (this.originalHeaders.get(headerName) != null) {
return this.originalHeaders.get(headerName);
}
return null;
} }
public void setProtocolMessageType(Object protocolMessageType) { protected void setProtocolMessageType(Object protocolMessageType) {
this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType); this.headers.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
} }
public Object getProtocolMessageType() { protected Object getProtocolMessageType() {
return this.messageHeaders.get(PROTOCOL_MESSAGE_TYPE); return getHeaderValue(PROTOCOL_MESSAGE_TYPE);
} }
public void setDestination(String destination) { public void setDestination(String destination) {
this.messageHeaders.put(DESTINATIONS, Arrays.asList(destination)); Assert.notNull(destination, "destination is required");
this.headers.put(DESTINATIONS, Arrays.asList(destination));
} }
@SuppressWarnings("unchecked")
public String getDestination() { public String getDestination() {
@SuppressWarnings("unchecked") List<String> destinations = (List<String>) getHeaderValue(DESTINATIONS);
List<String> destination = (List<String>) messageHeaders.get(DESTINATIONS); return CollectionUtils.isEmpty(destinations) ? null : destinations.get(0);
return CollectionUtils.isEmpty(destination) ? null : destination.get(0);
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public List<String> getDestinations() { public List<String> getDestinations() {
return (List<String>) messageHeaders.get(DESTINATIONS); List<String> destinations = (List<String>) getHeaderValue(DESTINATIONS);
return CollectionUtils.isEmpty(destinations) ? null : destinations;
} }
public void setDestinations(List<String> destinations) { public void setDestinations(List<String> destinations) {
if (destinations != null) { Assert.notNull(destinations, "destinations are required");
this.messageHeaders.put(DESTINATIONS, destinations); this.headers.put(DESTINATIONS, destinations);
}
} }
public MediaType getContentType() { public MediaType getContentType() {
return (MediaType) this.messageHeaders.get(CONTENT_TYPE); return (MediaType) getHeaderValue(CONTENT_TYPE);
} }
public void setContentType(MediaType mediaType) { public void setContentType(MediaType contentType) {
if (mediaType != null) { Assert.notNull(contentType, "contentType is required");
this.messageHeaders.put(CONTENT_TYPE, mediaType); this.headers.put(CONTENT_TYPE, contentType);
}
} }
public String getSubscriptionId() { public String getSubscriptionId() {
return (String) this.messageHeaders.get(SUBSCRIPTION_ID); return (String) getHeaderValue(SUBSCRIPTION_ID);
} }
public void setSubscriptionId(String subscriptionId) { public void setSubscriptionId(String subscriptionId) {
this.messageHeaders.put(SUBSCRIPTION_ID, subscriptionId); this.headers.put(SUBSCRIPTION_ID, subscriptionId);
} }
public String getSessionId() { public String getSessionId() {
return (String) this.messageHeaders.get(SESSION_ID); return (String) getHeaderValue(SESSION_ID);
} }
public void setSessionId(String sessionId) { public void setSessionId(String sessionId) {
this.messageHeaders.put(SESSION_ID, sessionId); this.headers.put(SESSION_ID, sessionId);
}
/**
* Return a read-only map of headers originating from a message received by the
* application from an external source (e.g. from a remote WebSocket endpoint). The
* header names and values are exactly as they were, and are protocol specific but may
* also be custom application headers if the protocol allows that.
*/
public Map<String, List<String>> getExternalSourceHeaders() {
return this.externalSourceHeaders;
}
@Override
public String toString() {
return "PubSubHeaders [originalHeaders=" + this.originalHeaders + ", headers="
+ this.headers + ", externalSourceHeaders=" + this.externalSourceHeaders + "]";
} }
} }

10
spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java

@ -72,10 +72,10 @@ public abstract class AbstractMessageService {
} }
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Processing notification: " + message); logger.trace("Processing message id=" + message.getHeaders().getId());
} }
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true); PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
MessageType messageType = headers.getMessageType(); MessageType messageType = headers.getMessageType();
if (messageType == null || messageType.equals(MessageType.OTHER)) { if (messageType == null || messageType.equals(MessageType.OTHER)) {
processOther(message); processOther(message);
@ -129,7 +129,7 @@ public abstract class AbstractMessageService {
} }
private boolean isAllowedDestination(Message<?> message) { private boolean isAllowedDestination(Message<?> message) {
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true); PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
String destination = headers.getDestination(); String destination = headers.getDestination();
if (destination == null) { if (destination == null) {
return true; return true;
@ -138,7 +138,7 @@ public abstract class AbstractMessageService {
for (String pattern : this.disallowedDestinations) { for (String pattern : this.disallowedDestinations) {
if (this.pathMatcher.match(pattern, destination)) { if (this.pathMatcher.match(pattern, destination)) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Skip notification: " + message); logger.trace("Skip notification message id=" + message.getHeaders().getId());
} }
return false; return false;
} }
@ -151,7 +151,7 @@ public abstract class AbstractMessageService {
} }
} }
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Skip notification: " + message); logger.trace("Skip notification message id=" + message.getHeaders().getId());
} }
return false; return false;
} }

10
spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java

@ -62,7 +62,7 @@ public class PubSubMessageService extends AbstractMessageService {
try { try {
// Convert to byte[] payload before the fan-out // Convert to byte[] payload before the fan-out
PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true); PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType()); byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType());
message = new GenericMessage<byte[]>(payload, message.getHeaders()); message = new GenericMessage<byte[]>(payload, message.getHeaders());
@ -82,19 +82,19 @@ public class PubSubMessageService extends AbstractMessageService {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Subscribe " + message); logger.debug("Subscribe " + message);
} }
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true); PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
final String subscriptionId = headers.getSubscriptionId(); final String subscriptionId = headers.getSubscriptionId();
EventRegistration registration = getEventBus().registerConsumer(getPublishKey(headers.getDestination()), EventRegistration registration = getEventBus().registerConsumer(getPublishKey(headers.getDestination()),
new EventConsumer<Message<?>>() { new EventConsumer<Message<?>>() {
@Override @Override
public void accept(Message<?> message) { public void accept(Message<?> message) {
PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true); PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
PubSubHeaders outHeaders = new PubSubHeaders(); PubSubHeaders outHeaders = PubSubHeaders.create();
outHeaders.setDestinations(inHeaders.getDestinations()); outHeaders.setDestinations(inHeaders.getDestinations());
outHeaders.setContentType(inHeaders.getContentType()); outHeaders.setContentType(inHeaders.getContentType());
outHeaders.setSubscriptionId(subscriptionId); outHeaders.setSubscriptionId(subscriptionId);
Object payload = message.getPayload(); Object payload = message.getPayload();
message = new GenericMessage<Object>(payload, outHeaders.getMessageHeaders()); message = new GenericMessage<Object>(payload, outHeaders.toMessageHeaders());
getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, message); getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, message);
} }
}); });

2
spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java

@ -167,7 +167,7 @@ public class AnnotationMessageService extends AbstractMessageService implements
private void handleMessage(final Message<?> message, Map<MappingInfo, HandlerMethod> handlerMethods) { private void handleMessage(final Message<?> message, Map<MappingInfo, HandlerMethod> handlerMethods) {
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true); PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
String destination = headers.getDestination(); String destination = headers.getDestination();
HandlerMethod match = getHandlerMethod(destination, handlerMethods); HandlerMethod match = getHandlerMethod(destination, handlerMethods);

6
spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java

@ -49,16 +49,16 @@ public class MessageChannelArgumentResolver implements ArgumentResolver {
@Override @Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception { public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
final String sessionId = new PubSubHeaders(message.getHeaders(), true).getSessionId(); final String sessionId = PubSubHeaders.fromMessageHeaders(message.getHeaders()).getSessionId();
return new MessageChannel() { return new MessageChannel() {
@Override @Override
public boolean send(Message<?> message) { public boolean send(Message<?> message) {
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), false); PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
headers.setSessionId(sessionId); headers.setSessionId(sessionId);
message = new GenericMessage<Object>(message.getPayload(), headers.getMessageHeaders()); message = new GenericMessage<Object>(message.getPayload(), headers.toMessageHeaders());
eventBus.send(AbstractMessageService.CLIENT_TO_SERVER_MESSAGE_KEY, message); eventBus.send(AbstractMessageService.CLIENT_TO_SERVER_MESSAGE_KEY, message);

6
spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java

@ -66,15 +66,15 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
return; return;
} }
PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true); PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
String sessionId = inHeaders.getSessionId(); String sessionId = inHeaders.getSessionId();
String subscriptionId = inHeaders.getSubscriptionId(); String subscriptionId = inHeaders.getSubscriptionId();
Assert.notNull(subscriptionId, "No subscription id: " + message); Assert.notNull(subscriptionId, "No subscription id: " + message);
PubSubHeaders outHeaders = new PubSubHeaders(returnMessage.getHeaders(), false); PubSubHeaders outHeaders = PubSubHeaders.fromMessageHeaders(returnMessage.getHeaders());
outHeaders.setSessionId(sessionId); outHeaders.setSessionId(sessionId);
outHeaders.setSubscriptionId(subscriptionId); outHeaders.setSubscriptionId(subscriptionId);
returnMessage = new GenericMessage<Object>(returnMessage.getPayload(), outHeaders.getMessageHeaders()); returnMessage = new GenericMessage<Object>(returnMessage.getPayload(), outHeaders.toMessageHeaders());
this.eventBus.send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, returnMessage); this.eventBus.send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, returnMessage);
} }

284
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java

@ -17,11 +17,17 @@
package org.springframework.web.messaging.stomp; package org.springframework.web.messaging.stomp;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessageHeaders;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.messaging.PubSubHeaders; import org.springframework.web.messaging.PubSubHeaders;
@ -29,7 +35,12 @@ import reactor.util.Assert;
/** /**
* STOMP adapter for {@link MessageHeaders}. * Can be used to prepare headers for a new STOMP message, or to access and/or modify
* STOMP-specific headers of an existing message.
* <p>
* Use one of the static factory method in this class, then call getters and setters, and
* at the end if necessary call {@link #toMessageHeaders()} to obtain the updated headers
* or call {@link #toStompMessageHeaders()} to obtain only the STOMP-specific headers.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 4.0 * @since 4.0
@ -54,6 +65,8 @@ public class StompHeaders extends PubSubHeaders {
private static final String ACK = "ack"; private static final String ACK = "ack";
private static final String NACK = "nack";
private static final String DESTINATION = "destination"; private static final String DESTINATION = "destination";
private static final String CONTENT_TYPE = "content-type"; private static final String CONTENT_TYPE = "content-type";
@ -63,30 +76,133 @@ public class StompHeaders extends PubSubHeaders {
private static final String HEARTBEAT = "heart-beat"; private static final String HEARTBEAT = "heart-beat";
private static final String STOMP_HEADERS = "stompHeaders";
private final Map<String, String> headers;
/**
* A constructor for creating new STOMP message headers.
* This constructor is private. See factory methods in this sub-classes.
*/
private StompHeaders(StompCommand command, Map<String, List<String>> externalSourceHeaders) {
super(command.getMessageType(), command, externalSourceHeaders);
this.headers = new HashMap<String, String>(4);
updateMessageHeaders();
}
private void updateMessageHeaders() {
if (getExternalSourceHeaders().isEmpty()) {
return;
}
String destination = getHeaderValue(DESTINATION);
if (destination != null) {
super.setDestination(destination);
}
String contentType = getHeaderValue(CONTENT_TYPE);
if (contentType != null) {
super.setContentType(MediaType.parseMediaType(contentType));
}
if (StompCommand.SUBSCRIBE.equals(getStompCommand())) {
if (getHeaderValue(ID) != null) {
super.setSubscriptionId(getHeaderValue(ID));
}
}
}
/**
* A constructor for accessing and modifying existing message headers. This
* constructor is protected. See factory methods in this class.
*/
@SuppressWarnings("unchecked")
private StompHeaders(MessageHeaders messageHeaders) {
super(messageHeaders);
this.headers = (messageHeaders.get(STOMP_HEADERS) != null) ?
(Map<String, String>) messageHeaders.get(STOMP_HEADERS) : new HashMap<String, String>(4);
}
/**
* Create {@link StompHeaders} for a new {@link Message}.
*/
public static StompHeaders create(StompCommand command) {
return new StompHeaders(command, null);
}
/** /**
* Constructor for building new headers. * Create {@link StompHeaders} from the headers of an existing {@link Message}.
*
* @param command the STOMP command
*/ */
public StompHeaders(StompCommand command) { public static StompHeaders fromMessageHeaders(MessageHeaders messageHeaders) {
super(command.getMessageType(), command); return new StompHeaders(messageHeaders);
} }
/** /**
* Constructor for access to existing {@link MessageHeaders}. * Create {@link StompHeaders} from parsed STOP frame content.
*
* @param messageHeaders the existing message headers
* @param readOnly whether the resulting instance will be used for read-only access,
* if {@code true}, then set methods will throw exceptions; if {@code false}
* they will work.
*/ */
public StompHeaders(MessageHeaders messageHeaders, boolean readOnly) { public static StompHeaders fromParsedFrame(StompCommand command, Map<String, List<String>> headers) {
super(messageHeaders, readOnly); return new StompHeaders(command, headers);
} }
/**
* Return the original, wrapped headers (i.e. unmodified) or a new Map including any
* updates made via setters.
*/
@Override @Override
public StompCommand getProtocolMessageType() { public Map<String, Object> toMessageHeaders() {
return (StompCommand) super.getProtocolMessageType(); Map<String, Object> result = super.toMessageHeaders();
if (isModified()) {
result.put(STOMP_HEADERS, this.headers);
}
return result;
}
@Override
public boolean isModified() {
return (super.isModified() || !this.headers.isEmpty());
}
/**
* Return STOMP headers and any custom headers that may have been sent by
* a remote endpoint, if this message originated from outside.
*/
public Map<String, List<String>> toStompMessageHeaders() {
MultiValueMap<String, String> result = new LinkedMultiValueMap<String, String>();
result.putAll(getExternalSourceHeaders());
result.setAll(this.headers);
String destination = super.getDestination();
if (destination != null) {
result.set(DESTINATION, destination);
}
MediaType contentType = getContentType();
if (contentType != null) {
result.set(CONTENT_TYPE, contentType.toString());
}
if (StompCommand.MESSAGE.equals(getStompCommand())) {
String subscriptionId = getSubscriptionId();
if (subscriptionId != null) {
result.set(SUBSCRIPTION, subscriptionId);
}
else {
logger.warn("STOMP MESSAGE frame should have a subscription: " + this.toString());
}
if ((getMessageId() == null)) {
this.headers.put(MESSAGE_ID, toMessageHeaders().get(ID).toString());
}
}
return result;
}
public void setStompCommandIfNotSet(StompCommand command) {
if (getStompCommand() == null) {
setProtocolMessageType(command);
}
} }
public StompCommand getStompCommand() { public StompCommand getStompCommand() {
@ -94,32 +210,42 @@ public class StompHeaders extends PubSubHeaders {
} }
public Set<String> getAcceptVersion() { public Set<String> getAcceptVersion() {
String rawValue = getRawHeaders().get(ACCEPT_VERSION); String rawValue = getHeaderValue(ACCEPT_VERSION);
return (rawValue != null) ? StringUtils.commaDelimitedListToSet(rawValue) : Collections.<String>emptySet(); return (rawValue != null) ? StringUtils.commaDelimitedListToSet(rawValue) : Collections.<String>emptySet();
} }
private String getHeaderValue(String headerName) {
List<String> values = getExternalSourceHeaders().get(headerName);
return !CollectionUtils.isEmpty(values) ? values.get(0) : this.headers.get(headerName);
}
public void setAcceptVersion(String acceptVersion) { public void setAcceptVersion(String acceptVersion) {
getRawHeaders().put(ACCEPT_VERSION, acceptVersion); this.headers.put(ACCEPT_VERSION, acceptVersion);
}
public void setHost(String host) {
this.headers.put(HOST, host);
}
public String getHost() {
return getHeaderValue(HOST);
} }
@Override @Override
public void setDestination(String destination) { public void setDestination(String destination) {
if (destination != null) { super.setDestination(destination);
super.setDestination(destination); this.headers.put(DESTINATION, destination);
getRawHeaders().put(DESTINATION, destination);
}
} }
@Override @Override
public void setDestinations(List<String> destinations) { public void setDestinations(List<String> destinations) {
if (destinations != null) { Assert.isTrue((destinations != null) && (destinations.size() == 1), "STOMP allows one destination per message");
super.setDestinations(destinations); super.setDestinations(destinations);
getRawHeaders().put(DESTINATION, destinations.get(0)); this.headers.put(DESTINATION, destinations.get(0));
}
} }
public long[] getHeartbeat() { public long[] getHeartbeat() {
String rawValue = getRawHeaders().get(HEARTBEAT); String rawValue = getHeaderValue(HEARTBEAT);
if (!StringUtils.hasText(rawValue)) { if (!StringUtils.hasText(rawValue)) {
return null; return null;
} }
@ -131,102 +257,82 @@ public class StompHeaders extends PubSubHeaders {
public void setContentType(MediaType mediaType) { public void setContentType(MediaType mediaType) {
if (mediaType != null) { if (mediaType != null) {
super.setContentType(mediaType); super.setContentType(mediaType);
getRawHeaders().put(CONTENT_TYPE, mediaType.toString()); this.headers.put(CONTENT_TYPE, mediaType.toString());
} }
} }
public MediaType getContentType() {
String value = getHeaderValue(CONTENT_TYPE);
return (value != null) ? MediaType.parseMediaType(value) : null;
}
public Integer getContentLength() { public Integer getContentLength() {
String contentLength = getRawHeaders().get(CONTENT_LENGTH); String contentLength = getHeaderValue(CONTENT_LENGTH);
return StringUtils.hasText(contentLength) ? new Integer(contentLength) : null; return StringUtils.hasText(contentLength) ? new Integer(contentLength) : null;
} }
public void setContentLength(int contentLength) { public void setContentLength(int contentLength) {
getRawHeaders().put(CONTENT_LENGTH, String.valueOf(contentLength)); this.headers.put(CONTENT_LENGTH, String.valueOf(contentLength));
} }
@Override public void setHeartbeat(long cx, long cy) {
public String getSubscriptionId() { this.headers.put(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy}));
return StompCommand.SUBSCRIBE.equals(getStompCommand()) ? getRawHeaders().get(ID) : null;
} }
@Override public void setAck(String ack) {
public void setSubscriptionId(String subscriptionId) { this.headers.put(ACK, ack);
Assert.isTrue(StompCommand.MESSAGE.equals(getStompCommand()),
"\"subscription\" can only be set on a STOMP MESSAGE frame");
super.setSubscriptionId(subscriptionId);
getRawHeaders().put(SUBSCRIPTION, subscriptionId);
} }
public void setHeartbeat(long cx, long cy) { public String getAck() {
getRawHeaders().put(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy})); return getHeaderValue(ACK);
}
public void setNack(String nack) {
this.headers.put(NACK, nack);
}
public String getNack() {
return getHeaderValue(NACK);
}
public void setReceiptId(String receiptId) {
this.headers.put(RECEIPT_ID, receiptId);
}
public String getReceiptId() {
return getHeaderValue(RECEIPT_ID);
} }
public String getMessage() { public String getMessage() {
return getRawHeaders().get(MESSAGE); return getHeaderValue(MESSAGE);
} }
public void setMessage(String content) { public void setMessage(String content) {
getRawHeaders().put(MESSAGE, content); this.headers.put(MESSAGE, content);
} }
public String getMessageId() { public String getMessageId() {
return getRawHeaders().get(MESSAGE_ID); return getHeaderValue(MESSAGE_ID);
} }
public void setMessageId(String id) { public void setMessageId(String id) {
getRawHeaders().put(MESSAGE_ID, id); this.headers.put(MESSAGE_ID, id);
} }
public String getVersion() { public String getVersion() {
return getRawHeaders().get(VERSION); return getHeaderValue(VERSION);
} }
public void setVersion(String version) { public void setVersion(String version) {
getRawHeaders().put(VERSION, version); this.headers.put(VERSION, version);
}
/**
* Update generic message headers from raw headers. This method only needs to be
* invoked when raw headers are added via {@link #getRawHeaders()}.
*/
public void updateMessageHeaders() {
String destination = getRawHeaders().get(DESTINATION);
if (destination != null) {
setDestination(destination);
}
String contentType = getRawHeaders().get(CONTENT_TYPE);
if (contentType != null) {
setContentType(MediaType.parseMediaType(contentType));
}
if (StompCommand.SUBSCRIBE.equals(getStompCommand())) {
if (getRawHeaders().get(ID) != null) {
super.setSubscriptionId(getRawHeaders().get(ID));
}
}
} }
/** @Override
* Update raw headers from generic message headers. This method only needs to be public String toString() {
* invoked if creating {@link StompHeaders} from {@link MessageHeaders} that never return "StompHeaders [" + "messageType=" + getMessageType() + ", protocolMessageType="
* contained raw headers. + getProtocolMessageType() + ", destination=" + getDestination()
*/ + ", subscriptionId=" + getSubscriptionId() + ", sessionId=" + getSessionId()
public void updateRawHeaders() { + ", externalSourceHeaders=" + getExternalSourceHeaders() + ", headers=" + this.headers + "]";
String destination = getDestination();
if (destination != null) {
getRawHeaders().put(DESTINATION, destination);
}
MediaType contentType = getContentType();
if (contentType != null) {
getRawHeaders().put(CONTENT_TYPE, contentType.toString());
}
String subscriptionId = getSubscriptionId();
if (subscriptionId != null) {
getRawHeaders().put(SUBSCRIPTION, subscriptionId);
}
if (StompCommand.MESSAGE.equals(getStompCommand()) && (getMessageId() == null)) {
getRawHeaders().put(MESSAGE_ID, getMessageHeaders().get(ID).toString());
}
} }
} }

6
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java

@ -21,8 +21,8 @@ import java.util.concurrent.ConcurrentHashMap;
import org.springframework.messaging.GenericMessage; import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.stomp.StompCommand; import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.stomp.support.StompMessageConverter; import org.springframework.web.messaging.stomp.support.StompMessageConverter;
import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.TextMessage;
@ -76,10 +76,10 @@ public abstract class AbstractStompWebSocketHandler extends TextWebSocketHandler
protected void sendErrorMessage(WebSocketSession session, Throwable error) { protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaders stompHeaders = new StompHeaders(StompCommand.ERROR); StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR);
stompHeaders.setMessage(error.getMessage()); stompHeaders.setMessage(error.getMessage());
Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.getMessageHeaders()); Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.toMessageHeaders());
byte[] bytes = this.stompMessageConverter.fromMessage(errorMessage); byte[] bytes = this.stompMessageConverter.fromMessage(errorMessage);
try { try {

21
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java

@ -63,10 +63,8 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
@Override @Override
public void accept(Message<?> message) { public void accept(Message<?> message) {
StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), false); StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
if (stompHeaders.getProtocolMessageType() == null) { stompHeaders.setStompCommandIfNotSet(StompCommand.MESSAGE);
stompHeaders.setProtocolMessageType(StompCommand.MESSAGE);
}
if (StompCommand.CONNECTED.equals(stompHeaders.getStompCommand())) { if (StompCommand.CONNECTED.equals(stompHeaders.getStompCommand())) {
// Ignore for now since we already sent it // Ignore for now since we already sent it
@ -74,6 +72,9 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
} }
String sessionId = stompHeaders.getSessionId(); String sessionId = stompHeaders.getSessionId();
if (sessionId == null) {
logger.error("Cannot send message without a session id: " + message);
}
WebSocketSession session = getWebSocketSession(sessionId); WebSocketSession session = getWebSocketSession(sessionId);
byte[] payload; byte[] payload;
@ -87,7 +88,7 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
} }
try { try {
Map<String, Object> messageHeaders = stompHeaders.getMessageHeaders(); Map<String, Object> messageHeaders = stompHeaders.toMessageHeaders();
Message<byte[]> byteMessage = new GenericMessage<byte[]>(payload, messageHeaders); Message<byte[]> byteMessage = new GenericMessage<byte[]>(payload, messageHeaders);
byte[] bytes = getStompMessageConverter().fromMessage(byteMessage); byte[] bytes = getStompMessageConverter().fromMessage(byteMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
@ -116,11 +117,11 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
public void handleStompMessage(final WebSocketSession session, Message<byte[]> message) { public void handleStompMessage(final WebSocketSession session, Message<byte[]> message) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Processing: " + message); logger.trace("Processing STOMP message: " + message);
} }
try { try {
StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), true); StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
MessageType messageType = stompHeaders.getMessageType(); MessageType messageType = stompHeaders.getMessageType();
if (MessageType.CONNECT.equals(messageType)) { if (MessageType.CONNECT.equals(messageType)) {
handleConnect(session, message); handleConnect(session, message);
@ -147,8 +148,8 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
protected void handleConnect(final WebSocketSession session, Message<byte[]> message) throws IOException { protected void handleConnect(final WebSocketSession session, Message<byte[]> message) throws IOException {
StompHeaders connectStompHeaders = new StompHeaders(message.getHeaders(), true); StompHeaders connectStompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
StompHeaders connectedStompHeaders = new StompHeaders(StompCommand.CONNECTED); StompHeaders connectedStompHeaders = StompHeaders.create(StompCommand.CONNECTED);
Set<String> acceptVersions = connectStompHeaders.getAcceptVersion(); Set<String> acceptVersions = connectStompHeaders.getAcceptVersion();
if (acceptVersions.contains("1.2")) { if (acceptVersions.contains("1.2")) {
@ -167,7 +168,7 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
// TODO: security // TODO: security
Message<byte[]> connectedMessage = new GenericMessage<byte[]>(new byte[0], connectedStompHeaders.getMessageHeaders()); Message<byte[]> connectedMessage = new GenericMessage<byte[]>(new byte[0], connectedStompHeaders.toMessageHeaders());
byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage); byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
} }

19
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java

@ -93,21 +93,20 @@ public class RelayStompService extends AbstractMessageService {
private void forwardMessage(Message<?> message, StompCommand command) { private void forwardMessage(Message<?> message, StompCommand command) {
StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), false); StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
String sessionId = stompHeaders.getSessionId(); String sessionId = stompHeaders.getSessionId();
RelaySession session = RelayStompService.this.relaySessions.get(sessionId); RelaySession session = RelayStompService.this.relaySessions.get(sessionId);
Assert.notNull(session, "RelaySession not found"); Assert.notNull(session, "RelaySession not found");
try { try {
if (stompHeaders.getProtocolMessageType() == null) { stompHeaders.setStompCommandIfNotSet(StompCommand.SEND);
stompHeaders.setProtocolMessageType(StompCommand.SEND);
}
MediaType contentType = stompHeaders.getContentType(); MediaType contentType = stompHeaders.getContentType();
byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType); byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType);
Message<byte[]> byteMessage = new GenericMessage<byte[]>(payload, stompHeaders.getMessageHeaders()); Message<byte[]> byteMessage = new GenericMessage<byte[]>(payload, stompHeaders.toMessageHeaders());
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Forwarding: " + byteMessage); logger.trace("Forwarding STOMP " + stompHeaders.getStompCommand() + " message");
} }
byte[] bytes = this.stompMessageConverter.fromMessage(byteMessage); byte[] bytes = this.stompMessageConverter.fromMessage(byteMessage);
@ -115,7 +114,7 @@ public class RelayStompService extends AbstractMessageService {
session.getOutputStream().flush(); session.getOutputStream().flush();
} }
catch (Exception ex) { catch (Exception ex) {
logger.error("Couldn't forward message", ex); logger.error("Couldn't forward message " + message, ex);
clearRelaySession(sessionId); clearRelaySession(sessionId);
} }
} }
@ -233,13 +232,15 @@ public class RelayStompService extends AbstractMessageService {
catch (IOException e) { catch (IOException e) {
logger.error("Socket error: " + e.getMessage()); logger.error("Socket error: " + e.getMessage());
clearRelaySession(sessionId); clearRelaySession(sessionId);
sendErrorMessage("Lost connection");
} }
} }
private void sendErrorMessage(String message) { private void sendErrorMessage(String message) {
StompHeaders stompHeaders = new StompHeaders(StompCommand.ERROR); StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR);
stompHeaders.setMessage(message); stompHeaders.setMessage(message);
Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.getMessageHeaders()); stompHeaders.setSessionId(this.sessionId);
Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.toMessageHeaders());
getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, errorMessage); getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, errorMessage);
} }
} }

34
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java

@ -18,6 +18,7 @@ package org.springframework.web.messaging.stomp.support;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
@ -25,6 +26,8 @@ import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.messaging.stomp.StompCommand; import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompConversionException; import org.springframework.web.messaging.stomp.StompConversionException;
import org.springframework.web.messaging.stomp.StompHeaders; import org.springframework.web.messaging.stomp.StompHeaders;
@ -80,15 +83,13 @@ public class StompMessageConverter {
StompCommand command = StompCommand.valueOf(parser.nextToken(LF).trim()); StompCommand command = StompCommand.valueOf(parser.nextToken(LF).trim());
Assert.notNull(command, "No command found"); Assert.notNull(command, "No command found");
StompHeaders stompHeaders = new StompHeaders(command); MultiValueMap<String, String> headers = new LinkedMultiValueMap<String, String>();
stompHeaders.setSessionId(sessionId);
while (parser.hasNext()) { while (parser.hasNext()) {
String header = parser.nextToken(COLON); String header = parser.nextToken(COLON);
if (header != null) { if (header != null) {
if (parser.hasNext()) { if (parser.hasNext()) {
String value = parser.nextToken(LF); String value = parser.nextToken(LF);
stompHeaders.getRawHeaders().put(header, value); headers.add(header, value);
} }
else { else {
throw new StompConversionException("Parse exception for " + headerContent); throw new StompConversionException("Parse exception for " + headerContent);
@ -96,12 +97,13 @@ public class StompMessageConverter {
} }
} }
StompHeaders stompHeaders = StompHeaders.fromParsedFrame(command, headers);
stompHeaders.setSessionId(sessionId);
byte[] payload = new byte[totalLength - payloadIndex]; byte[] payload = new byte[totalLength - payloadIndex];
System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex); System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
stompHeaders.updateMessageHeaders(); return createMessage(command, stompHeaders.toMessageHeaders(), payload);
return createMessage(command, stompHeaders.getMessageHeaders(), payload);
} }
private int findIndexOfPayload(byte[] bytes) { private int findIndexOfPayload(byte[] bytes) {
@ -138,20 +140,20 @@ public class StompMessageConverter {
public byte[] fromMessage(Message<byte[]> message) { public byte[] fromMessage(Message<byte[]> message) {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
MessageHeaders messageHeaders = message.getHeaders(); MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, false); StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
stompHeaders.updateRawHeaders();
try { try {
out.write(stompHeaders.getStompCommand().toString().getBytes("UTF-8")); out.write(stompHeaders.getStompCommand().toString().getBytes("UTF-8"));
out.write(LF); out.write(LF);
for (Entry<String, String> entry : stompHeaders.getRawHeaders().entrySet()) { for (Entry<String, List<String>> entry : stompHeaders.toStompMessageHeaders().entrySet()) {
String key = entry.getKey(); String key = entry.getKey();
key = replaceAllOutbound(key); key = replaceAllOutbound(key);
String value = entry.getValue(); for (String value : entry.getValue()) {
out.write(key.getBytes("UTF-8")); out.write(key.getBytes("UTF-8"));
out.write(COLON); out.write(COLON);
value = replaceAllOutbound(value); value = replaceAllOutbound(value);
out.write(value.getBytes("UTF-8")); out.write(value.getBytes("UTF-8"));
out.write(LF); out.write(LF);
}
} }
out.write(LF); out.write(LF);
out.write(message.getPayload()); out.write(message.getPayload());

24
spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java

@ -22,8 +22,8 @@ import org.junit.Test;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessageHeaders;
import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.stomp.StompCommand; import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompHeaders;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -52,15 +52,17 @@ public class StompMessageConverterTests {
assertEquals(0, message.getPayload().length); assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders(); MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, true); StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
assertEquals(6, stompHeaders.getMessageHeaders().size()); assertEquals(7, stompHeaders.toMessageHeaders().size());
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getHost());
assertEquals(MessageType.CONNECT, stompHeaders.getMessageType()); assertEquals(MessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand()); assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand());
assertEquals("session-123", stompHeaders.getSessionId()); assertEquals("session-123", stompHeaders.getSessionId());
assertNotNull(messageHeaders.get(MessageHeaders.ID)); assertNotNull(messageHeaders.get(MessageHeaders.ID));
assertNotNull(messageHeaders.get(MessageHeaders.TIMESTAMP)); assertNotNull(messageHeaders.get(MessageHeaders.TIMESTAMP));
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getRawHeaders().get("host"));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
@ -80,9 +82,9 @@ public class StompMessageConverterTests {
assertEquals(0, message.getPayload().length); assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders(); MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, true); StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getRawHeaders().get("ho:\ns\rt")); assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getExternalSourceHeaders().get("ho:\ns\rt").get(0));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
@ -102,9 +104,9 @@ public class StompMessageConverterTests {
assertEquals(0, message.getPayload().length); assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders(); MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, true); StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion()); assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getRawHeaders().get("host")); assertEquals("github.org", stompHeaders.getHost());
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
@ -124,9 +126,9 @@ public class StompMessageConverterTests {
assertEquals(0, message.getPayload().length); assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders(); MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, true); StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getRawHeaders().get("ho:\ns\rt")); assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getExternalSourceHeaders().get("ho:\ns\rt").get(0));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");

Loading…
Cancel
Save