diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java
index e2a15998adb..2aebcabc3ea 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java
@@ -22,143 +22,220 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.springframework.http.MediaType;
+import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
+import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
+import org.springframework.util.LinkedMultiValueMap;
/**
+ * A base class for working with message headers in Web, messaging protocols that support
+ * the publish-subscribe message pattern. Provides uniform access to specific values
+ * common across protocols such as a destination, message type (publish,
+ * subscribe/unsubscribe), session id, and others.
+ *
+ * This class can be used to prepare headers for a new pub-sub message, or to access
+ * and/or modify headers of an existing message.
+ *
+ * Use one of the static factory method in this class, then call getters and setters, and
+ * at the end if necessary call {@link #toMessageHeaders()} to obtain the updated headers.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class PubSubHeaders {
+ protected Log logger = LogFactory.getLog(getClass());
+
private static final String DESTINATIONS = "destinations";
private static final String CONTENT_TYPE = "contentType";
private static final String MESSAGE_TYPE = "messageType";
- private static final String SUBSCRIPTION_ID = "subscriptionId";
-
private static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType";
private static final String SESSION_ID = "sessionId";
- private static final String RAW_HEADERS = "rawHeaders";
+ private static final String SUBSCRIPTION_ID = "subscriptionId";
+
+ private static final String EXTERNAL_SOURCE_HEADERS = "extSourceHeaders";
- private final Map messageHeaders;
+ private static final Map> emptyMultiValueMap =
+ Collections.unmodifiableMap(new LinkedMultiValueMap(0));
+
+
+ // wrapped read-only message headers
+ private final MessageHeaders originalHeaders;
+
+ // header updates
+ private final Map headers = new HashMap(4);
+
+ // saved headers from a message from a remote source
+ private final Map> externalSourceHeaders;
- private final Map rawHeaders;
/**
- * Constructor for building new headers.
- *
- * @param messageType the message type
- * @param protocolMessageType the protocol-specific message type or command
+ * A constructor for creating new message headers.
+ * This constructor is protected. See factory methods in this and sub-classes.
*/
- public PubSubHeaders(MessageType messageType, Object protocolMessageType) {
+ protected PubSubHeaders(MessageType messageType, Object protocolMessageType,
+ Map> externalSourceHeaders) {
+
+ this.originalHeaders = null;
+
+ Assert.notNull(messageType, "messageType is required");
+ this.headers.put(MESSAGE_TYPE, messageType);
- this.messageHeaders = new HashMap();
- this.messageHeaders.put(MESSAGE_TYPE, messageType);
if (protocolMessageType != null) {
- this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
+ this.headers.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
}
- this.rawHeaders = new HashMap();
- this.messageHeaders.put(RAW_HEADERS, this.rawHeaders);
- }
-
- public PubSubHeaders() {
- this(MessageType.MESSAGE, null);
+ if (externalSourceHeaders == null) {
+ this.externalSourceHeaders = emptyMultiValueMap;
+ }
+ else {
+ this.externalSourceHeaders = Collections.unmodifiableMap(externalSourceHeaders); // TODO: list values must also be read-only
+ this.headers.put(EXTERNAL_SOURCE_HEADERS, this.externalSourceHeaders);
+ }
}
/**
- * Constructor for access to existing {@link MessageHeaders}.
- *
- * @param messageHeaders
+ * A constructor for accessing and modifying existing message headers. This
+ * constructor is protected. See factory methods in this and sub-classes.
*/
@SuppressWarnings("unchecked")
- public PubSubHeaders(MessageHeaders messageHeaders, boolean readOnly) {
+ protected PubSubHeaders(MessageHeaders originalHeaders) {
+ this.originalHeaders = originalHeaders;
+ this.externalSourceHeaders = (originalHeaders.get(EXTERNAL_SOURCE_HEADERS) != null) ?
+ (Map>) originalHeaders.get(EXTERNAL_SOURCE_HEADERS) : emptyMultiValueMap;
+ }
- this.messageHeaders = readOnly ? messageHeaders : new HashMap(messageHeaders);
- this.rawHeaders = this.messageHeaders.containsKey(RAW_HEADERS) ?
- (Map) messageHeaders.get(RAW_HEADERS) : Collections.emptyMap();
- if (this.messageHeaders.get(MESSAGE_TYPE) == null) {
- this.messageHeaders.put(MESSAGE_TYPE, MessageType.MESSAGE);
- }
+ /**
+ * Create {@link PubSubHeaders} for a new {@link Message}.
+ */
+ public static PubSubHeaders create() {
+ return new PubSubHeaders(MessageType.MESSAGE, null, null);
+ }
+
+ /**
+ * Create {@link PubSubHeaders} from existing message headers.
+ */
+ public static PubSubHeaders fromMessageHeaders(MessageHeaders originalHeaders) {
+ return new PubSubHeaders(originalHeaders);
}
- public Map getMessageHeaders() {
- return this.messageHeaders;
+ /**
+ * Return the original, wrapped headers (i.e. unmodified) or a new Map including any
+ * updates made via setters.
+ */
+ public Map toMessageHeaders() {
+ if (!isModified()) {
+ return this.originalHeaders;
+ }
+ Map result = new HashMap();
+ if (this.originalHeaders != null) {
+ result.putAll(this.originalHeaders);
+ }
+ result.putAll(this.headers);
+ return result;
}
- public Map getRawHeaders() {
- return this.rawHeaders;
+ public boolean isModified() {
+ return ((this.originalHeaders == null) || !this.headers.isEmpty());
}
public MessageType getMessageType() {
- return (MessageType) this.messageHeaders.get(MESSAGE_TYPE);
+ return (MessageType) getHeaderValue(MESSAGE_TYPE);
+ }
+
+ private Object getHeaderValue(String headerName) {
+ if (this.headers.get(headerName) != null) {
+ return this.headers.get(headerName);
+ }
+ else if (this.originalHeaders.get(headerName) != null) {
+ return this.originalHeaders.get(headerName);
+ }
+ return null;
}
- public void setProtocolMessageType(Object protocolMessageType) {
- this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
+ protected void setProtocolMessageType(Object protocolMessageType) {
+ this.headers.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
}
- public Object getProtocolMessageType() {
- return this.messageHeaders.get(PROTOCOL_MESSAGE_TYPE);
+ protected Object getProtocolMessageType() {
+ return getHeaderValue(PROTOCOL_MESSAGE_TYPE);
}
public void setDestination(String destination) {
- this.messageHeaders.put(DESTINATIONS, Arrays.asList(destination));
+ Assert.notNull(destination, "destination is required");
+ this.headers.put(DESTINATIONS, Arrays.asList(destination));
}
+ @SuppressWarnings("unchecked")
public String getDestination() {
- @SuppressWarnings("unchecked")
- List destination = (List) messageHeaders.get(DESTINATIONS);
- return CollectionUtils.isEmpty(destination) ? null : destination.get(0);
+ List destinations = (List) getHeaderValue(DESTINATIONS);
+ return CollectionUtils.isEmpty(destinations) ? null : destinations.get(0);
}
@SuppressWarnings("unchecked")
public List getDestinations() {
- return (List) messageHeaders.get(DESTINATIONS);
+ List destinations = (List) getHeaderValue(DESTINATIONS);
+ return CollectionUtils.isEmpty(destinations) ? null : destinations;
}
public void setDestinations(List destinations) {
- if (destinations != null) {
- this.messageHeaders.put(DESTINATIONS, destinations);
- }
+ Assert.notNull(destinations, "destinations are required");
+ this.headers.put(DESTINATIONS, destinations);
}
public MediaType getContentType() {
- return (MediaType) this.messageHeaders.get(CONTENT_TYPE);
+ return (MediaType) getHeaderValue(CONTENT_TYPE);
}
- public void setContentType(MediaType mediaType) {
- if (mediaType != null) {
- this.messageHeaders.put(CONTENT_TYPE, mediaType);
- }
+ public void setContentType(MediaType contentType) {
+ Assert.notNull(contentType, "contentType is required");
+ this.headers.put(CONTENT_TYPE, contentType);
}
public String getSubscriptionId() {
- return (String) this.messageHeaders.get(SUBSCRIPTION_ID);
+ return (String) getHeaderValue(SUBSCRIPTION_ID);
}
public void setSubscriptionId(String subscriptionId) {
- this.messageHeaders.put(SUBSCRIPTION_ID, subscriptionId);
+ this.headers.put(SUBSCRIPTION_ID, subscriptionId);
}
public String getSessionId() {
- return (String) this.messageHeaders.get(SESSION_ID);
+ return (String) getHeaderValue(SESSION_ID);
}
public void setSessionId(String sessionId) {
- this.messageHeaders.put(SESSION_ID, sessionId);
+ this.headers.put(SESSION_ID, sessionId);
+ }
+
+ /**
+ * Return a read-only map of headers originating from a message received by the
+ * application from an external source (e.g. from a remote WebSocket endpoint). The
+ * header names and values are exactly as they were, and are protocol specific but may
+ * also be custom application headers if the protocol allows that.
+ */
+ public Map> getExternalSourceHeaders() {
+ return this.externalSourceHeaders;
+ }
+
+ @Override
+ public String toString() {
+ return "PubSubHeaders [originalHeaders=" + this.originalHeaders + ", headers="
+ + this.headers + ", externalSourceHeaders=" + this.externalSourceHeaders + "]";
}
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java
index c3a4a880472..89816b75b11 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractMessageService.java
@@ -72,10 +72,10 @@ public abstract class AbstractMessageService {
}
if (logger.isTraceEnabled()) {
- logger.trace("Processing notification: " + message);
+ logger.trace("Processing message id=" + message.getHeaders().getId());
}
- PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
+ PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
MessageType messageType = headers.getMessageType();
if (messageType == null || messageType.equals(MessageType.OTHER)) {
processOther(message);
@@ -129,7 +129,7 @@ public abstract class AbstractMessageService {
}
private boolean isAllowedDestination(Message> message) {
- PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
+ PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
String destination = headers.getDestination();
if (destination == null) {
return true;
@@ -138,7 +138,7 @@ public abstract class AbstractMessageService {
for (String pattern : this.disallowedDestinations) {
if (this.pathMatcher.match(pattern, destination)) {
if (logger.isTraceEnabled()) {
- logger.trace("Skip notification: " + message);
+ logger.trace("Skip notification message id=" + message.getHeaders().getId());
}
return false;
}
@@ -151,7 +151,7 @@ public abstract class AbstractMessageService {
}
}
if (logger.isTraceEnabled()) {
- logger.trace("Skip notification: " + message);
+ logger.trace("Skip notification message id=" + message.getHeaders().getId());
}
return false;
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java
index 963a978368b..b81cdc03a7c 100644
--- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java
+++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/PubSubMessageService.java
@@ -62,7 +62,7 @@ public class PubSubMessageService extends AbstractMessageService {
try {
// Convert to byte[] payload before the fan-out
- PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
+ PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType());
message = new GenericMessage(payload, message.getHeaders());
@@ -82,19 +82,19 @@ public class PubSubMessageService extends AbstractMessageService {
if (logger.isDebugEnabled()) {
logger.debug("Subscribe " + message);
}
- PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
+ PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
final String subscriptionId = headers.getSubscriptionId();
EventRegistration registration = getEventBus().registerConsumer(getPublishKey(headers.getDestination()),
new EventConsumer>() {
@Override
public void accept(Message> message) {
- PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
- PubSubHeaders outHeaders = new PubSubHeaders();
+ PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
+ PubSubHeaders outHeaders = PubSubHeaders.create();
outHeaders.setDestinations(inHeaders.getDestinations());
outHeaders.setContentType(inHeaders.getContentType());
outHeaders.setSubscriptionId(subscriptionId);
Object payload = message.getPayload();
- message = new GenericMessage