Browse Source

Refactor STOMP and PubSub header message support

pull/286/merge
Rossen Stoyanchev 13 years ago
parent
commit
ad41f095a1
  1. 187
      spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java
  2. 10
      spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java
  3. 10
      spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java
  4. 2
      spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java
  5. 6
      spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java
  6. 6
      spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java
  7. 284
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java
  8. 6
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java
  9. 21
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java
  10. 19
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java
  11. 34
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java
  12. 24
      spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java

187
spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java

@ -22,143 +22,220 @@ import java.util.HashMap; @@ -22,143 +22,220 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.MediaType;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
/**
* A base class for working with message headers in Web, messaging protocols that support
* the publish-subscribe message pattern. Provides uniform access to specific values
* common across protocols such as a destination, message type (publish,
* subscribe/unsubscribe), session id, and others.
* <p>
* This class can be used to prepare headers for a new pub-sub message, or to access
* and/or modify headers of an existing message.
* <p>
* Use one of the static factory method in this class, then call getters and setters, and
* at the end if necessary call {@link #toMessageHeaders()} to obtain the updated headers.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class PubSubHeaders {
protected Log logger = LogFactory.getLog(getClass());
private static final String DESTINATIONS = "destinations";
private static final String CONTENT_TYPE = "contentType";
private static final String MESSAGE_TYPE = "messageType";
private static final String SUBSCRIPTION_ID = "subscriptionId";
private static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType";
private static final String SESSION_ID = "sessionId";
private static final String RAW_HEADERS = "rawHeaders";
private static final String SUBSCRIPTION_ID = "subscriptionId";
private static final String EXTERNAL_SOURCE_HEADERS = "extSourceHeaders";
private final Map<String, Object> messageHeaders;
private static final Map<String, List<String>> emptyMultiValueMap =
Collections.unmodifiableMap(new LinkedMultiValueMap<String, String>(0));
// wrapped read-only message headers
private final MessageHeaders originalHeaders;
// header updates
private final Map<String, Object> headers = new HashMap<String, Object>(4);
// saved headers from a message from a remote source
private final Map<String, List<String>> externalSourceHeaders;
private final Map<String, String> rawHeaders;
/**
* Constructor for building new headers.
*
* @param messageType the message type
* @param protocolMessageType the protocol-specific message type or command
* A constructor for creating new message headers.
* This constructor is protected. See factory methods in this and sub-classes.
*/
public PubSubHeaders(MessageType messageType, Object protocolMessageType) {
protected PubSubHeaders(MessageType messageType, Object protocolMessageType,
Map<String, List<String>> externalSourceHeaders) {
this.originalHeaders = null;
Assert.notNull(messageType, "messageType is required");
this.headers.put(MESSAGE_TYPE, messageType);
this.messageHeaders = new HashMap<String, Object>();
this.messageHeaders.put(MESSAGE_TYPE, messageType);
if (protocolMessageType != null) {
this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
this.headers.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
}
this.rawHeaders = new HashMap<String, String>();
this.messageHeaders.put(RAW_HEADERS, this.rawHeaders);
}
public PubSubHeaders() {
this(MessageType.MESSAGE, null);
if (externalSourceHeaders == null) {
this.externalSourceHeaders = emptyMultiValueMap;
}
else {
this.externalSourceHeaders = Collections.unmodifiableMap(externalSourceHeaders); // TODO: list values must also be read-only
this.headers.put(EXTERNAL_SOURCE_HEADERS, this.externalSourceHeaders);
}
}
/**
* Constructor for access to existing {@link MessageHeaders}.
*
* @param messageHeaders
* A constructor for accessing and modifying existing message headers. This
* constructor is protected. See factory methods in this and sub-classes.
*/
@SuppressWarnings("unchecked")
public PubSubHeaders(MessageHeaders messageHeaders, boolean readOnly) {
protected PubSubHeaders(MessageHeaders originalHeaders) {
this.originalHeaders = originalHeaders;
this.externalSourceHeaders = (originalHeaders.get(EXTERNAL_SOURCE_HEADERS) != null) ?
(Map<String, List<String>>) originalHeaders.get(EXTERNAL_SOURCE_HEADERS) : emptyMultiValueMap;
}
this.messageHeaders = readOnly ? messageHeaders : new HashMap<String, Object>(messageHeaders);
this.rawHeaders = this.messageHeaders.containsKey(RAW_HEADERS) ?
(Map<String, String>) messageHeaders.get(RAW_HEADERS) : Collections.<String, String>emptyMap();
if (this.messageHeaders.get(MESSAGE_TYPE) == null) {
this.messageHeaders.put(MESSAGE_TYPE, MessageType.MESSAGE);
}
/**
* Create {@link PubSubHeaders} for a new {@link Message}.
*/
public static PubSubHeaders create() {
return new PubSubHeaders(MessageType.MESSAGE, null, null);
}
/**
* Create {@link PubSubHeaders} from existing message headers.
*/
public static PubSubHeaders fromMessageHeaders(MessageHeaders originalHeaders) {
return new PubSubHeaders(originalHeaders);
}
public Map<String, Object> getMessageHeaders() {
return this.messageHeaders;
/**
* Return the original, wrapped headers (i.e. unmodified) or a new Map including any
* updates made via setters.
*/
public Map<String, Object> toMessageHeaders() {
if (!isModified()) {
return this.originalHeaders;
}
Map<String, Object> result = new HashMap<String, Object>();
if (this.originalHeaders != null) {
result.putAll(this.originalHeaders);
}
result.putAll(this.headers);
return result;
}
public Map<String, String> getRawHeaders() {
return this.rawHeaders;
public boolean isModified() {
return ((this.originalHeaders == null) || !this.headers.isEmpty());
}
public MessageType getMessageType() {
return (MessageType) this.messageHeaders.get(MESSAGE_TYPE);
return (MessageType) getHeaderValue(MESSAGE_TYPE);
}
private Object getHeaderValue(String headerName) {
if (this.headers.get(headerName) != null) {
return this.headers.get(headerName);
}
else if (this.originalHeaders.get(headerName) != null) {
return this.originalHeaders.get(headerName);
}
return null;
}
public void setProtocolMessageType(Object protocolMessageType) {
this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
protected void setProtocolMessageType(Object protocolMessageType) {
this.headers.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
}
public Object getProtocolMessageType() {
return this.messageHeaders.get(PROTOCOL_MESSAGE_TYPE);
protected Object getProtocolMessageType() {
return getHeaderValue(PROTOCOL_MESSAGE_TYPE);
}
public void setDestination(String destination) {
this.messageHeaders.put(DESTINATIONS, Arrays.asList(destination));
Assert.notNull(destination, "destination is required");
this.headers.put(DESTINATIONS, Arrays.asList(destination));
}
@SuppressWarnings("unchecked")
public String getDestination() {
@SuppressWarnings("unchecked")
List<String> destination = (List<String>) messageHeaders.get(DESTINATIONS);
return CollectionUtils.isEmpty(destination) ? null : destination.get(0);
List<String> destinations = (List<String>) getHeaderValue(DESTINATIONS);
return CollectionUtils.isEmpty(destinations) ? null : destinations.get(0);
}
@SuppressWarnings("unchecked")
public List<String> getDestinations() {
return (List<String>) messageHeaders.get(DESTINATIONS);
List<String> destinations = (List<String>) getHeaderValue(DESTINATIONS);
return CollectionUtils.isEmpty(destinations) ? null : destinations;
}
public void setDestinations(List<String> destinations) {
if (destinations != null) {
this.messageHeaders.put(DESTINATIONS, destinations);
}
Assert.notNull(destinations, "destinations are required");
this.headers.put(DESTINATIONS, destinations);
}
public MediaType getContentType() {
return (MediaType) this.messageHeaders.get(CONTENT_TYPE);
return (MediaType) getHeaderValue(CONTENT_TYPE);
}
public void setContentType(MediaType mediaType) {
if (mediaType != null) {
this.messageHeaders.put(CONTENT_TYPE, mediaType);
}
public void setContentType(MediaType contentType) {
Assert.notNull(contentType, "contentType is required");
this.headers.put(CONTENT_TYPE, contentType);
}
public String getSubscriptionId() {
return (String) this.messageHeaders.get(SUBSCRIPTION_ID);
return (String) getHeaderValue(SUBSCRIPTION_ID);
}
public void setSubscriptionId(String subscriptionId) {
this.messageHeaders.put(SUBSCRIPTION_ID, subscriptionId);
this.headers.put(SUBSCRIPTION_ID, subscriptionId);
}
public String getSessionId() {
return (String) this.messageHeaders.get(SESSION_ID);
return (String) getHeaderValue(SESSION_ID);
}
public void setSessionId(String sessionId) {
this.messageHeaders.put(SESSION_ID, sessionId);
this.headers.put(SESSION_ID, sessionId);
}
/**
* Return a read-only map of headers originating from a message received by the
* application from an external source (e.g. from a remote WebSocket endpoint). The
* header names and values are exactly as they were, and are protocol specific but may
* also be custom application headers if the protocol allows that.
*/
public Map<String, List<String>> getExternalSourceHeaders() {
return this.externalSourceHeaders;
}
@Override
public String toString() {
return "PubSubHeaders [originalHeaders=" + this.originalHeaders + ", headers="
+ this.headers + ", externalSourceHeaders=" + this.externalSourceHeaders + "]";
}
}

10
spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java

@ -72,10 +72,10 @@ public abstract class AbstractMessageService { @@ -72,10 +72,10 @@ public abstract class AbstractMessageService {
}
if (logger.isTraceEnabled()) {
logger.trace("Processing notification: " + message);
logger.trace("Processing message id=" + message.getHeaders().getId());
}
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
MessageType messageType = headers.getMessageType();
if (messageType == null || messageType.equals(MessageType.OTHER)) {
processOther(message);
@ -129,7 +129,7 @@ public abstract class AbstractMessageService { @@ -129,7 +129,7 @@ public abstract class AbstractMessageService {
}
private boolean isAllowedDestination(Message<?> message) {
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
String destination = headers.getDestination();
if (destination == null) {
return true;
@ -138,7 +138,7 @@ public abstract class AbstractMessageService { @@ -138,7 +138,7 @@ public abstract class AbstractMessageService {
for (String pattern : this.disallowedDestinations) {
if (this.pathMatcher.match(pattern, destination)) {
if (logger.isTraceEnabled()) {
logger.trace("Skip notification: " + message);
logger.trace("Skip notification message id=" + message.getHeaders().getId());
}
return false;
}
@ -151,7 +151,7 @@ public abstract class AbstractMessageService { @@ -151,7 +151,7 @@ public abstract class AbstractMessageService {
}
}
if (logger.isTraceEnabled()) {
logger.trace("Skip notification: " + message);
logger.trace("Skip notification message id=" + message.getHeaders().getId());
}
return false;
}

10
spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java

@ -62,7 +62,7 @@ public class PubSubMessageService extends AbstractMessageService { @@ -62,7 +62,7 @@ public class PubSubMessageService extends AbstractMessageService {
try {
// Convert to byte[] payload before the fan-out
PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType());
message = new GenericMessage<byte[]>(payload, message.getHeaders());
@ -82,19 +82,19 @@ public class PubSubMessageService extends AbstractMessageService { @@ -82,19 +82,19 @@ public class PubSubMessageService extends AbstractMessageService {
if (logger.isDebugEnabled()) {
logger.debug("Subscribe " + message);
}
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
final String subscriptionId = headers.getSubscriptionId();
EventRegistration registration = getEventBus().registerConsumer(getPublishKey(headers.getDestination()),
new EventConsumer<Message<?>>() {
@Override
public void accept(Message<?> message) {
PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
PubSubHeaders outHeaders = new PubSubHeaders();
PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
PubSubHeaders outHeaders = PubSubHeaders.create();
outHeaders.setDestinations(inHeaders.getDestinations());
outHeaders.setContentType(inHeaders.getContentType());
outHeaders.setSubscriptionId(subscriptionId);
Object payload = message.getPayload();
message = new GenericMessage<Object>(payload, outHeaders.getMessageHeaders());
message = new GenericMessage<Object>(payload, outHeaders.toMessageHeaders());
getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, message);
}
});

2
spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java

@ -167,7 +167,7 @@ public class AnnotationMessageService extends AbstractMessageService implements @@ -167,7 +167,7 @@ public class AnnotationMessageService extends AbstractMessageService implements
private void handleMessage(final Message<?> message, Map<MappingInfo, HandlerMethod> handlerMethods) {
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
String destination = headers.getDestination();
HandlerMethod match = getHandlerMethod(destination, handlerMethods);

6
spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java

@ -49,16 +49,16 @@ public class MessageChannelArgumentResolver implements ArgumentResolver { @@ -49,16 +49,16 @@ public class MessageChannelArgumentResolver implements ArgumentResolver {
@Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
final String sessionId = new PubSubHeaders(message.getHeaders(), true).getSessionId();
final String sessionId = PubSubHeaders.fromMessageHeaders(message.getHeaders()).getSessionId();
return new MessageChannel() {
@Override
public boolean send(Message<?> message) {
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), false);
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
headers.setSessionId(sessionId);
message = new GenericMessage<Object>(message.getPayload(), headers.getMessageHeaders());
message = new GenericMessage<Object>(message.getPayload(), headers.toMessageHeaders());
eventBus.send(AbstractMessageService.CLIENT_TO_SERVER_MESSAGE_KEY, message);

6
spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java

@ -66,15 +66,15 @@ public class MessageReturnValueHandler implements ReturnValueHandler { @@ -66,15 +66,15 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
return;
}
PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
String sessionId = inHeaders.getSessionId();
String subscriptionId = inHeaders.getSubscriptionId();
Assert.notNull(subscriptionId, "No subscription id: " + message);
PubSubHeaders outHeaders = new PubSubHeaders(returnMessage.getHeaders(), false);
PubSubHeaders outHeaders = PubSubHeaders.fromMessageHeaders(returnMessage.getHeaders());
outHeaders.setSessionId(sessionId);
outHeaders.setSubscriptionId(subscriptionId);
returnMessage = new GenericMessage<Object>(returnMessage.getPayload(), outHeaders.getMessageHeaders());
returnMessage = new GenericMessage<Object>(returnMessage.getPayload(), outHeaders.toMessageHeaders());
this.eventBus.send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, returnMessage);
}

284
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java

@ -17,11 +17,17 @@ @@ -17,11 +17,17 @@
package org.springframework.web.messaging.stomp;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.http.MediaType;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.messaging.PubSubHeaders;
@ -29,7 +35,12 @@ import reactor.util.Assert; @@ -29,7 +35,12 @@ import reactor.util.Assert;
/**
* STOMP adapter for {@link MessageHeaders}.
* Can be used to prepare headers for a new STOMP message, or to access and/or modify
* STOMP-specific headers of an existing message.
* <p>
* Use one of the static factory method in this class, then call getters and setters, and
* at the end if necessary call {@link #toMessageHeaders()} to obtain the updated headers
* or call {@link #toStompMessageHeaders()} to obtain only the STOMP-specific headers.
*
* @author Rossen Stoyanchev
* @since 4.0
@ -54,6 +65,8 @@ public class StompHeaders extends PubSubHeaders { @@ -54,6 +65,8 @@ public class StompHeaders extends PubSubHeaders {
private static final String ACK = "ack";
private static final String NACK = "nack";
private static final String DESTINATION = "destination";
private static final String CONTENT_TYPE = "content-type";
@ -63,30 +76,133 @@ public class StompHeaders extends PubSubHeaders { @@ -63,30 +76,133 @@ public class StompHeaders extends PubSubHeaders {
private static final String HEARTBEAT = "heart-beat";
private static final String STOMP_HEADERS = "stompHeaders";
private final Map<String, String> headers;
/**
* A constructor for creating new STOMP message headers.
* This constructor is private. See factory methods in this sub-classes.
*/
private StompHeaders(StompCommand command, Map<String, List<String>> externalSourceHeaders) {
super(command.getMessageType(), command, externalSourceHeaders);
this.headers = new HashMap<String, String>(4);
updateMessageHeaders();
}
private void updateMessageHeaders() {
if (getExternalSourceHeaders().isEmpty()) {
return;
}
String destination = getHeaderValue(DESTINATION);
if (destination != null) {
super.setDestination(destination);
}
String contentType = getHeaderValue(CONTENT_TYPE);
if (contentType != null) {
super.setContentType(MediaType.parseMediaType(contentType));
}
if (StompCommand.SUBSCRIBE.equals(getStompCommand())) {
if (getHeaderValue(ID) != null) {
super.setSubscriptionId(getHeaderValue(ID));
}
}
}
/**
* A constructor for accessing and modifying existing message headers. This
* constructor is protected. See factory methods in this class.
*/
@SuppressWarnings("unchecked")
private StompHeaders(MessageHeaders messageHeaders) {
super(messageHeaders);
this.headers = (messageHeaders.get(STOMP_HEADERS) != null) ?
(Map<String, String>) messageHeaders.get(STOMP_HEADERS) : new HashMap<String, String>(4);
}
/**
* Create {@link StompHeaders} for a new {@link Message}.
*/
public static StompHeaders create(StompCommand command) {
return new StompHeaders(command, null);
}
/**
* Constructor for building new headers.
*
* @param command the STOMP command
* Create {@link StompHeaders} from the headers of an existing {@link Message}.
*/
public StompHeaders(StompCommand command) {
super(command.getMessageType(), command);
public static StompHeaders fromMessageHeaders(MessageHeaders messageHeaders) {
return new StompHeaders(messageHeaders);
}
/**
* Constructor for access to existing {@link MessageHeaders}.
*
* @param messageHeaders the existing message headers
* @param readOnly whether the resulting instance will be used for read-only access,
* if {@code true}, then set methods will throw exceptions; if {@code false}
* they will work.
* Create {@link StompHeaders} from parsed STOP frame content.
*/
public StompHeaders(MessageHeaders messageHeaders, boolean readOnly) {
super(messageHeaders, readOnly);
public static StompHeaders fromParsedFrame(StompCommand command, Map<String, List<String>> headers) {
return new StompHeaders(command, headers);
}
/**
* Return the original, wrapped headers (i.e. unmodified) or a new Map including any
* updates made via setters.
*/
@Override
public StompCommand getProtocolMessageType() {
return (StompCommand) super.getProtocolMessageType();
public Map<String, Object> toMessageHeaders() {
Map<String, Object> result = super.toMessageHeaders();
if (isModified()) {
result.put(STOMP_HEADERS, this.headers);
}
return result;
}
@Override
public boolean isModified() {
return (super.isModified() || !this.headers.isEmpty());
}
/**
* Return STOMP headers and any custom headers that may have been sent by
* a remote endpoint, if this message originated from outside.
*/
public Map<String, List<String>> toStompMessageHeaders() {
MultiValueMap<String, String> result = new LinkedMultiValueMap<String, String>();
result.putAll(getExternalSourceHeaders());
result.setAll(this.headers);
String destination = super.getDestination();
if (destination != null) {
result.set(DESTINATION, destination);
}
MediaType contentType = getContentType();
if (contentType != null) {
result.set(CONTENT_TYPE, contentType.toString());
}
if (StompCommand.MESSAGE.equals(getStompCommand())) {
String subscriptionId = getSubscriptionId();
if (subscriptionId != null) {
result.set(SUBSCRIPTION, subscriptionId);
}
else {
logger.warn("STOMP MESSAGE frame should have a subscription: " + this.toString());
}
if ((getMessageId() == null)) {
this.headers.put(MESSAGE_ID, toMessageHeaders().get(ID).toString());
}
}
return result;
}
public void setStompCommandIfNotSet(StompCommand command) {
if (getStompCommand() == null) {
setProtocolMessageType(command);
}
}
public StompCommand getStompCommand() {
@ -94,32 +210,42 @@ public class StompHeaders extends PubSubHeaders { @@ -94,32 +210,42 @@ public class StompHeaders extends PubSubHeaders {
}
public Set<String> getAcceptVersion() {
String rawValue = getRawHeaders().get(ACCEPT_VERSION);
String rawValue = getHeaderValue(ACCEPT_VERSION);
return (rawValue != null) ? StringUtils.commaDelimitedListToSet(rawValue) : Collections.<String>emptySet();
}
private String getHeaderValue(String headerName) {
List<String> values = getExternalSourceHeaders().get(headerName);
return !CollectionUtils.isEmpty(values) ? values.get(0) : this.headers.get(headerName);
}
public void setAcceptVersion(String acceptVersion) {
getRawHeaders().put(ACCEPT_VERSION, acceptVersion);
this.headers.put(ACCEPT_VERSION, acceptVersion);
}
public void setHost(String host) {
this.headers.put(HOST, host);
}
public String getHost() {
return getHeaderValue(HOST);
}
@Override
public void setDestination(String destination) {
if (destination != null) {
super.setDestination(destination);
getRawHeaders().put(DESTINATION, destination);
}
super.setDestination(destination);
this.headers.put(DESTINATION, destination);
}
@Override
public void setDestinations(List<String> destinations) {
if (destinations != null) {
super.setDestinations(destinations);
getRawHeaders().put(DESTINATION, destinations.get(0));
}
Assert.isTrue((destinations != null) && (destinations.size() == 1), "STOMP allows one destination per message");
super.setDestinations(destinations);
this.headers.put(DESTINATION, destinations.get(0));
}
public long[] getHeartbeat() {
String rawValue = getRawHeaders().get(HEARTBEAT);
String rawValue = getHeaderValue(HEARTBEAT);
if (!StringUtils.hasText(rawValue)) {
return null;
}
@ -131,102 +257,82 @@ public class StompHeaders extends PubSubHeaders { @@ -131,102 +257,82 @@ public class StompHeaders extends PubSubHeaders {
public void setContentType(MediaType mediaType) {
if (mediaType != null) {
super.setContentType(mediaType);
getRawHeaders().put(CONTENT_TYPE, mediaType.toString());
this.headers.put(CONTENT_TYPE, mediaType.toString());
}
}
public MediaType getContentType() {
String value = getHeaderValue(CONTENT_TYPE);
return (value != null) ? MediaType.parseMediaType(value) : null;
}
public Integer getContentLength() {
String contentLength = getRawHeaders().get(CONTENT_LENGTH);
String contentLength = getHeaderValue(CONTENT_LENGTH);
return StringUtils.hasText(contentLength) ? new Integer(contentLength) : null;
}
public void setContentLength(int contentLength) {
getRawHeaders().put(CONTENT_LENGTH, String.valueOf(contentLength));
this.headers.put(CONTENT_LENGTH, String.valueOf(contentLength));
}
@Override
public String getSubscriptionId() {
return StompCommand.SUBSCRIBE.equals(getStompCommand()) ? getRawHeaders().get(ID) : null;
public void setHeartbeat(long cx, long cy) {
this.headers.put(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy}));
}
@Override
public void setSubscriptionId(String subscriptionId) {
Assert.isTrue(StompCommand.MESSAGE.equals(getStompCommand()),
"\"subscription\" can only be set on a STOMP MESSAGE frame");
super.setSubscriptionId(subscriptionId);
getRawHeaders().put(SUBSCRIPTION, subscriptionId);
public void setAck(String ack) {
this.headers.put(ACK, ack);
}
public void setHeartbeat(long cx, long cy) {
getRawHeaders().put(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy}));
public String getAck() {
return getHeaderValue(ACK);
}
public void setNack(String nack) {
this.headers.put(NACK, nack);
}
public String getNack() {
return getHeaderValue(NACK);
}
public void setReceiptId(String receiptId) {
this.headers.put(RECEIPT_ID, receiptId);
}
public String getReceiptId() {
return getHeaderValue(RECEIPT_ID);
}
public String getMessage() {
return getRawHeaders().get(MESSAGE);
return getHeaderValue(MESSAGE);
}
public void setMessage(String content) {
getRawHeaders().put(MESSAGE, content);
this.headers.put(MESSAGE, content);
}
public String getMessageId() {
return getRawHeaders().get(MESSAGE_ID);
return getHeaderValue(MESSAGE_ID);
}
public void setMessageId(String id) {
getRawHeaders().put(MESSAGE_ID, id);
this.headers.put(MESSAGE_ID, id);
}
public String getVersion() {
return getRawHeaders().get(VERSION);
return getHeaderValue(VERSION);
}
public void setVersion(String version) {
getRawHeaders().put(VERSION, version);
}
/**
* Update generic message headers from raw headers. This method only needs to be
* invoked when raw headers are added via {@link #getRawHeaders()}.
*/
public void updateMessageHeaders() {
String destination = getRawHeaders().get(DESTINATION);
if (destination != null) {
setDestination(destination);
}
String contentType = getRawHeaders().get(CONTENT_TYPE);
if (contentType != null) {
setContentType(MediaType.parseMediaType(contentType));
}
if (StompCommand.SUBSCRIBE.equals(getStompCommand())) {
if (getRawHeaders().get(ID) != null) {
super.setSubscriptionId(getRawHeaders().get(ID));
}
}
this.headers.put(VERSION, version);
}
/**
* Update raw headers from generic message headers. This method only needs to be
* invoked if creating {@link StompHeaders} from {@link MessageHeaders} that never
* contained raw headers.
*/
public void updateRawHeaders() {
String destination = getDestination();
if (destination != null) {
getRawHeaders().put(DESTINATION, destination);
}
MediaType contentType = getContentType();
if (contentType != null) {
getRawHeaders().put(CONTENT_TYPE, contentType.toString());
}
String subscriptionId = getSubscriptionId();
if (subscriptionId != null) {
getRawHeaders().put(SUBSCRIPTION, subscriptionId);
}
if (StompCommand.MESSAGE.equals(getStompCommand()) && (getMessageId() == null)) {
getRawHeaders().put(MESSAGE_ID, getMessageHeaders().get(ID).toString());
}
@Override
public String toString() {
return "StompHeaders [" + "messageType=" + getMessageType() + ", protocolMessageType="
+ getProtocolMessageType() + ", destination=" + getDestination()
+ ", subscriptionId=" + getSubscriptionId() + ", sessionId=" + getSessionId()
+ ", externalSourceHeaders=" + getExternalSourceHeaders() + ", headers=" + this.headers + "]";
}
}

6
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java

@ -21,8 +21,8 @@ import java.util.concurrent.ConcurrentHashMap; @@ -21,8 +21,8 @@ import java.util.concurrent.ConcurrentHashMap;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.stomp.support.StompMessageConverter;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
@ -76,10 +76,10 @@ public abstract class AbstractStompWebSocketHandler extends TextWebSocketHandler @@ -76,10 +76,10 @@ public abstract class AbstractStompWebSocketHandler extends TextWebSocketHandler
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaders stompHeaders = new StompHeaders(StompCommand.ERROR);
StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR);
stompHeaders.setMessage(error.getMessage());
Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.getMessageHeaders());
Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.toMessageHeaders());
byte[] bytes = this.stompMessageConverter.fromMessage(errorMessage);
try {

21
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java

@ -63,10 +63,8 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler @@ -63,10 +63,8 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
@Override
public void accept(Message<?> message) {
StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), false);
if (stompHeaders.getProtocolMessageType() == null) {
stompHeaders.setProtocolMessageType(StompCommand.MESSAGE);
}
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
stompHeaders.setStompCommandIfNotSet(StompCommand.MESSAGE);
if (StompCommand.CONNECTED.equals(stompHeaders.getStompCommand())) {
// Ignore for now since we already sent it
@ -74,6 +72,9 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler @@ -74,6 +72,9 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
}
String sessionId = stompHeaders.getSessionId();
if (sessionId == null) {
logger.error("Cannot send message without a session id: " + message);
}
WebSocketSession session = getWebSocketSession(sessionId);
byte[] payload;
@ -87,7 +88,7 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler @@ -87,7 +88,7 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
}
try {
Map<String, Object> messageHeaders = stompHeaders.getMessageHeaders();
Map<String, Object> messageHeaders = stompHeaders.toMessageHeaders();
Message<byte[]> byteMessage = new GenericMessage<byte[]>(payload, messageHeaders);
byte[] bytes = getStompMessageConverter().fromMessage(byteMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
@ -116,11 +117,11 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler @@ -116,11 +117,11 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
public void handleStompMessage(final WebSocketSession session, Message<byte[]> message) {
if (logger.isTraceEnabled()) {
logger.trace("Processing: " + message);
logger.trace("Processing STOMP message: " + message);
}
try {
StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), true);
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
MessageType messageType = stompHeaders.getMessageType();
if (MessageType.CONNECT.equals(messageType)) {
handleConnect(session, message);
@ -147,8 +148,8 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler @@ -147,8 +148,8 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
protected void handleConnect(final WebSocketSession session, Message<byte[]> message) throws IOException {
StompHeaders connectStompHeaders = new StompHeaders(message.getHeaders(), true);
StompHeaders connectedStompHeaders = new StompHeaders(StompCommand.CONNECTED);
StompHeaders connectStompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
StompHeaders connectedStompHeaders = StompHeaders.create(StompCommand.CONNECTED);
Set<String> acceptVersions = connectStompHeaders.getAcceptVersion();
if (acceptVersions.contains("1.2")) {
@ -167,7 +168,7 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler @@ -167,7 +168,7 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
// TODO: security
Message<byte[]> connectedMessage = new GenericMessage<byte[]>(new byte[0], connectedStompHeaders.getMessageHeaders());
Message<byte[]> connectedMessage = new GenericMessage<byte[]>(new byte[0], connectedStompHeaders.toMessageHeaders());
byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}

19
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java

@ -93,21 +93,20 @@ public class RelayStompService extends AbstractMessageService { @@ -93,21 +93,20 @@ public class RelayStompService extends AbstractMessageService {
private void forwardMessage(Message<?> message, StompCommand command) {
StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), false);
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
String sessionId = stompHeaders.getSessionId();
RelaySession session = RelayStompService.this.relaySessions.get(sessionId);
Assert.notNull(session, "RelaySession not found");
try {
if (stompHeaders.getProtocolMessageType() == null) {
stompHeaders.setProtocolMessageType(StompCommand.SEND);
}
stompHeaders.setStompCommandIfNotSet(StompCommand.SEND);
MediaType contentType = stompHeaders.getContentType();
byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType);
Message<byte[]> byteMessage = new GenericMessage<byte[]>(payload, stompHeaders.getMessageHeaders());
Message<byte[]> byteMessage = new GenericMessage<byte[]>(payload, stompHeaders.toMessageHeaders());
if (logger.isTraceEnabled()) {
logger.trace("Forwarding: " + byteMessage);
logger.trace("Forwarding STOMP " + stompHeaders.getStompCommand() + " message");
}
byte[] bytes = this.stompMessageConverter.fromMessage(byteMessage);
@ -115,7 +114,7 @@ public class RelayStompService extends AbstractMessageService { @@ -115,7 +114,7 @@ public class RelayStompService extends AbstractMessageService {
session.getOutputStream().flush();
}
catch (Exception ex) {
logger.error("Couldn't forward message", ex);
logger.error("Couldn't forward message " + message, ex);
clearRelaySession(sessionId);
}
}
@ -233,13 +232,15 @@ public class RelayStompService extends AbstractMessageService { @@ -233,13 +232,15 @@ public class RelayStompService extends AbstractMessageService {
catch (IOException e) {
logger.error("Socket error: " + e.getMessage());
clearRelaySession(sessionId);
sendErrorMessage("Lost connection");
}
}
private void sendErrorMessage(String message) {
StompHeaders stompHeaders = new StompHeaders(StompCommand.ERROR);
StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR);
stompHeaders.setMessage(message);
Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.getMessageHeaders());
stompHeaders.setSessionId(this.sessionId);
Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.toMessageHeaders());
getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, errorMessage);
}
}

34
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java

@ -18,6 +18,7 @@ package org.springframework.web.messaging.stomp.support; @@ -18,6 +18,7 @@ package org.springframework.web.messaging.stomp.support;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
@ -25,6 +26,8 @@ import org.springframework.messaging.GenericMessage; @@ -25,6 +26,8 @@ import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompConversionException;
import org.springframework.web.messaging.stomp.StompHeaders;
@ -80,15 +83,13 @@ public class StompMessageConverter { @@ -80,15 +83,13 @@ public class StompMessageConverter {
StompCommand command = StompCommand.valueOf(parser.nextToken(LF).trim());
Assert.notNull(command, "No command found");
StompHeaders stompHeaders = new StompHeaders(command);
stompHeaders.setSessionId(sessionId);
MultiValueMap<String, String> headers = new LinkedMultiValueMap<String, String>();
while (parser.hasNext()) {
String header = parser.nextToken(COLON);
if (header != null) {
if (parser.hasNext()) {
String value = parser.nextToken(LF);
stompHeaders.getRawHeaders().put(header, value);
headers.add(header, value);
}
else {
throw new StompConversionException("Parse exception for " + headerContent);
@ -96,12 +97,13 @@ public class StompMessageConverter { @@ -96,12 +97,13 @@ public class StompMessageConverter {
}
}
StompHeaders stompHeaders = StompHeaders.fromParsedFrame(command, headers);
stompHeaders.setSessionId(sessionId);
byte[] payload = new byte[totalLength - payloadIndex];
System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
stompHeaders.updateMessageHeaders();
return createMessage(command, stompHeaders.getMessageHeaders(), payload);
return createMessage(command, stompHeaders.toMessageHeaders(), payload);
}
private int findIndexOfPayload(byte[] bytes) {
@ -138,20 +140,20 @@ public class StompMessageConverter { @@ -138,20 +140,20 @@ public class StompMessageConverter {
public byte[] fromMessage(Message<byte[]> message) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, false);
stompHeaders.updateRawHeaders();
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
try {
out.write(stompHeaders.getStompCommand().toString().getBytes("UTF-8"));
out.write(LF);
for (Entry<String, String> entry : stompHeaders.getRawHeaders().entrySet()) {
for (Entry<String, List<String>> entry : stompHeaders.toStompMessageHeaders().entrySet()) {
String key = entry.getKey();
key = replaceAllOutbound(key);
String value = entry.getValue();
out.write(key.getBytes("UTF-8"));
out.write(COLON);
value = replaceAllOutbound(value);
out.write(value.getBytes("UTF-8"));
out.write(LF);
for (String value : entry.getValue()) {
out.write(key.getBytes("UTF-8"));
out.write(COLON);
value = replaceAllOutbound(value);
out.write(value.getBytes("UTF-8"));
out.write(LF);
}
}
out.write(LF);
out.write(message.getPayload());

24
spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java

@ -22,8 +22,8 @@ import org.junit.Test; @@ -22,8 +22,8 @@ import org.junit.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompHeaders;
import static org.junit.Assert.*;
@ -52,15 +52,17 @@ public class StompMessageConverterTests { @@ -52,15 +52,17 @@ public class StompMessageConverterTests {
assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, true);
assertEquals(6, stompHeaders.getMessageHeaders().size());
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
assertEquals(7, stompHeaders.toMessageHeaders().size());
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getHost());
assertEquals(MessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand());
assertEquals("session-123", stompHeaders.getSessionId());
assertNotNull(messageHeaders.get(MessageHeaders.ID));
assertNotNull(messageHeaders.get(MessageHeaders.TIMESTAMP));
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getRawHeaders().get("host"));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
@ -80,9 +82,9 @@ public class StompMessageConverterTests { @@ -80,9 +82,9 @@ public class StompMessageConverterTests {
assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, true);
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getRawHeaders().get("ho:\ns\rt"));
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getExternalSourceHeaders().get("ho:\ns\rt").get(0));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
@ -102,9 +104,9 @@ public class StompMessageConverterTests { @@ -102,9 +104,9 @@ public class StompMessageConverterTests {
assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, true);
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getRawHeaders().get("host"));
assertEquals("github.org", stompHeaders.getHost());
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
@ -124,9 +126,9 @@ public class StompMessageConverterTests { @@ -124,9 +126,9 @@ public class StompMessageConverterTests {
assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, true);
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getRawHeaders().get("ho:\ns\rt"));
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getExternalSourceHeaders().get("ho:\ns\rt").get(0));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");

Loading…
Cancel
Save