diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java index e2a15998adb..2aebcabc3ea 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java @@ -22,143 +22,220 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.http.MediaType; +import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; /** + * A base class for working with message headers in Web, messaging protocols that support + * the publish-subscribe message pattern. Provides uniform access to specific values + * common across protocols such as a destination, message type (publish, + * subscribe/unsubscribe), session id, and others. + *

+ * This class can be used to prepare headers for a new pub-sub message, or to access + * and/or modify headers of an existing message. + *

+ * Use one of the static factory method in this class, then call getters and setters, and + * at the end if necessary call {@link #toMessageHeaders()} to obtain the updated headers. * * @author Rossen Stoyanchev * @since 4.0 */ public class PubSubHeaders { + protected Log logger = LogFactory.getLog(getClass()); + private static final String DESTINATIONS = "destinations"; private static final String CONTENT_TYPE = "contentType"; private static final String MESSAGE_TYPE = "messageType"; - private static final String SUBSCRIPTION_ID = "subscriptionId"; - private static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType"; private static final String SESSION_ID = "sessionId"; - private static final String RAW_HEADERS = "rawHeaders"; + private static final String SUBSCRIPTION_ID = "subscriptionId"; + + private static final String EXTERNAL_SOURCE_HEADERS = "extSourceHeaders"; - private final Map messageHeaders; + private static final Map> emptyMultiValueMap = + Collections.unmodifiableMap(new LinkedMultiValueMap(0)); + + + // wrapped read-only message headers + private final MessageHeaders originalHeaders; + + // header updates + private final Map headers = new HashMap(4); + + // saved headers from a message from a remote source + private final Map> externalSourceHeaders; - private final Map rawHeaders; /** - * Constructor for building new headers. - * - * @param messageType the message type - * @param protocolMessageType the protocol-specific message type or command + * A constructor for creating new message headers. + * This constructor is protected. See factory methods in this and sub-classes. */ - public PubSubHeaders(MessageType messageType, Object protocolMessageType) { + protected PubSubHeaders(MessageType messageType, Object protocolMessageType, + Map> externalSourceHeaders) { + + this.originalHeaders = null; + + Assert.notNull(messageType, "messageType is required"); + this.headers.put(MESSAGE_TYPE, messageType); - this.messageHeaders = new HashMap(); - this.messageHeaders.put(MESSAGE_TYPE, messageType); if (protocolMessageType != null) { - this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType); + this.headers.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType); } - this.rawHeaders = new HashMap(); - this.messageHeaders.put(RAW_HEADERS, this.rawHeaders); - } - - public PubSubHeaders() { - this(MessageType.MESSAGE, null); + if (externalSourceHeaders == null) { + this.externalSourceHeaders = emptyMultiValueMap; + } + else { + this.externalSourceHeaders = Collections.unmodifiableMap(externalSourceHeaders); // TODO: list values must also be read-only + this.headers.put(EXTERNAL_SOURCE_HEADERS, this.externalSourceHeaders); + } } /** - * Constructor for access to existing {@link MessageHeaders}. - * - * @param messageHeaders + * A constructor for accessing and modifying existing message headers. This + * constructor is protected. See factory methods in this and sub-classes. */ @SuppressWarnings("unchecked") - public PubSubHeaders(MessageHeaders messageHeaders, boolean readOnly) { + protected PubSubHeaders(MessageHeaders originalHeaders) { + this.originalHeaders = originalHeaders; + this.externalSourceHeaders = (originalHeaders.get(EXTERNAL_SOURCE_HEADERS) != null) ? + (Map>) originalHeaders.get(EXTERNAL_SOURCE_HEADERS) : emptyMultiValueMap; + } - this.messageHeaders = readOnly ? messageHeaders : new HashMap(messageHeaders); - this.rawHeaders = this.messageHeaders.containsKey(RAW_HEADERS) ? - (Map) messageHeaders.get(RAW_HEADERS) : Collections.emptyMap(); - if (this.messageHeaders.get(MESSAGE_TYPE) == null) { - this.messageHeaders.put(MESSAGE_TYPE, MessageType.MESSAGE); - } + /** + * Create {@link PubSubHeaders} for a new {@link Message}. + */ + public static PubSubHeaders create() { + return new PubSubHeaders(MessageType.MESSAGE, null, null); + } + + /** + * Create {@link PubSubHeaders} from existing message headers. + */ + public static PubSubHeaders fromMessageHeaders(MessageHeaders originalHeaders) { + return new PubSubHeaders(originalHeaders); } - public Map getMessageHeaders() { - return this.messageHeaders; + /** + * Return the original, wrapped headers (i.e. unmodified) or a new Map including any + * updates made via setters. + */ + public Map toMessageHeaders() { + if (!isModified()) { + return this.originalHeaders; + } + Map result = new HashMap(); + if (this.originalHeaders != null) { + result.putAll(this.originalHeaders); + } + result.putAll(this.headers); + return result; } - public Map getRawHeaders() { - return this.rawHeaders; + public boolean isModified() { + return ((this.originalHeaders == null) || !this.headers.isEmpty()); } public MessageType getMessageType() { - return (MessageType) this.messageHeaders.get(MESSAGE_TYPE); + return (MessageType) getHeaderValue(MESSAGE_TYPE); + } + + private Object getHeaderValue(String headerName) { + if (this.headers.get(headerName) != null) { + return this.headers.get(headerName); + } + else if (this.originalHeaders.get(headerName) != null) { + return this.originalHeaders.get(headerName); + } + return null; } - public void setProtocolMessageType(Object protocolMessageType) { - this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType); + protected void setProtocolMessageType(Object protocolMessageType) { + this.headers.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType); } - public Object getProtocolMessageType() { - return this.messageHeaders.get(PROTOCOL_MESSAGE_TYPE); + protected Object getProtocolMessageType() { + return getHeaderValue(PROTOCOL_MESSAGE_TYPE); } public void setDestination(String destination) { - this.messageHeaders.put(DESTINATIONS, Arrays.asList(destination)); + Assert.notNull(destination, "destination is required"); + this.headers.put(DESTINATIONS, Arrays.asList(destination)); } + @SuppressWarnings("unchecked") public String getDestination() { - @SuppressWarnings("unchecked") - List destination = (List) messageHeaders.get(DESTINATIONS); - return CollectionUtils.isEmpty(destination) ? null : destination.get(0); + List destinations = (List) getHeaderValue(DESTINATIONS); + return CollectionUtils.isEmpty(destinations) ? null : destinations.get(0); } @SuppressWarnings("unchecked") public List getDestinations() { - return (List) messageHeaders.get(DESTINATIONS); + List destinations = (List) getHeaderValue(DESTINATIONS); + return CollectionUtils.isEmpty(destinations) ? null : destinations; } public void setDestinations(List destinations) { - if (destinations != null) { - this.messageHeaders.put(DESTINATIONS, destinations); - } + Assert.notNull(destinations, "destinations are required"); + this.headers.put(DESTINATIONS, destinations); } public MediaType getContentType() { - return (MediaType) this.messageHeaders.get(CONTENT_TYPE); + return (MediaType) getHeaderValue(CONTENT_TYPE); } - public void setContentType(MediaType mediaType) { - if (mediaType != null) { - this.messageHeaders.put(CONTENT_TYPE, mediaType); - } + public void setContentType(MediaType contentType) { + Assert.notNull(contentType, "contentType is required"); + this.headers.put(CONTENT_TYPE, contentType); } public String getSubscriptionId() { - return (String) this.messageHeaders.get(SUBSCRIPTION_ID); + return (String) getHeaderValue(SUBSCRIPTION_ID); } public void setSubscriptionId(String subscriptionId) { - this.messageHeaders.put(SUBSCRIPTION_ID, subscriptionId); + this.headers.put(SUBSCRIPTION_ID, subscriptionId); } public String getSessionId() { - return (String) this.messageHeaders.get(SESSION_ID); + return (String) getHeaderValue(SESSION_ID); } public void setSessionId(String sessionId) { - this.messageHeaders.put(SESSION_ID, sessionId); + this.headers.put(SESSION_ID, sessionId); + } + + /** + * Return a read-only map of headers originating from a message received by the + * application from an external source (e.g. from a remote WebSocket endpoint). The + * header names and values are exactly as they were, and are protocol specific but may + * also be custom application headers if the protocol allows that. + */ + public Map> getExternalSourceHeaders() { + return this.externalSourceHeaders; + } + + @Override + public String toString() { + return "PubSubHeaders [originalHeaders=" + this.originalHeaders + ", headers=" + + this.headers + ", externalSourceHeaders=" + this.externalSourceHeaders + "]"; } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java index c3a4a880472..89816b75b11 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java @@ -72,10 +72,10 @@ public abstract class AbstractMessageService { } if (logger.isTraceEnabled()) { - logger.trace("Processing notification: " + message); + logger.trace("Processing message id=" + message.getHeaders().getId()); } - PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true); + PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); MessageType messageType = headers.getMessageType(); if (messageType == null || messageType.equals(MessageType.OTHER)) { processOther(message); @@ -129,7 +129,7 @@ public abstract class AbstractMessageService { } private boolean isAllowedDestination(Message message) { - PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true); + PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); String destination = headers.getDestination(); if (destination == null) { return true; @@ -138,7 +138,7 @@ public abstract class AbstractMessageService { for (String pattern : this.disallowedDestinations) { if (this.pathMatcher.match(pattern, destination)) { if (logger.isTraceEnabled()) { - logger.trace("Skip notification: " + message); + logger.trace("Skip notification message id=" + message.getHeaders().getId()); } return false; } @@ -151,7 +151,7 @@ public abstract class AbstractMessageService { } } if (logger.isTraceEnabled()) { - logger.trace("Skip notification: " + message); + logger.trace("Skip notification message id=" + message.getHeaders().getId()); } return false; } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java index 963a978368b..b81cdc03a7c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java @@ -62,7 +62,7 @@ public class PubSubMessageService extends AbstractMessageService { try { // Convert to byte[] payload before the fan-out - PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true); + PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders()); byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType()); message = new GenericMessage(payload, message.getHeaders()); @@ -82,19 +82,19 @@ public class PubSubMessageService extends AbstractMessageService { if (logger.isDebugEnabled()) { logger.debug("Subscribe " + message); } - PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true); + PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); final String subscriptionId = headers.getSubscriptionId(); EventRegistration registration = getEventBus().registerConsumer(getPublishKey(headers.getDestination()), new EventConsumer>() { @Override public void accept(Message message) { - PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true); - PubSubHeaders outHeaders = new PubSubHeaders(); + PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders()); + PubSubHeaders outHeaders = PubSubHeaders.create(); outHeaders.setDestinations(inHeaders.getDestinations()); outHeaders.setContentType(inHeaders.getContentType()); outHeaders.setSubscriptionId(subscriptionId); Object payload = message.getPayload(); - message = new GenericMessage(payload, outHeaders.getMessageHeaders()); + message = new GenericMessage(payload, outHeaders.toMessageHeaders()); getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, message); } }); diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java index 55ab59ea86b..f2e7d13c732 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationMessageService.java @@ -167,7 +167,7 @@ public class AnnotationMessageService extends AbstractMessageService implements private void handleMessage(final Message message, Map handlerMethods) { - PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true); + PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); String destination = headers.getDestination(); HandlerMethod match = getHandlerMethod(destination, handlerMethods); diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java index ad16b5f97c0..2196b48ac2d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java @@ -49,16 +49,16 @@ public class MessageChannelArgumentResolver implements ArgumentResolver { @Override public Object resolveArgument(MethodParameter parameter, Message message) throws Exception { - final String sessionId = new PubSubHeaders(message.getHeaders(), true).getSessionId(); + final String sessionId = PubSubHeaders.fromMessageHeaders(message.getHeaders()).getSessionId(); return new MessageChannel() { @Override public boolean send(Message message) { - PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), false); + PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); headers.setSessionId(sessionId); - message = new GenericMessage(message.getPayload(), headers.getMessageHeaders()); + message = new GenericMessage(message.getPayload(), headers.toMessageHeaders()); eventBus.send(AbstractMessageService.CLIENT_TO_SERVER_MESSAGE_KEY, message); diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java index 4a357b94d61..30268979be8 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java @@ -66,15 +66,15 @@ public class MessageReturnValueHandler implements ReturnValueHandler { return; } - PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true); + PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders()); String sessionId = inHeaders.getSessionId(); String subscriptionId = inHeaders.getSubscriptionId(); Assert.notNull(subscriptionId, "No subscription id: " + message); - PubSubHeaders outHeaders = new PubSubHeaders(returnMessage.getHeaders(), false); + PubSubHeaders outHeaders = PubSubHeaders.fromMessageHeaders(returnMessage.getHeaders()); outHeaders.setSessionId(sessionId); outHeaders.setSubscriptionId(subscriptionId); - returnMessage = new GenericMessage(returnMessage.getPayload(), outHeaders.getMessageHeaders()); + returnMessage = new GenericMessage(returnMessage.getPayload(), outHeaders.toMessageHeaders()); this.eventBus.send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, returnMessage); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java index f629809b593..20926a9b4e0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java @@ -17,11 +17,17 @@ package org.springframework.web.messaging.stomp; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import org.springframework.http.MediaType; +import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.messaging.PubSubHeaders; @@ -29,7 +35,12 @@ import reactor.util.Assert; /** - * STOMP adapter for {@link MessageHeaders}. + * Can be used to prepare headers for a new STOMP message, or to access and/or modify + * STOMP-specific headers of an existing message. + *

+ * Use one of the static factory method in this class, then call getters and setters, and + * at the end if necessary call {@link #toMessageHeaders()} to obtain the updated headers + * or call {@link #toStompMessageHeaders()} to obtain only the STOMP-specific headers. * * @author Rossen Stoyanchev * @since 4.0 @@ -54,6 +65,8 @@ public class StompHeaders extends PubSubHeaders { private static final String ACK = "ack"; + private static final String NACK = "nack"; + private static final String DESTINATION = "destination"; private static final String CONTENT_TYPE = "content-type"; @@ -63,30 +76,133 @@ public class StompHeaders extends PubSubHeaders { private static final String HEARTBEAT = "heart-beat"; + private static final String STOMP_HEADERS = "stompHeaders"; + + + private final Map headers; + + + /** + * A constructor for creating new STOMP message headers. + * This constructor is private. See factory methods in this sub-classes. + */ + private StompHeaders(StompCommand command, Map> externalSourceHeaders) { + super(command.getMessageType(), command, externalSourceHeaders); + this.headers = new HashMap(4); + updateMessageHeaders(); + } + + private void updateMessageHeaders() { + if (getExternalSourceHeaders().isEmpty()) { + return; + } + String destination = getHeaderValue(DESTINATION); + if (destination != null) { + super.setDestination(destination); + } + String contentType = getHeaderValue(CONTENT_TYPE); + if (contentType != null) { + super.setContentType(MediaType.parseMediaType(contentType)); + } + if (StompCommand.SUBSCRIBE.equals(getStompCommand())) { + if (getHeaderValue(ID) != null) { + super.setSubscriptionId(getHeaderValue(ID)); + } + } + } + + /** + * A constructor for accessing and modifying existing message headers. This + * constructor is protected. See factory methods in this class. + */ + @SuppressWarnings("unchecked") + private StompHeaders(MessageHeaders messageHeaders) { + super(messageHeaders); + this.headers = (messageHeaders.get(STOMP_HEADERS) != null) ? + (Map) messageHeaders.get(STOMP_HEADERS) : new HashMap(4); + } + + + /** + * Create {@link StompHeaders} for a new {@link Message}. + */ + public static StompHeaders create(StompCommand command) { + return new StompHeaders(command, null); + } + /** - * Constructor for building new headers. - * - * @param command the STOMP command + * Create {@link StompHeaders} from the headers of an existing {@link Message}. */ - public StompHeaders(StompCommand command) { - super(command.getMessageType(), command); + public static StompHeaders fromMessageHeaders(MessageHeaders messageHeaders) { + return new StompHeaders(messageHeaders); } /** - * Constructor for access to existing {@link MessageHeaders}. - * - * @param messageHeaders the existing message headers - * @param readOnly whether the resulting instance will be used for read-only access, - * if {@code true}, then set methods will throw exceptions; if {@code false} - * they will work. + * Create {@link StompHeaders} from parsed STOP frame content. */ - public StompHeaders(MessageHeaders messageHeaders, boolean readOnly) { - super(messageHeaders, readOnly); + public static StompHeaders fromParsedFrame(StompCommand command, Map> headers) { + return new StompHeaders(command, headers); } + + /** + * Return the original, wrapped headers (i.e. unmodified) or a new Map including any + * updates made via setters. + */ @Override - public StompCommand getProtocolMessageType() { - return (StompCommand) super.getProtocolMessageType(); + public Map toMessageHeaders() { + Map result = super.toMessageHeaders(); + if (isModified()) { + result.put(STOMP_HEADERS, this.headers); + } + return result; + } + + @Override + public boolean isModified() { + return (super.isModified() || !this.headers.isEmpty()); + } + + /** + * Return STOMP headers and any custom headers that may have been sent by + * a remote endpoint, if this message originated from outside. + */ + public Map> toStompMessageHeaders() { + + MultiValueMap result = new LinkedMultiValueMap(); + result.putAll(getExternalSourceHeaders()); + result.setAll(this.headers); + + String destination = super.getDestination(); + if (destination != null) { + result.set(DESTINATION, destination); + } + + MediaType contentType = getContentType(); + if (contentType != null) { + result.set(CONTENT_TYPE, contentType.toString()); + } + + if (StompCommand.MESSAGE.equals(getStompCommand())) { + String subscriptionId = getSubscriptionId(); + if (subscriptionId != null) { + result.set(SUBSCRIPTION, subscriptionId); + } + else { + logger.warn("STOMP MESSAGE frame should have a subscription: " + this.toString()); + } + if ((getMessageId() == null)) { + this.headers.put(MESSAGE_ID, toMessageHeaders().get(ID).toString()); + } + } + + return result; + } + + public void setStompCommandIfNotSet(StompCommand command) { + if (getStompCommand() == null) { + setProtocolMessageType(command); + } } public StompCommand getStompCommand() { @@ -94,32 +210,42 @@ public class StompHeaders extends PubSubHeaders { } public Set getAcceptVersion() { - String rawValue = getRawHeaders().get(ACCEPT_VERSION); + String rawValue = getHeaderValue(ACCEPT_VERSION); return (rawValue != null) ? StringUtils.commaDelimitedListToSet(rawValue) : Collections.emptySet(); } + private String getHeaderValue(String headerName) { + List values = getExternalSourceHeaders().get(headerName); + return !CollectionUtils.isEmpty(values) ? values.get(0) : this.headers.get(headerName); + } + public void setAcceptVersion(String acceptVersion) { - getRawHeaders().put(ACCEPT_VERSION, acceptVersion); + this.headers.put(ACCEPT_VERSION, acceptVersion); + } + + public void setHost(String host) { + this.headers.put(HOST, host); + } + + public String getHost() { + return getHeaderValue(HOST); } @Override public void setDestination(String destination) { - if (destination != null) { - super.setDestination(destination); - getRawHeaders().put(DESTINATION, destination); - } + super.setDestination(destination); + this.headers.put(DESTINATION, destination); } @Override public void setDestinations(List destinations) { - if (destinations != null) { - super.setDestinations(destinations); - getRawHeaders().put(DESTINATION, destinations.get(0)); - } + Assert.isTrue((destinations != null) && (destinations.size() == 1), "STOMP allows one destination per message"); + super.setDestinations(destinations); + this.headers.put(DESTINATION, destinations.get(0)); } public long[] getHeartbeat() { - String rawValue = getRawHeaders().get(HEARTBEAT); + String rawValue = getHeaderValue(HEARTBEAT); if (!StringUtils.hasText(rawValue)) { return null; } @@ -131,102 +257,82 @@ public class StompHeaders extends PubSubHeaders { public void setContentType(MediaType mediaType) { if (mediaType != null) { super.setContentType(mediaType); - getRawHeaders().put(CONTENT_TYPE, mediaType.toString()); + this.headers.put(CONTENT_TYPE, mediaType.toString()); } } + public MediaType getContentType() { + String value = getHeaderValue(CONTENT_TYPE); + return (value != null) ? MediaType.parseMediaType(value) : null; + } + public Integer getContentLength() { - String contentLength = getRawHeaders().get(CONTENT_LENGTH); + String contentLength = getHeaderValue(CONTENT_LENGTH); return StringUtils.hasText(contentLength) ? new Integer(contentLength) : null; } public void setContentLength(int contentLength) { - getRawHeaders().put(CONTENT_LENGTH, String.valueOf(contentLength)); + this.headers.put(CONTENT_LENGTH, String.valueOf(contentLength)); } - @Override - public String getSubscriptionId() { - return StompCommand.SUBSCRIBE.equals(getStompCommand()) ? getRawHeaders().get(ID) : null; + public void setHeartbeat(long cx, long cy) { + this.headers.put(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy})); } - @Override - public void setSubscriptionId(String subscriptionId) { - Assert.isTrue(StompCommand.MESSAGE.equals(getStompCommand()), - "\"subscription\" can only be set on a STOMP MESSAGE frame"); - super.setSubscriptionId(subscriptionId); - getRawHeaders().put(SUBSCRIPTION, subscriptionId); + public void setAck(String ack) { + this.headers.put(ACK, ack); } - public void setHeartbeat(long cx, long cy) { - getRawHeaders().put(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy})); + public String getAck() { + return getHeaderValue(ACK); + } + + public void setNack(String nack) { + this.headers.put(NACK, nack); + } + + public String getNack() { + return getHeaderValue(NACK); + } + + public void setReceiptId(String receiptId) { + this.headers.put(RECEIPT_ID, receiptId); + } + + public String getReceiptId() { + return getHeaderValue(RECEIPT_ID); } public String getMessage() { - return getRawHeaders().get(MESSAGE); + return getHeaderValue(MESSAGE); } public void setMessage(String content) { - getRawHeaders().put(MESSAGE, content); + this.headers.put(MESSAGE, content); } public String getMessageId() { - return getRawHeaders().get(MESSAGE_ID); + return getHeaderValue(MESSAGE_ID); } public void setMessageId(String id) { - getRawHeaders().put(MESSAGE_ID, id); + this.headers.put(MESSAGE_ID, id); } public String getVersion() { - return getRawHeaders().get(VERSION); + return getHeaderValue(VERSION); } public void setVersion(String version) { - getRawHeaders().put(VERSION, version); - } - - - /** - * Update generic message headers from raw headers. This method only needs to be - * invoked when raw headers are added via {@link #getRawHeaders()}. - */ - public void updateMessageHeaders() { - String destination = getRawHeaders().get(DESTINATION); - if (destination != null) { - setDestination(destination); - } - String contentType = getRawHeaders().get(CONTENT_TYPE); - if (contentType != null) { - setContentType(MediaType.parseMediaType(contentType)); - } - if (StompCommand.SUBSCRIBE.equals(getStompCommand())) { - if (getRawHeaders().get(ID) != null) { - super.setSubscriptionId(getRawHeaders().get(ID)); - } - } + this.headers.put(VERSION, version); } - /** - * Update raw headers from generic message headers. This method only needs to be - * invoked if creating {@link StompHeaders} from {@link MessageHeaders} that never - * contained raw headers. - */ - public void updateRawHeaders() { - String destination = getDestination(); - if (destination != null) { - getRawHeaders().put(DESTINATION, destination); - } - MediaType contentType = getContentType(); - if (contentType != null) { - getRawHeaders().put(CONTENT_TYPE, contentType.toString()); - } - String subscriptionId = getSubscriptionId(); - if (subscriptionId != null) { - getRawHeaders().put(SUBSCRIPTION, subscriptionId); - } - if (StompCommand.MESSAGE.equals(getStompCommand()) && (getMessageId() == null)) { - getRawHeaders().put(MESSAGE_ID, getMessageHeaders().get(ID).toString()); - } + @Override + public String toString() { + return "StompHeaders [" + "messageType=" + getMessageType() + ", protocolMessageType=" + + getProtocolMessageType() + ", destination=" + getDestination() + + ", subscriptionId=" + getSubscriptionId() + ", sessionId=" + getSessionId() + + ", externalSourceHeaders=" + getExternalSourceHeaders() + ", headers=" + this.headers + "]"; } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java index d8e66221446..3a8eb45f26c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/AbstractStompWebSocketHandler.java @@ -21,8 +21,8 @@ import java.util.concurrent.ConcurrentHashMap; import org.springframework.messaging.GenericMessage; import org.springframework.messaging.Message; -import org.springframework.web.messaging.stomp.StompHeaders; import org.springframework.web.messaging.stomp.StompCommand; +import org.springframework.web.messaging.stomp.StompHeaders; import org.springframework.web.messaging.stomp.support.StompMessageConverter; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; @@ -76,10 +76,10 @@ public abstract class AbstractStompWebSocketHandler extends TextWebSocketHandler protected void sendErrorMessage(WebSocketSession session, Throwable error) { - StompHeaders stompHeaders = new StompHeaders(StompCommand.ERROR); + StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR); stompHeaders.setMessage(error.getMessage()); - Message errorMessage = new GenericMessage(new byte[0], stompHeaders.getMessageHeaders()); + Message errorMessage = new GenericMessage(new byte[0], stompHeaders.toMessageHeaders()); byte[] bytes = this.stompMessageConverter.fromMessage(errorMessage); try { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java index a21cf4d0744..fe4a082eb19 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/socket/DefaultStompWebSocketHandler.java @@ -63,10 +63,8 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler @Override public void accept(Message message) { - StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), false); - if (stompHeaders.getProtocolMessageType() == null) { - stompHeaders.setProtocolMessageType(StompCommand.MESSAGE); - } + StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); + stompHeaders.setStompCommandIfNotSet(StompCommand.MESSAGE); if (StompCommand.CONNECTED.equals(stompHeaders.getStompCommand())) { // Ignore for now since we already sent it @@ -74,6 +72,9 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler } String sessionId = stompHeaders.getSessionId(); + if (sessionId == null) { + logger.error("Cannot send message without a session id: " + message); + } WebSocketSession session = getWebSocketSession(sessionId); byte[] payload; @@ -87,7 +88,7 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler } try { - Map messageHeaders = stompHeaders.getMessageHeaders(); + Map messageHeaders = stompHeaders.toMessageHeaders(); Message byteMessage = new GenericMessage(payload, messageHeaders); byte[] bytes = getStompMessageConverter().fromMessage(byteMessage); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); @@ -116,11 +117,11 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler public void handleStompMessage(final WebSocketSession session, Message message) { if (logger.isTraceEnabled()) { - logger.trace("Processing: " + message); + logger.trace("Processing STOMP message: " + message); } try { - StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), true); + StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); MessageType messageType = stompHeaders.getMessageType(); if (MessageType.CONNECT.equals(messageType)) { handleConnect(session, message); @@ -147,8 +148,8 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler protected void handleConnect(final WebSocketSession session, Message message) throws IOException { - StompHeaders connectStompHeaders = new StompHeaders(message.getHeaders(), true); - StompHeaders connectedStompHeaders = new StompHeaders(StompCommand.CONNECTED); + StompHeaders connectStompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); + StompHeaders connectedStompHeaders = StompHeaders.create(StompCommand.CONNECTED); Set acceptVersions = connectStompHeaders.getAcceptVersion(); if (acceptVersions.contains("1.2")) { @@ -167,7 +168,7 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler // TODO: security - Message connectedMessage = new GenericMessage(new byte[0], connectedStompHeaders.getMessageHeaders()); + Message connectedMessage = new GenericMessage(new byte[0], connectedStompHeaders.toMessageHeaders()); byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java index bbe024514ea..d0c256a3ad3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/RelayStompService.java @@ -93,21 +93,20 @@ public class RelayStompService extends AbstractMessageService { private void forwardMessage(Message message, StompCommand command) { - StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), false); + StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); String sessionId = stompHeaders.getSessionId(); RelaySession session = RelayStompService.this.relaySessions.get(sessionId); Assert.notNull(session, "RelaySession not found"); try { - if (stompHeaders.getProtocolMessageType() == null) { - stompHeaders.setProtocolMessageType(StompCommand.SEND); - } + stompHeaders.setStompCommandIfNotSet(StompCommand.SEND); + MediaType contentType = stompHeaders.getContentType(); byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType); - Message byteMessage = new GenericMessage(payload, stompHeaders.getMessageHeaders()); + Message byteMessage = new GenericMessage(payload, stompHeaders.toMessageHeaders()); if (logger.isTraceEnabled()) { - logger.trace("Forwarding: " + byteMessage); + logger.trace("Forwarding STOMP " + stompHeaders.getStompCommand() + " message"); } byte[] bytes = this.stompMessageConverter.fromMessage(byteMessage); @@ -115,7 +114,7 @@ public class RelayStompService extends AbstractMessageService { session.getOutputStream().flush(); } catch (Exception ex) { - logger.error("Couldn't forward message", ex); + logger.error("Couldn't forward message " + message, ex); clearRelaySession(sessionId); } } @@ -233,13 +232,15 @@ public class RelayStompService extends AbstractMessageService { catch (IOException e) { logger.error("Socket error: " + e.getMessage()); clearRelaySession(sessionId); + sendErrorMessage("Lost connection"); } } private void sendErrorMessage(String message) { - StompHeaders stompHeaders = new StompHeaders(StompCommand.ERROR); + StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR); stompHeaders.setMessage(message); - Message errorMessage = new GenericMessage(new byte[0], stompHeaders.getMessageHeaders()); + stompHeaders.setSessionId(this.sessionId); + Message errorMessage = new GenericMessage(new byte[0], stompHeaders.toMessageHeaders()); getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, errorMessage); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java index a8453307ce6..0c7376ddb40 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java @@ -18,6 +18,7 @@ package org.springframework.web.messaging.stomp.support; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.Charset; +import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -25,6 +26,8 @@ import org.springframework.messaging.GenericMessage; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.web.messaging.stomp.StompCommand; import org.springframework.web.messaging.stomp.StompConversionException; import org.springframework.web.messaging.stomp.StompHeaders; @@ -80,15 +83,13 @@ public class StompMessageConverter { StompCommand command = StompCommand.valueOf(parser.nextToken(LF).trim()); Assert.notNull(command, "No command found"); - StompHeaders stompHeaders = new StompHeaders(command); - stompHeaders.setSessionId(sessionId); - + MultiValueMap headers = new LinkedMultiValueMap(); while (parser.hasNext()) { String header = parser.nextToken(COLON); if (header != null) { if (parser.hasNext()) { String value = parser.nextToken(LF); - stompHeaders.getRawHeaders().put(header, value); + headers.add(header, value); } else { throw new StompConversionException("Parse exception for " + headerContent); @@ -96,12 +97,13 @@ public class StompMessageConverter { } } + StompHeaders stompHeaders = StompHeaders.fromParsedFrame(command, headers); + stompHeaders.setSessionId(sessionId); + byte[] payload = new byte[totalLength - payloadIndex]; System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex); - stompHeaders.updateMessageHeaders(); - - return createMessage(command, stompHeaders.getMessageHeaders(), payload); + return createMessage(command, stompHeaders.toMessageHeaders(), payload); } private int findIndexOfPayload(byte[] bytes) { @@ -138,20 +140,20 @@ public class StompMessageConverter { public byte[] fromMessage(Message message) { ByteArrayOutputStream out = new ByteArrayOutputStream(); MessageHeaders messageHeaders = message.getHeaders(); - StompHeaders stompHeaders = new StompHeaders(messageHeaders, false); - stompHeaders.updateRawHeaders(); + StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders); try { out.write(stompHeaders.getStompCommand().toString().getBytes("UTF-8")); out.write(LF); - for (Entry entry : stompHeaders.getRawHeaders().entrySet()) { + for (Entry> entry : stompHeaders.toStompMessageHeaders().entrySet()) { String key = entry.getKey(); key = replaceAllOutbound(key); - String value = entry.getValue(); - out.write(key.getBytes("UTF-8")); - out.write(COLON); - value = replaceAllOutbound(value); - out.write(value.getBytes("UTF-8")); - out.write(LF); + for (String value : entry.getValue()) { + out.write(key.getBytes("UTF-8")); + out.write(COLON); + value = replaceAllOutbound(value); + out.write(value.getBytes("UTF-8")); + out.write(LF); + } } out.write(LF); out.write(message.getPayload()); diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java index f359100c954..e2eaed097f5 100644 --- a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java @@ -22,8 +22,8 @@ import org.junit.Test; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.web.messaging.MessageType; -import org.springframework.web.messaging.stomp.StompHeaders; import org.springframework.web.messaging.stomp.StompCommand; +import org.springframework.web.messaging.stomp.StompHeaders; import static org.junit.Assert.*; @@ -52,15 +52,17 @@ public class StompMessageConverterTests { assertEquals(0, message.getPayload().length); MessageHeaders messageHeaders = message.getHeaders(); - StompHeaders stompHeaders = new StompHeaders(messageHeaders, true); - assertEquals(6, stompHeaders.getMessageHeaders().size()); + StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders); + assertEquals(7, stompHeaders.toMessageHeaders().size()); + + assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); + assertEquals("github.org", stompHeaders.getHost()); + assertEquals(MessageType.CONNECT, stompHeaders.getMessageType()); assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand()); assertEquals("session-123", stompHeaders.getSessionId()); assertNotNull(messageHeaders.get(MessageHeaders.ID)); assertNotNull(messageHeaders.get(MessageHeaders.TIMESTAMP)); - assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); - assertEquals("github.org", stompHeaders.getRawHeaders().get("host")); String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); @@ -80,9 +82,9 @@ public class StompMessageConverterTests { assertEquals(0, message.getPayload().length); MessageHeaders messageHeaders = message.getHeaders(); - StompHeaders stompHeaders = new StompHeaders(messageHeaders, true); + StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders); assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); - assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getRawHeaders().get("ho:\ns\rt")); + assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getExternalSourceHeaders().get("ho:\ns\rt").get(0)); String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); @@ -102,9 +104,9 @@ public class StompMessageConverterTests { assertEquals(0, message.getPayload().length); MessageHeaders messageHeaders = message.getHeaders(); - StompHeaders stompHeaders = new StompHeaders(messageHeaders, true); + StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders); assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion()); - assertEquals("github.org", stompHeaders.getRawHeaders().get("host")); + assertEquals("github.org", stompHeaders.getHost()); String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); @@ -124,9 +126,9 @@ public class StompMessageConverterTests { assertEquals(0, message.getPayload().length); MessageHeaders messageHeaders = message.getHeaders(); - StompHeaders stompHeaders = new StompHeaders(messageHeaders, true); + StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders); assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); - assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getRawHeaders().get("ho:\ns\rt")); + assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getExternalSourceHeaders().get("ho:\ns\rt").get(0)); String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");