diff --git a/spring-messaging/src/main/java/org/springframework/messaging/converter/AbstractMessageConverter.java b/spring-messaging/src/main/java/org/springframework/messaging/converter/AbstractMessageConverter.java index d03d70cdcdd..4267a6abb8b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/converter/AbstractMessageConverter.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/converter/AbstractMessageConverter.java @@ -27,6 +27,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.Assert; import org.springframework.util.MimeType; @@ -194,18 +195,27 @@ public abstract class AbstractMessageConverter implements MessageConverter { @Override public final Message toMessage(Object payload, MessageHeaders headers) { + if (!canConvertTo(payload, headers)) { return null; } + payload = convertToInternal(payload, headers); + MimeType mimeType = getDefaultContentType(payload); + + if (headers != null) { + MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(headers, MessageHeaderAccessor.class); + if (accessor != null && accessor.isMutable()) { + accessor.setHeaderIfAbsent(MessageHeaders.CONTENT_TYPE, mimeType); + return MessageBuilder.createMessage(payload, accessor.getMessageHeaders()); + } + } + MessageBuilder builder = MessageBuilder.withPayload(payload); if (headers != null) { builder.copyHeaders(headers); } - MimeType mimeType = getDefaultContentType(payload); - if (mimeType != null) { - builder.setHeaderIfAbsent(MessageHeaders.CONTENT_TYPE, mimeType); - } + builder.setHeaderIfAbsent(MessageHeaders.CONTENT_TYPE, mimeType); return builder.build(); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/converter/SimpleMessageConverter.java b/spring-messaging/src/main/java/org/springframework/messaging/converter/SimpleMessageConverter.java index 28a4f373571..59cbf65aeaa 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/converter/SimpleMessageConverter.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/converter/SimpleMessageConverter.java @@ -19,6 +19,7 @@ package org.springframework.messaging.converter; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.ClassUtils; /** @@ -44,7 +45,16 @@ public class SimpleMessageConverter implements MessageConverter { @Override public Message toMessage(Object payload, MessageHeaders headers) { - return (payload != null ? MessageBuilder.withPayload(payload).copyHeaders(headers).build() : null); + if (payload == null) { + return null; + } + if (headers != null) { + MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(headers, MessageHeaderAccessor.class); + if (accessor != null && accessor.isMutable()) { + return MessageBuilder.createMessage(payload, accessor.getMessageHeaders()); + } + } + return MessageBuilder.withPayload(payload).copyHeaders(headers).build(); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/core/AbstractMessageSendingTemplate.java b/spring-messaging/src/main/java/org/springframework/messaging/core/AbstractMessageSendingTemplate.java index a2c82570b03..cc3bc460ad4 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/core/AbstractMessageSendingTemplate.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/core/AbstractMessageSendingTemplate.java @@ -26,7 +26,6 @@ import org.springframework.messaging.MessagingException; import org.springframework.messaging.converter.MessageConversionException; import org.springframework.messaging.converter.MessageConverter; import org.springframework.messaging.converter.SimpleMessageConverter; -import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.Assert; /** @@ -109,8 +108,7 @@ public abstract class AbstractMessageSendingTemplate implements MessageSendin @Override public void convertAndSend(D destination, Object payload, Map headers) throws MessagingException { - MessagePostProcessor postProcessor = null; - this.convertAndSend(destination, payload, headers, postProcessor); + this.convertAndSend(destination, payload, headers, null); } @Override @@ -122,45 +120,49 @@ public abstract class AbstractMessageSendingTemplate implements MessageSendin public void convertAndSend(D destination, Object payload, MessagePostProcessor postProcessor) throws MessagingException { - Map headers = null; - this.convertAndSend(destination, payload, headers, postProcessor); + this.convertAndSend(destination, payload, null, postProcessor); } @Override public void convertAndSend(D destination, Object payload, Map headers, MessagePostProcessor postProcessor) throws MessagingException { - headers = processHeadersToSend(headers); - - MessageHeaders messageHeaders; - if (headers != null && headers instanceof MessageHeaders) { - MessageHeaderAccessor.getAccessor() + MessageHeaders messageHeaders = null; + headers = processHeadersToSend(headers); + if (headers != null) { + if (headers instanceof MessageHeaders) { + messageHeaders = (MessageHeaders) headers; + } + else { + messageHeaders = new MessageHeaders(headers); + } } - MessageHeaders messageHeaders = (headers != null) ? new MessageHeaders(headers) : null; Message message = this.converter.toMessage(payload, messageHeaders); if (message == null) { String payloadType = (payload != null) ? payload.getClass().getName() : null; + Object contentType = (messageHeaders != null) ? messageHeaders.get(MessageHeaders.CONTENT_TYPE) : null; throw new MessageConversionException("Unable to convert payload type '" - + payloadType + "', Content-Type=" + messageHeaders.get(MessageHeaders.CONTENT_TYPE) - + ", converter=" + this.converter, null); + + payloadType + "', Content-Type=" + contentType + ", converter=" + this.converter, null); } if (postProcessor != null) { message = postProcessor.postProcessMessage(message); } + this.send(destination, message); } /** - * Provides access to the map of headers before a send operation. - * Implementations can modify the headers by returning a different map. - * This implementation returns the map that was passed in (i.e. without any changes). + * Provides access to the map of input headers before a send operation. Sub-classes + * can modify the headers and then return the same or a different map. + * + *

This default implementation in this class returns the input map. * - * @param headers the headers to send, possibly {@code null} - * @return the actual headers to send + * @param headers the headers to send or {@code null}. + * @return the actual headers to send or {@code null}. */ protected Map processHeadersToSend(Map headers) { return headers; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/core/GenericMessagingTemplate.java b/spring-messaging/src/main/java/org/springframework/messaging/core/GenericMessagingTemplate.java index 97dd813ee68..059534ac88b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/core/GenericMessagingTemplate.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/core/GenericMessagingTemplate.java @@ -31,6 +31,7 @@ import org.springframework.messaging.MessageDeliveryException; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.PollableChannel; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.Assert; /** @@ -110,6 +111,11 @@ public class GenericMessagingTemplate extends AbstractDestinationResolvingMessag Assert.notNull(channel, "channel must not be null"); + MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class); + if (accessor != null && accessor.isMutable()) { + accessor.setImmutable(); + } + long timeout = this.sendTimeout; boolean sent = (timeout >= 0) ? channel.send(message, timeout) : channel.send(message); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/AbstractMethodMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/AbstractMethodMessageHandler.java index 62912cf7df3..da9f2ce76ca 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/AbstractMethodMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/AbstractMethodMessageHandler.java @@ -322,6 +322,7 @@ public abstract class AbstractMethodMessageHandler @Override public void handleMessage(Message message) throws MessagingException { + String destination = getDestination(message); if (destination == null) { logger.trace("Ignoring message, no destination"); @@ -342,9 +343,11 @@ public abstract class AbstractMethodMessageHandler MessageHeaderAccessor headerAccessor = MessageHeaderAccessor.getMutableAccessor(message); headerAccessor.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, lookupDestination); + headerAccessor.setLeaveMutable(true); message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders()); handleMessageInternal(message, lookupDestination); + headerAccessor.setImmutable(); } protected abstract String getDestination(Message message); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageSendingOperations.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageSendingOperations.java index 156bbc7d5d2..70d5470a79f 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageSendingOperations.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageSendingOperations.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,10 @@ import org.springframework.messaging.core.MessageSendingOperations; /** * A specialization of {@link MessageSendingOperations} with methods for use with - * the Spring Framework support for simple messaging protocols (like STOMP). + * the Spring Framework support for Simple Messaging Protocols (like STOMP). + * + *

For more on user destinations see + * {@link org.springframework.messaging.simp.user.UserDestinationResolver}. * * @author Rossen Stoyanchev * @since 4.0 @@ -32,7 +35,8 @@ import org.springframework.messaging.core.MessageSendingOperations; public interface SimpMessageSendingOperations extends MessageSendingOperations { /** - * Send a message to a specific user. + * Send a message to the given user. + * * @param user the user that should receive the message. * @param destination the destination to send the message to. * @param payload the payload to send @@ -40,27 +44,62 @@ public interface SimpMessageSendingOperations extends MessageSendingOperationsBy default headers are interpreted as native headers (e.g. STOMP) and + * are saved under a special key in the resulting Spring + * {@link org.springframework.messaging.Message Message}. In effect when the + * message leaves the application, the provided headers are included with it + * and delivered to the destination (e.g. the STOMP client or broker). + * + *

If the map already contains the key + * {@link org.springframework.messaging.support.NativeMessageHeaderAccessor#NATIVE_HEADERS "nativeHeaders"} + * or was prepared with + * {@link org.springframework.messaging.simp.SimpMessageHeaderAccessor SimpMessageHeaderAccessor} + * then the headers are used directly. A common expected case is providing a + * content type (to influence the message conversion) and native headers. + * This may be done as follows: + * + *

+	 * SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create();
+	 * accessor.setContentType(MimeTypeUtils.TEXT_PLAIN);
+	 * accessor.setNativeHeader("foo", "bar");
+	 * accessor.setLeaveMutable(true);
+	 * MessageHeaders headers = accessor.getMessageHeaders();
+	 *
+	 * messagingTemplate.convertAndSendToUser(user, destination, payload, headers);
+	 * 
+ * + *

Note: if the {@code MessageHeaders} are mutable as in + * the above example, implementations of this interface should take notice and + * update the headers in the same instance (rather than copy or re-create it) + * and then set it immutable before sending the final message. + * + * @param user the user that should receive the message, must not be {@code null} + * @param destination the destination to send the message to, must not be {@code null} + * @param payload the payload to send, may be {@code null} + * @param headers the message headers, may be {@code null} */ void convertAndSendToUser(String user, String destination, Object payload, Map headers) throws MessagingException; /** - * Send a message to a specific user. - * @param user the user that should receive the message. - * @param destination the destination to send the message to. - * @param payload the payload to send + * Send a message to the given user. + * + * @param user the user that should receive the message, must not be {@code null} + * @param destination the destination to send the message to, must not be {@code null} + * @param payload the payload to send, may be {@code null} * @param postProcessor a postProcessor to post-process or modify the created message */ void convertAndSendToUser(String user, String destination, Object payload, MessagePostProcessor postProcessor) throws MessagingException; /** - * Send a message to a specific user. + * Send a message to the given user. + * + *

See {@link #convertAndSend(Object, Object, java.util.Map)} for important + * notes regarding the input headers. + * * @param user the user that should receive the message. * @param destination the destination to send the message to. * @param payload the payload to send diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessagingTemplate.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessagingTemplate.java index dd8a9028568..a0e720eddad 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessagingTemplate.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessagingTemplate.java @@ -16,29 +16,25 @@ package org.springframework.messaging.simp; -import java.util.HashMap; import java.util.Map; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageDeliveryException; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessagingException; import org.springframework.messaging.core.AbstractMessageSendingTemplate; import org.springframework.messaging.core.MessagePostProcessor; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.messaging.support.NativeMessageHeaderAccessor; import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; /** - * A specialization of {@link AbstractMessageSendingTemplate} that interprets a - * String-based destination as the - * {@link org.springframework.messaging.simp.SimpMessageHeaderAccessor#DESTINATION_HEADER DESTINATION_HEADER} - * to be added to the headers of sent messages. - *

- * Also provides methods for sending messages to a user. See + * An implementation of {@link org.springframework.messaging.simp.SimpMessageSendingOperations}. + * + *

Also provides methods for sending messages to a user. See * {@link org.springframework.messaging.simp.user.UserDestinationResolver UserDestinationResolver} * for more on user destinations. * @@ -106,21 +102,66 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplateIf a destination header is not already present ,the message is sent + * to the configured {@link #setDefaultDestination(Object) defaultDestination} + * or an exception an {@code IllegalStateException} is raised if that isn't + * configured. + * + * @param message the message to send, never {@code null} + */ @Override public void send(Message message) { + Assert.notNull(message, "'message' is required"); String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders()); - destination = (destination != null) ? destination : getRequiredDefaultDestination(); - doSend(destination, message); + if (destination != null) { + sendInternal(message); + return; + } + doSend(getRequiredDefaultDestination(), message); } + @SuppressWarnings("unchecked") @Override protected void doSend(String destination, Message message) { + Assert.notNull(destination, "Destination must not be null"); - SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message); - headerAccessor.setDestination(destination); - headerAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); - message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders()); + SimpMessageHeaderAccessor simpAccessor = + MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class); + + if (simpAccessor != null) { + if (simpAccessor.isMutable()) { + simpAccessor.setDestination(destination); + simpAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); + simpAccessor.setImmutable(); + sendInternal(message); + return; + } + else { + // Try and keep the original accessor type + simpAccessor = (SimpMessageHeaderAccessor) MessageHeaderAccessor.getMutableAccessor(message); + } + } + else { + simpAccessor = SimpMessageHeaderAccessor.wrap(message); + } + + simpAccessor.setDestination(destination); + simpAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); + message = MessageBuilder.createMessage(message.getPayload(), simpAccessor.getMessageHeaders()); + sendInternal(message); + } + + private void sendInternal(Message message) { + + String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders()); + Assert.notNull(destination); long timeout = this.sendTimeout; boolean sent = (timeout >= 0) @@ -129,12 +170,11 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate - * If the given headers already contain the key - * {@link org.springframework.messaging.support.NativeMessageHeaderAccessor#NATIVE_HEADERS NATIVE_HEADERS} - * then the same header map is returned (i.e. without any changes). + * {@link org.springframework.messaging.support.NativeMessageHeaderAccessor#NATIVE_HEADERS NATIVE_HEADERS NATIVE_HEADERS NATIVE_HEADERS}. + * effectively treats the input header map as headers to be sent out to the + * destination. + * + *

However if the given headers already contain the key + * {@code NATIVE_HEADERS NATIVE_HEADERS} then the same headers instance is + * returned without changes. + * + *

Also if the given headers were prepared and obtained with + * {@link SimpMessageHeaderAccessor#getMessageHeaders()} then the same headers + * instance is also returned without changes. */ @Override protected Map processHeadersToSend(Map headers) { if (headers == null) { - return null; + SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + headerAccessor.setLeaveMutable(true); + return headerAccessor.getMessageHeaders(); } - else if (headers.containsKey(NativeMessageHeaderAccessor.NATIVE_HEADERS)) { + + if (headers.containsKey(NativeMessageHeaderAccessor.NATIVE_HEADERS)) { return headers; } - else { - MultiValueMap nativeHeaders = new LinkedMultiValueMap(headers.size()); - for (String key : headers.keySet()) { - Object value = headers.get(key); - nativeHeaders.set(key, (value != null ? value.toString() : null)); + + if (headers instanceof MessageHeaders) { + SimpMessageHeaderAccessor accessor = + MessageHeaderAccessor.getAccessor((MessageHeaders) headers, SimpMessageHeaderAccessor.class); + if (accessor != null) { + return headers; } + } - headers = new HashMap(1); - headers.put(NativeMessageHeaderAccessor.NATIVE_HEADERS, nativeHeaders); - return headers; + SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + for (String key : headers.keySet()) { + Object value = headers.get(key); + headerAccessor.setNativeHeader(key, (value != null ? value.toString() : null)); } + return headerAccessor.getMessageHeaders(); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java index 69c301182ce..1f481de8cc8 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java @@ -24,15 +24,14 @@ import org.springframework.core.annotation.AnnotationUtils; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHeaders; -import org.springframework.messaging.core.MessagePostProcessor; import org.springframework.messaging.handler.DestinationPatternsMessageCondition; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageSendingOperations; +import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.user.DestinationUserNameProvider; -import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; @@ -123,21 +122,13 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH MessageHeaders headers = message.getHeaders(); String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); - MessagePostProcessor postProcessor = new SessionHeaderPostProcessor(sessionId); SendToUser sendToUser = returnType.getMethodAnnotation(SendToUser.class); if (sendToUser != null) { - Principal principal = SimpMessageHeaderAccessor.getUser(headers); - if (principal == null) { - throw new MissingSessionUserException(message); - } - String userName = principal.getName(); - if (principal instanceof DestinationUserNameProvider) { - userName = ((DestinationUserNameProvider) principal).getDestinationUserName(); - } + String user = getUserName(message, headers); String[] destinations = getTargetDestinations(sendToUser, message, this.defaultUserDestinationPrefix); for (String destination : destinations) { - this.messagingTemplate.convertAndSendToUser(userName, destination, returnValue, postProcessor); + this.messagingTemplate.convertAndSendToUser(user, destination, returnValue, createHeaders(sessionId)); } return; } @@ -145,15 +136,26 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH SendTo sendTo = returnType.getMethodAnnotation(SendTo.class); String[] destinations = getTargetDestinations(sendTo, message, this.defaultDestinationPrefix); for (String destination : destinations) { - this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor); + this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId)); } } } - protected String[] getTargetDestinations(Annotation annot, Message message, String defaultPrefix) { + protected String getUserName(Message message, MessageHeaders headers) { + Principal principal = SimpMessageHeaderAccessor.getUser(headers); + if (principal == null) { + throw new MissingSessionUserException(message); + } + if (principal instanceof DestinationUserNameProvider) { + return ((DestinationUserNameProvider) principal).getDestinationUserName(); + } + return principal.getName(); + } + + protected String[] getTargetDestinations(Annotation annotation, Message message, String defaultPrefix) { - if (annot != null) { - String[] value = (String[]) AnnotationUtils.getValue(annot); + if (annotation != null) { + String[] value = (String[]) AnnotationUtils.getValue(annotation); if (!ObjectUtils.isEmpty(value)) { return value; } @@ -162,23 +164,14 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH return new String[] { defaultPrefix + message.getHeaders().get(name) }; } - - private final class SessionHeaderPostProcessor implements MessagePostProcessor { - - private final String sessionId; - - public SessionHeaderPostProcessor(String sessionId) { - this.sessionId = sessionId; - } - - @Override - public Message postProcessMessage(Message message) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - headers.setSessionId(this.sessionId); - return MessageBuilder.createMessage(message.getPayload(), headers.getMessageHeaders()); - } + private MessageHeaders createHeaders(String sessionId) { + SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + headerAccessor.setSessionId(sessionId); + headerAccessor.setLeaveMutable(true); + return headerAccessor.getMessageHeaders(); } + @Override public String toString() { return "SendToMethodReturnValueHandler [annotationRequired=" + annotationRequired + "]"; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java index 17e82eae175..5d0851b6c01 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java @@ -58,7 +58,7 @@ import org.springframework.messaging.simp.SimpMessageSendingOperations; import org.springframework.messaging.simp.SimpMessageTypeMessageCondition; import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.annotation.SubscribeMapping; -import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.stereotype.Controller; import org.springframework.util.AntPathMatcher; import org.springframework.util.Assert; @@ -330,7 +330,7 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan @Override protected String getDestination(Message message) { - return (String) SimpMessageHeaderAccessor.getDestination(message.getHeaders()); + return SimpMessageHeaderAccessor.getDestination(message.getHeaders()); } @Override @@ -357,10 +357,9 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan Map vars = getPathMatcher().extractUriTemplateVariables(matchedPattern, lookupDestination); if (!CollectionUtils.isEmpty(vars)) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - headers.setHeader(DestinationVariableMethodArgumentResolver.DESTINATION_TEMPLATE_VARIABLES_HEADER, vars); - message = MessageBuilder.createMessage(message.getPayload(), headers.getMessageHeaders()); - + MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class); + Assert.state(accessor != null && accessor.isMutable()); + accessor.setHeader(DestinationVariableMethodArgumentResolver.DESTINATION_TEMPLATE_VARIABLES_HEADER, vars); } super.handleMatch(mapping, handlerMethod, lookupDestination, message); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java index 4d91326f8e1..f17797b0e32 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ package org.springframework.messaging.simp.annotation.support; import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; -import org.springframework.messaging.core.MessagePostProcessor; import org.springframework.messaging.core.MessageSendingOperations; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; @@ -27,17 +26,18 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.annotation.SubscribeMapping; -import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; /** - * A {@link HandlerMethodReturnValueHandler} for replying directly to a subscription. It - * supports methods annotated with {@link org.springframework.messaging.simp.annotation.SubscribeMapping} unless they're also annotated - * with {@link SendTo} or {@link SendToUser}. + * A {@link HandlerMethodReturnValueHandler} for replying directly to a subscription. + * It is supported on methods annotated with + * {@link org.springframework.messaging.simp.annotation.SubscribeMapping} + * unless they're also annotated with {@link SendTo} or {@link SendToUser} in + * which case a message is sent to the broker instead. * - *

The value returned from the method is converted, and turned to a {@link Message} and - * then enriched with the sessionId, subscriptionId, and destination of the input message. - * The message is then sent directly back to the connected client. + *

The value returned from the method is converted, and turned to a {@link Message} + * and then enriched with the sessionId, subscriptionId, and destination of the + * input message. The message is then sent directly back to the connected client. * * @author Rossen Stoyanchev * @since 4.0 @@ -48,8 +48,10 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn /** - * @param messagingTemplate a messaging template for sending messages directly - * to clients, e.g. in response to a subscription + * Class constructor. + * + * @param messagingTemplate a messaging template to send messages to, most + * likely the "clientOutboundChannel", must not be {@link null}. */ public SubscriptionMethodReturnValueHandler(MessageSendingOperations messagingTemplate) { Assert.notNull(messagingTemplate, "messagingTemplate must not be null"); @@ -73,36 +75,22 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn } MessageHeaders headers = message.getHeaders(); + String destination = SimpMessageHeaderAccessor.getDestination(headers); String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers); - String destination = SimpMessageHeaderAccessor.getDestination(headers); - Assert.state(subscriptionId != null, "No subsriptiondId in input message to method " + returnType.getMethod()); + Assert.state(subscriptionId != null, + "No subscriptionId in message=" + message + ", method=" + returnType.getMethod()); - MessagePostProcessor postProcessor = new SubscriptionHeaderPostProcessor(sessionId, subscriptionId); - this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor); + this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId, subscriptionId)); } - - private final class SubscriptionHeaderPostProcessor implements MessagePostProcessor { - - private final String sessionId; - - private final String subscriptionId; - - - public SubscriptionHeaderPostProcessor(String sessionId, String subscriptionId) { - this.sessionId = sessionId; - this.subscriptionId = subscriptionId; - } - - @Override - public Message postProcessMessage(Message message) { - SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message); - headerAccessor.setSessionId(this.sessionId); - headerAccessor.setSubscriptionId(this.subscriptionId); - headerAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); - return MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders()); - } + private MessageHeaders createHeaders(String sessionId, String subscriptionId) { + SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + headerAccessor.setSessionId(sessionId); + headerAccessor.setSubscriptionId(subscriptionId); + headerAccessor.setLeaveMutable(true); + return headerAccessor.getMessageHeaders(); } + } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/converter/MessageConverterTests.java b/spring-messaging/src/test/java/org/springframework/messaging/converter/MessageConverterTests.java index 77e8f974f76..cc84bfacf63 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/converter/MessageConverterTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/converter/MessageConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,8 @@ import org.junit.Before; import org.junit.Test; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; @@ -34,7 +36,8 @@ import static org.junit.Assert.*; import static org.junit.Assert.assertEquals; /** - * Test fixture for {@link org.springframework.messaging.converter.AbstractMessageConverter}. + * Unit tests for + * {@link org.springframework.messaging.converter.AbstractMessageConverter}. * * @author Rossen Stoyanchev */ @@ -109,15 +112,32 @@ public class MessageConverterTests { } @Test - public void toMessageHeadersCopied() { + public void toMessageWithHeaders() { Map map = new HashMap(); map.put("foo", "bar"); MessageHeaders headers = new MessageHeaders(map); Message message = this.converter.toMessage("ABC", headers); - assertEquals("bar", message.getHeaders().get("foo")); assertNotNull(message.getHeaders().getId()); assertNotNull(message.getHeaders().getTimestamp()); + assertEquals(MimeTypeUtils.TEXT_PLAIN, message.getHeaders().get(MessageHeaders.CONTENT_TYPE)); + assertEquals("bar", message.getHeaders().get("foo")); + } + + @Test + public void toMessageWithMutableMessageHeaders() { + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + accessor.setHeader("foo", "bar"); + accessor.setNativeHeader("fooNative", "barNative"); + accessor.setLeaveMutable(true); + + MessageHeaders headers = accessor.getMessageHeaders(); + Message message = this.converter.toMessage("ABC", headers); + + assertSame(headers, message.getHeaders()); + assertNull(message.getHeaders().getId()); + assertNull(message.getHeaders().getTimestamp()); + assertEquals(MimeTypeUtils.TEXT_PLAIN, message.getHeaders().get(MessageHeaders.CONTENT_TYPE)); } @Test diff --git a/spring-messaging/src/test/java/org/springframework/messaging/converter/SimpleMessageConverterTests.java b/spring-messaging/src/test/java/org/springframework/messaging/converter/SimpleMessageConverterTests.java new file mode 100644 index 00000000000..9d338ea086c --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/converter/SimpleMessageConverterTests.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.converter; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.support.MessageHeaderAccessor; + +import java.util.Collections; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +/** + * Unit tests for + * {@link org.springframework.messaging.converter.SimpleMessageConverter}. + * + * @author Rossen Stoyanchev + */ +public class SimpleMessageConverterTests { + + private SimpleMessageConverter converter; + + @Before + public void setup() { + this.converter = new SimpleMessageConverter(); + } + + @Test + public void toMessageWithNullPayload() { + assertNull(this.converter.toMessage(null, null)); + } + + @Test + public void toMessageWithPayloadAndHeaders() { + MessageHeaders headers = new MessageHeaders(Collections.singletonMap("foo", "bar")); + Message message = this.converter.toMessage("payload", headers); + + assertEquals("payload", message.getPayload()); + assertEquals("bar", message.getHeaders().get("foo")); + } + + @Test + public void toMessageWithPayloadAndMutableHeaders() { + MessageHeaderAccessor accessor = new MessageHeaderAccessor(); + accessor.setHeader("foo", "bar"); + accessor.setLeaveMutable(true); + MessageHeaders headers = accessor.getMessageHeaders(); + + Message message = this.converter.toMessage("payload", headers); + + assertEquals("payload", message.getPayload()); + assertSame(headers, message.getHeaders()); + assertEquals("bar", message.getHeaders().get("foo")); + } +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/core/GenericMessagingTemplateTests.java b/spring-messaging/src/test/java/org/springframework/messaging/core/GenericMessagingTemplateTests.java index 6d1624b6398..4af3dba3eb7 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/core/GenericMessagingTemplateTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/core/GenericMessagingTemplateTests.java @@ -16,6 +16,7 @@ package org.springframework.messaging.core; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -27,10 +28,13 @@ import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageDeliveryException; import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessagingException; +import org.springframework.messaging.StubMessageChannel; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.support.ExecutorSubscribableChannel; import org.springframework.messaging.support.GenericMessage; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import static org.junit.Assert.*; @@ -44,12 +48,17 @@ public class GenericMessagingTemplateTests { private GenericMessagingTemplate template; + private StubMessageChannel messageChannel; + private ThreadPoolTaskExecutor executor; @Before public void setup() { + this.messageChannel = new StubMessageChannel(); this.template = new GenericMessagingTemplate(); + this.template.setDefaultDestination(this.messageChannel); + this.template.setDestinationResolver(new TestDestinationResolver()); this.executor = new ThreadPoolTaskExecutor(); this.executor.afterPropertiesSet(); } @@ -114,4 +123,26 @@ public class GenericMessagingTemplateTests { } } + @Test + public void convertAndSendWithSimpMessageHeaders() { + MessageHeaderAccessor accessor = new MessageHeaderAccessor(); + accessor.setHeader("key", "value"); + accessor.setLeaveMutable(true); + MessageHeaders headers = accessor.getMessageHeaders(); + + this.template.convertAndSend("channel", "data", headers); + List> messages = this.messageChannel.getMessages(); + Message message = messages.get(0); + + assertSame(headers, message.getHeaders()); + assertFalse(accessor.isMutable()); + } + + private class TestDestinationResolver implements DestinationResolver { + + @Override + public MessageChannel resolveDestination(String name) throws DestinationResolutionException { + return messageChannel; + } + } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/core/MessageSendingTemplateTests.java b/spring-messaging/src/test/java/org/springframework/messaging/core/MessageSendingTemplateTests.java index 877b5405547..9e30b4c0e7c 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/core/MessageSendingTemplateTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/core/MessageSendingTemplateTests.java @@ -16,6 +16,7 @@ package org.springframework.messaging.core; +import java.nio.charset.Charset; import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -27,6 +28,8 @@ import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.converter.*; import org.springframework.messaging.support.GenericMessage; +import org.springframework.messaging.support.MessageHeaderAccessor; +import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; import static org.junit.Assert.*; @@ -122,6 +125,22 @@ public class MessageSendingTemplateTests { assertEquals("payload", this.template.message.getPayload()); } + @Test + public void convertAndSendPayloadAndMutableHeadersToDestination() { + MessageHeaderAccessor accessor = new MessageHeaderAccessor(); + accessor.setHeader("foo", "bar"); + accessor.setLeaveMutable(true); + MessageHeaders messageHeaders = accessor.getMessageHeaders(); + + this.template.setMessageConverter(new StringMessageConverter()); + this.template.convertAndSend("somewhere", "payload", messageHeaders); + + MessageHeaders actual = this.template.message.getHeaders(); + assertSame(messageHeaders, actual); + assertEquals(new MimeType("text", "plain", Charset.forName("UTF-8")), actual.get(MessageHeaders.CONTENT_TYPE)); + assertEquals("bar", actual.get("foo")); + } + @Test public void convertAndSendPayloadWithPostProcessor() { this.template.setDefaultDestination("home"); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessagingTemplateTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessagingTemplateTests.java index 276f60a5a2c..43f2d1b6558 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessagingTemplateTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessagingTemplateTests.java @@ -16,17 +16,27 @@ package org.springframework.messaging.simp; +import org.apache.activemq.transport.stomp.Stomp; import org.junit.Before; import org.junit.Test; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.StubMessageChannel; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.messaging.support.NativeMessageHeaderAccessor; import org.springframework.util.LinkedMultiValueMap; import java.util.*; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; /** * Unit tests for {@link org.springframework.messaging.simp.SimpMessagingTemplate}. @@ -43,7 +53,7 @@ public class SimpMessagingTemplateTests { @Before public void setup() { this.messageChannel = new StubMessageChannel(); - this.messagingTemplate = new SimpMessagingTemplate(messageChannel); + this.messagingTemplate = new SimpMessagingTemplate(this.messageChannel); } @@ -55,10 +65,12 @@ public class SimpMessagingTemplateTests { assertEquals(1, messages.size()); Message message = messages.get(0); - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + SimpMessageHeaderAccessor headerAccessor = + MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class); - assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); - assertEquals("/user/joe/queue/foo", headers.getDestination()); + assertNotNull(headerAccessor); + assertEquals(SimpMessageType.MESSAGE, headerAccessor.getMessageType()); + assertEquals("/user/joe/queue/foo", headerAccessor.getDestination()); } @Test @@ -68,9 +80,11 @@ public class SimpMessagingTemplateTests { assertEquals(1, messages.size()); - Message message = messages.get(0); - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - assertEquals("/user/http:%2F%2Fjoe.openid.example.org%2F/queue/foo", headers.getDestination()); + SimpMessageHeaderAccessor headerAccessor = + MessageHeaderAccessor.getAccessor(messages.get(0), SimpMessageHeaderAccessor.class); + + assertNotNull(headerAccessor); + assertEquals("/user/http:%2F%2Fjoe.openid.example.org%2F/queue/foo", headerAccessor.getDestination()); } @Test @@ -79,11 +93,13 @@ public class SimpMessagingTemplateTests { this.messagingTemplate.convertAndSend("/foo", "data", headers); List> messages = this.messageChannel.getMessages(); - Message message = messages.get(0); - SimpMessageHeaderAccessor resultHeaders = SimpMessageHeaderAccessor.wrap(message); - assertNull(resultHeaders.toMap().get("key")); - assertEquals(Arrays.asList("value"), resultHeaders.getNativeHeader("key")); + SimpMessageHeaderAccessor headerAccessor = + MessageHeaderAccessor.getAccessor(messages.get(0), SimpMessageHeaderAccessor.class); + + assertNotNull(headerAccessor); + assertNull(headerAccessor.toMap().get("key")); + assertEquals(Arrays.asList("value"), headerAccessor.getNativeHeader("key")); } @Test @@ -93,12 +109,79 @@ public class SimpMessagingTemplateTests { headers.put(NativeMessageHeaderAccessor.NATIVE_HEADERS, new LinkedMultiValueMap()); this.messagingTemplate.convertAndSend("/foo", "data", headers); + List> messages = this.messageChannel.getMessages(); + + SimpMessageHeaderAccessor headerAccessor = + MessageHeaderAccessor.getAccessor(messages.get(0), SimpMessageHeaderAccessor.class); + + assertNotNull(headerAccessor); + assertEquals("value", headerAccessor.toMap().get("key")); + assertNull(headerAccessor.getNativeHeader("key")); + } + + @Test + public void convertAndSendWithMutableSimpMessageHeaders() { + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(); + accessor.setHeader("key", "value"); + accessor.setNativeHeader("fooNative", "barNative"); + accessor.setLeaveMutable(true); + MessageHeaders headers = accessor.getMessageHeaders(); + + this.messagingTemplate.convertAndSend("/foo", "data", headers); + List> messages = this.messageChannel.getMessages(); Message message = messages.get(0); - SimpMessageHeaderAccessor resultHeaders = SimpMessageHeaderAccessor.wrap(message); - assertEquals("value", resultHeaders.toMap().get("key")); - assertNull(resultHeaders.getNativeHeader("key")); + assertSame(headers, message.getHeaders()); + assertFalse(accessor.isMutable()); + } + + @Test + public void processHeadersToSend() { + Map map = this.messagingTemplate.processHeadersToSend(null); + + assertNotNull(map); + assertTrue("Actual: " + map.getClass().toString(), MessageHeaders.class.isAssignableFrom(map.getClass())); + + SimpMessageHeaderAccessor headerAccessor = + MessageHeaderAccessor.getAccessor((MessageHeaders) map, SimpMessageHeaderAccessor.class); + + assertTrue(headerAccessor.isMutable()); + assertEquals(SimpMessageType.MESSAGE, headerAccessor.getMessageType()); + } + + @Test + public void doSendWithMutableHeaders() { + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(); + accessor.setHeader("key", "value"); + accessor.setNativeHeader("fooNative", "barNative"); + accessor.setLeaveMutable(true); + MessageHeaders headers = accessor.getMessageHeaders(); + Message message = MessageBuilder.createMessage("payload", headers); + + this.messagingTemplate.doSend("/topic/foo", message); + + List> messages = this.messageChannel.getMessages(); + Message sentMessage = messages.get(0); + + assertSame(message, sentMessage); + assertFalse(accessor.isMutable()); + } + + @Test + public void doSendWithStompHeaders() { + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SUBSCRIBE); + accessor.setDestination("/user/queue/foo"); + Message message = MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); + + this.messagingTemplate.doSend("/queue/foo-user123", message); + + List> messages = this.messageChannel.getMessages(); + Message sentMessage = messages.get(0); + + MessageHeaderAccessor sentAccessor = MessageHeaderAccessor.getAccessor(sentMessage, MessageHeaderAccessor.class); + assertEquals(StompHeaderAccessor.class, sentAccessor.getClass()); + assertEquals("/queue/foo-user123", ((StompHeaderAccessor) sentAccessor).getDestination()); } -} +} \ No newline at end of file diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java index 1e2dd57fceb..bb2e7206e26 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java @@ -17,6 +17,7 @@ package org.springframework.messaging.simp.annotation.support; import java.lang.reflect.Method; +import java.nio.charset.Charset; import java.security.Principal; import javax.security.auth.Subject; @@ -27,22 +28,29 @@ import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.converter.StringMessageConverter; +import org.springframework.messaging.core.MessageSendingOperations; import org.springframework.messaging.handler.DestinationPatternsMessageCondition; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageSendingOperations; import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.user.DestinationUserNameProvider; import org.springframework.messaging.support.MessageBuilder; -import org.springframework.messaging.converter.MessageConverter; +import org.springframework.messaging.support.MessageHeaderAccessor; +import org.springframework.util.MimeType; import static org.junit.Assert.*; import static org.mockito.Matchers.*; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.*; /** @@ -52,7 +60,9 @@ import static org.mockito.Mockito.*; */ public class SendToMethodReturnValueHandlerTests { - private static final String payloadContent = "payload"; + public static final MimeType MIME_TYPE = new MimeType("text", "plain", Charset.forName("UTF-8")); + + private static final String PAYLOAD = "payload"; private SendToMethodReturnValueHandler handler; @@ -63,8 +73,6 @@ public class SendToMethodReturnValueHandlerTests { @Captor ArgumentCaptor> messageCaptor; - @Mock private MessageConverter messageConverter; - private MethodParameter noAnnotationsReturnType; private MethodParameter sendToReturnType; private MethodParameter sendToDefaultDestReturnType; @@ -78,11 +86,8 @@ public class SendToMethodReturnValueHandlerTests { MockitoAnnotations.initMocks(this); - Message message = MessageBuilder.withPayload(payloadContent).build(); - when(this.messageConverter.toMessage(payloadContent, null)).thenReturn(message); - SimpMessagingTemplate messagingTemplate = new SimpMessagingTemplate(this.messageChannel); - messagingTemplate.setMessageConverter(this.messageConverter); + messagingTemplate.setMessageConverter(new StringMessageConverter()); this.handler = new SendToMethodReturnValueHandler(messagingTemplate, true); this.handlerAnnotationNotRequired = new SendToMethodReturnValueHandler(messagingTemplate, false); @@ -118,15 +123,16 @@ public class SendToMethodReturnValueHandlerTests { when(this.messageChannel.send(any(Message.class))).thenReturn(true); Message inputMessage = createInputMessage("sess1", "sub1", "/app", "/dest", null); - this.handler.handleReturnValue(payloadContent, this.noAnnotationsReturnType, inputMessage); + this.handler.handleReturnValue(PAYLOAD, this.noAnnotationsReturnType, inputMessage); verify(this.messageChannel, times(1)).send(this.messageCaptor.capture()); Message message = this.messageCaptor.getAllValues().get(0); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); assertEquals("sess1", headers.getSessionId()); - assertNull(headers.getSubscriptionId()); assertEquals("/topic/dest", headers.getDestination()); + assertEquals(MIME_TYPE, headers.getContentType()); + assertNull("Subscription id should not be copied", headers.getSubscriptionId()); } @Test @@ -136,21 +142,23 @@ public class SendToMethodReturnValueHandlerTests { String sessionId = "sess1"; Message inputMessage = createInputMessage(sessionId, "sub1", null, null, null); - this.handler.handleReturnValue(payloadContent, this.sendToReturnType, inputMessage); + this.handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage); verify(this.messageChannel, times(2)).send(this.messageCaptor.capture()); Message message = this.messageCaptor.getAllValues().get(0); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); assertEquals(sessionId, headers.getSessionId()); - assertNull(headers.getSubscriptionId()); assertEquals("/dest1", headers.getDestination()); + assertEquals(MIME_TYPE, headers.getContentType()); + assertNull("Subscription id should not be copied", headers.getSubscriptionId()); message = this.messageCaptor.getAllValues().get(1); headers = SimpMessageHeaderAccessor.wrap(message); assertEquals(sessionId, headers.getSessionId()); - assertNull(headers.getSubscriptionId()); assertEquals("/dest2", headers.getDestination()); + assertEquals(MIME_TYPE, headers.getContentType()); + assertNull("Subscription id should not be copied", headers.getSubscriptionId()); } @Test @@ -160,15 +168,38 @@ public class SendToMethodReturnValueHandlerTests { String sessionId = "sess1"; Message inputMessage = createInputMessage(sessionId, "sub1", "/app", "/dest", null); - this.handler.handleReturnValue(payloadContent, this.sendToDefaultDestReturnType, inputMessage); + this.handler.handleReturnValue(PAYLOAD, this.sendToDefaultDestReturnType, inputMessage); verify(this.messageChannel, times(1)).send(this.messageCaptor.capture()); Message message = this.messageCaptor.getAllValues().get(0); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); assertEquals(sessionId, headers.getSessionId()); - assertNull(headers.getSubscriptionId()); assertEquals("/topic/dest", headers.getDestination()); + assertEquals(MIME_TYPE, headers.getContentType()); + assertNull("Subscription id should not be copied", headers.getSubscriptionId()); + } + + @Test + public void testHeadersToSend() throws Exception { + + Message inputMessage = createInputMessage("sess1", "sub1", "/app", "/dest", null); + + SimpMessageSendingOperations messagingTemplate = Mockito.mock(SimpMessageSendingOperations.class); + SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(messagingTemplate, false); + + handler.handleReturnValue(PAYLOAD, this.noAnnotationsReturnType, inputMessage); + + ArgumentCaptor captor = ArgumentCaptor.forClass(MessageHeaders.class); + verify(messagingTemplate).convertAndSend(eq("/topic/dest"), eq(PAYLOAD), captor.capture()); + + SimpMessageHeaderAccessor headerAccessor = + MessageHeaderAccessor.getAccessor(captor.getValue(), SimpMessageHeaderAccessor.class); + + assertNotNull(headerAccessor); + assertTrue(headerAccessor.isMutable()); + assertEquals("sess1", headerAccessor.getSessionId()); + assertNull("Subscription id should not be copied", headerAccessor.getSubscriptionId()); } @Test @@ -179,21 +210,23 @@ public class SendToMethodReturnValueHandlerTests { String sessionId = "sess1"; TestUser user = new TestUser(); Message inputMessage = createInputMessage(sessionId, "sub1", null, null, user); - this.handler.handleReturnValue(payloadContent, this.sendToUserReturnType, inputMessage); + this.handler.handleReturnValue(PAYLOAD, this.sendToUserReturnType, inputMessage); verify(this.messageChannel, times(2)).send(this.messageCaptor.capture()); Message message = this.messageCaptor.getAllValues().get(0); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); assertEquals(sessionId, headers.getSessionId()); - assertNull(headers.getSubscriptionId()); + assertEquals(MIME_TYPE, headers.getContentType()); assertEquals("/user/" + user.getName() + "/dest1", headers.getDestination()); + assertNull("Subscription id should not be copied", headers.getSubscriptionId()); message = this.messageCaptor.getAllValues().get(1); headers = SimpMessageHeaderAccessor.wrap(message); assertEquals(sessionId, headers.getSessionId()); - assertNull(headers.getSubscriptionId()); assertEquals("/user/" + user.getName() + "/dest2", headers.getDestination()); + assertEquals(MIME_TYPE, headers.getContentType()); + assertNull("Subscription id should not be copied", headers.getSubscriptionId()); } @Test @@ -204,7 +237,7 @@ public class SendToMethodReturnValueHandlerTests { String sessionId = "sess1"; TestUser user = new UniqueUser(); Message inputMessage = createInputMessage(sessionId, "sub1", null, null, user); - this.handler.handleReturnValue(payloadContent, this.sendToUserReturnType, inputMessage); + this.handler.handleReturnValue(PAYLOAD, this.sendToUserReturnType, inputMessage); verify(this.messageChannel, times(2)).send(this.messageCaptor.capture()); @@ -223,31 +256,56 @@ public class SendToMethodReturnValueHandlerTests { String sessionId = "sess1"; TestUser user = new TestUser(); Message inputMessage = createInputMessage(sessionId, "sub1", "/app", "/dest", user); - this.handler.handleReturnValue(payloadContent, this.sendToUserDefaultDestReturnType, inputMessage); + this.handler.handleReturnValue(PAYLOAD, this.sendToUserDefaultDestReturnType, inputMessage); verify(this.messageChannel, times(1)).send(this.messageCaptor.capture()); Message message = this.messageCaptor.getAllValues().get(0); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); assertEquals(sessionId, headers.getSessionId()); - assertNull(headers.getSubscriptionId()); assertEquals("/user/" + user.getName() + "/queue/dest", headers.getDestination()); + assertEquals(MIME_TYPE, headers.getContentType()); + assertNull("Subscription id should not be copied", headers.getSubscriptionId()); + } + + @Test + public void testHeadersToSendToUser() throws Exception { + + TestUser user = new TestUser(); + Message inputMessage = createInputMessage("sess1", "sub1", "/app", "/dest", user); + + SimpMessageSendingOperations messagingTemplate = Mockito.mock(SimpMessageSendingOperations.class); + SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(messagingTemplate, false); + + handler.handleReturnValue(PAYLOAD, this.sendToUserDefaultDestReturnType, inputMessage); + + ArgumentCaptor captor = ArgumentCaptor.forClass(MessageHeaders.class); + verify(messagingTemplate).convertAndSendToUser(eq("joe"), eq("/queue/dest"), eq(PAYLOAD), captor.capture()); + + SimpMessageHeaderAccessor headerAccessor = + MessageHeaderAccessor.getAccessor(captor.getValue(), SimpMessageHeaderAccessor.class); + + assertNotNull(headerAccessor); + assertTrue(headerAccessor.isMutable()); + assertEquals("sess1", headerAccessor.getSessionId()); + assertNull("Subscription id should not be copied", headerAccessor.getSubscriptionId()); } private Message createInputMessage(String sessId, String subsId, String destinationPrefix, String destination, Principal principal) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); - headers.setSessionId(sessId); - headers.setSubscriptionId(subsId); + + SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(); + headerAccessor.setSessionId(sessId); + headerAccessor.setSubscriptionId(subsId); if (destination != null && destinationPrefix != null) { - headers.setDestination(destinationPrefix + destination); - headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, destination); + headerAccessor.setDestination(destinationPrefix + destination); + headerAccessor.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, destination); } if (principal != null) { - headers.setUser(principal); + headerAccessor.setUser(principal); } - return MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); + return MessageBuilder.createMessage(new byte[0], headerAccessor.getMessageHeaders()); } private static class TestUser implements Principal { @@ -270,27 +328,27 @@ public class SendToMethodReturnValueHandlerTests { } public String handleNoAnnotations() { - return payloadContent; + return PAYLOAD; } @SendTo public String handleAndSendToDefaultDestination() { - return payloadContent; + return PAYLOAD; } @SendTo({"/dest1", "/dest2"}) public String handleAndSendTo() { - return payloadContent; + return PAYLOAD; } @SendToUser public String handleAndSendToUserDefaultDestination() { - return payloadContent; + return PAYLOAD; } @SendToUser({"/dest1", "/dest2"}) public String handleAndSendToUser() { - return payloadContent; + return PAYLOAD; } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java index 7c6e038f7e4..a239620708c 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java @@ -17,6 +17,7 @@ package org.springframework.messaging.simp.annotation.support; import java.lang.reflect.Method; +import java.nio.charset.Charset; import java.security.Principal; import org.junit.Before; @@ -24,17 +25,22 @@ import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.converter.StringMessageConverter; +import org.springframework.messaging.core.MessageSendingOperations; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.annotation.SubscribeMapping; import org.springframework.messaging.support.MessageBuilder; -import org.springframework.messaging.converter.MessageConverter; +import org.springframework.messaging.support.MessageHeaderAccessor; +import org.springframework.util.MimeType; import static org.junit.Assert.*; import static org.mockito.Matchers.*; @@ -47,7 +53,9 @@ import static org.mockito.Mockito.*; */ public class SubscriptionMethodReturnValueHandlerTests { - private static final String payloadContent = "payload"; + public static final MimeType MIME_TYPE = new MimeType("text", "plain", Charset.forName("UTF-8")); + + private static final String PAYLOAD = "payload"; private SubscriptionMethodReturnValueHandler handler; @@ -56,8 +64,6 @@ public class SubscriptionMethodReturnValueHandlerTests { @Captor ArgumentCaptor> messageCaptor; - @Mock private MessageConverter messageConverter; - private MethodParameter subscribeEventReturnType; private MethodParameter subscribeEventSendToReturnType; @@ -71,11 +77,8 @@ public class SubscriptionMethodReturnValueHandlerTests { MockitoAnnotations.initMocks(this); - Message message = MessageBuilder.withPayload(payloadContent).build(); - when(this.messageConverter.toMessage(payloadContent, null)).thenReturn(message); - SimpMessagingTemplate messagingTemplate = new SimpMessagingTemplate(this.messageChannel); - messagingTemplate.setMessageConverter(this.messageConverter); + messagingTemplate.setMessageConverter(new StringMessageConverter()); this.handler = new SubscriptionMethodReturnValueHandler(messagingTemplate); @@ -98,7 +101,7 @@ public class SubscriptionMethodReturnValueHandlerTests { } @Test - public void subscribeEventMethod() throws Exception { + public void testMessageSentToChannel() throws Exception { when(this.messageChannel.send(any(Message.class))).thenReturn(true); @@ -107,17 +110,46 @@ public class SubscriptionMethodReturnValueHandlerTests { String destination = "/dest"; Message inputMessage = createInputMessage(sessionId, subscriptionId, destination, null); - this.handler.handleReturnValue(payloadContent, this.subscribeEventReturnType, inputMessage); + this.handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage); verify(this.messageChannel).send(this.messageCaptor.capture()); assertNotNull(this.messageCaptor.getValue()); Message message = this.messageCaptor.getValue(); - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message); + + assertNull("SimpMessageHeaderAccessor should have disabled id", headerAccessor.getId()); + assertNull("SimpMessageHeaderAccessor should have disabled timestamp", headerAccessor.getTimestamp()); + assertEquals(sessionId, headerAccessor.getSessionId()); + assertEquals(subscriptionId, headerAccessor.getSubscriptionId()); + assertEquals(destination, headerAccessor.getDestination()); + assertEquals(MIME_TYPE, headerAccessor.getContentType()); + } + + @SuppressWarnings("unchecked") + @Test + public void testHeadersPassedToMessagingTemplate() throws Exception { + + String sessionId = "sess1"; + String subscriptionId = "subs1"; + String destination = "/dest"; + Message inputMessage = createInputMessage(sessionId, subscriptionId, destination, null); + + MessageSendingOperations messagingTemplate = Mockito.mock(MessageSendingOperations.class); + SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate); + + handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage); + + ArgumentCaptor captor = ArgumentCaptor.forClass(MessageHeaders.class); + verify(messagingTemplate).convertAndSend(eq("/dest"), eq(PAYLOAD), captor.capture()); + + SimpMessageHeaderAccessor headerAccessor = + MessageHeaderAccessor.getAccessor(captor.getValue(), SimpMessageHeaderAccessor.class); - assertEquals("sessionId should always be copied", sessionId, headers.getSessionId()); - assertEquals(subscriptionId, headers.getSubscriptionId()); - assertEquals(destination, headers.getDestination()); + assertNotNull(headerAccessor); + assertTrue(headerAccessor.isMutable()); + assertEquals(sessionId, headerAccessor.getSessionId()); + assertEquals(subscriptionId, headerAccessor.getSubscriptionId()); } @@ -131,19 +163,22 @@ public class SubscriptionMethodReturnValueHandlerTests { } + @SuppressWarnings("unused") @SubscribeMapping("/data") // not needed for the tests but here for completeness private String getData() { - return payloadContent; + return PAYLOAD; } + @SuppressWarnings("unused") @SubscribeMapping("/data") // not needed for the tests but here for completeness @SendTo("/sendToDest") private String getDataAndSendTo() { - return payloadContent; + return PAYLOAD; } + @SuppressWarnings("unused") @MessageMapping("/handle") // not needed for the tests but here for completeness public String handle() { - return payloadContent; + return PAYLOAD; } }