diff --git a/spring-context/src/main/java/org/springframework/messaging/MessageChannel.java b/spring-context/src/main/java/org/springframework/messaging/MessageChannel.java index dd42f246f84..aae499bf949 100644 --- a/spring-context/src/main/java/org/springframework/messaging/MessageChannel.java +++ b/spring-context/src/main/java/org/springframework/messaging/MessageChannel.java @@ -23,6 +23,7 @@ package org.springframework.messaging; * @author Mark Fisher * @since 4.0 */ +@SuppressWarnings("rawtypes") public interface MessageChannel { /** diff --git a/spring-context/src/main/java/org/springframework/messaging/MessageFactory.java b/spring-context/src/main/java/org/springframework/messaging/MessageFactory.java index 075cd2f4256..ea0d7a483b5 100644 --- a/spring-context/src/main/java/org/springframework/messaging/MessageFactory.java +++ b/spring-context/src/main/java/org/springframework/messaging/MessageFactory.java @@ -22,9 +22,8 @@ import java.util.Map; /** * A factory for creating messages, allowing for control of the concrete type of the message. * - * - * * @author Andy Wilkinson + * @since 4.0 */ public interface MessageFactory> { diff --git a/spring-context/src/main/java/org/springframework/messaging/MessageHandler.java b/spring-context/src/main/java/org/springframework/messaging/MessageHandler.java index 52a171e2116..8893717d492 100644 --- a/spring-context/src/main/java/org/springframework/messaging/MessageHandler.java +++ b/spring-context/src/main/java/org/springframework/messaging/MessageHandler.java @@ -24,6 +24,7 @@ package org.springframework.messaging; * @author Iwein Fuld * @since 4.0 */ +@SuppressWarnings("rawtypes") public interface MessageHandler { /** diff --git a/spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java b/spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java index 6b8961a983c..7e87eadb190 100644 --- a/spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java +++ b/spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java @@ -25,7 +25,9 @@ package org.springframework.messaging; * @author Mark Fisher * @since 4.0 */ -public interface SubscribableChannel> extends MessageChannel { +@SuppressWarnings("rawtypes") +public interface SubscribableChannel> + extends MessageChannel { /** * Register a {@link MessageHandler} as a subscriber to this channel. diff --git a/spring-context/src/main/java/org/springframework/messaging/GenericMessage.java b/spring-context/src/main/java/org/springframework/messaging/support/GenericMessage.java similarity index 94% rename from spring-context/src/main/java/org/springframework/messaging/GenericMessage.java rename to spring-context/src/main/java/org/springframework/messaging/support/GenericMessage.java index 548baf68063..7c0b6ada2ad 100644 --- a/spring-context/src/main/java/org/springframework/messaging/GenericMessage.java +++ b/spring-context/src/main/java/org/springframework/messaging/support/GenericMessage.java @@ -14,12 +14,14 @@ * limitations under the License. */ -package org.springframework.messaging; +package org.springframework.messaging.support; import java.io.Serializable; import java.util.HashMap; import java.util.Map; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; diff --git a/spring-context/src/main/java/org/springframework/messaging/GenericMessageFactory.java b/spring-context/src/main/java/org/springframework/messaging/support/GenericMessageFactory.java similarity index 83% rename from spring-context/src/main/java/org/springframework/messaging/GenericMessageFactory.java rename to spring-context/src/main/java/org/springframework/messaging/support/GenericMessageFactory.java index 076cc5590f2..30f80823fa6 100644 --- a/spring-context/src/main/java/org/springframework/messaging/GenericMessageFactory.java +++ b/spring-context/src/main/java/org/springframework/messaging/support/GenericMessageFactory.java @@ -14,20 +14,24 @@ * limitations under the License. */ -package org.springframework.messaging; +package org.springframework.messaging.support; import java.util.Map; +import org.springframework.messaging.MessageFactory; + /** * A {@link MessageFactory} that creates {@link GenericMessage GenericMessages}. * * @author Andy Wilkinson + * @since 4.0 */ public class GenericMessageFactory implements MessageFactory> { + @Override - public

GenericMessage createMessage(P payload, Map headers) { + public

GenericMessage

createMessage(P payload, Map headers) { return new GenericMessage

(payload, headers); } diff --git a/spring-context/src/main/java/org/springframework/messaging/support/MessageBuilder.java b/spring-context/src/main/java/org/springframework/messaging/support/MessageBuilder.java new file mode 100644 index 00000000000..22cf40cf283 --- /dev/null +++ b/spring-context/src/main/java/org/springframework/messaging/support/MessageBuilder.java @@ -0,0 +1,245 @@ +/* + * Copyright 2002-2010 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageFactory; +import org.springframework.messaging.MessageHeaders; +import org.springframework.util.Assert; +import org.springframework.util.PatternMatchUtils; +import org.springframework.util.StringUtils; + +/** + * @author Arjen Poutsma + * @author Mark Fisher + * @author Oleg Zhurakousky + * @author Dave Syer + * @since 4.0 + */ +public final class MessageBuilder { + + private final T payload; + + private final Map headers = new HashMap(); + + private final Message originalMessage; + + @SuppressWarnings("rawtypes") + private static volatile MessageFactory messageFactory = null; + + + /** + * A constructor with payload and an optional message to copy headers from. + * This is a private constructor to be invoked from the static factory methods only. + * + * @param payload the message payload, never {@code null} + * @param originalMessage a message to copy from or re-use if no changes are made, can + * be {@code null} + */ + private MessageBuilder(T payload, Message originalMessage) { + Assert.notNull(payload, "payload is required"); + this.payload = payload; + this.originalMessage = originalMessage; + if (originalMessage != null) { + this.headers.putAll(originalMessage.getHeaders()); + } + } + + /** + * Private constructor to be invoked from the static factory methods only. + * + * @param payload the message payload, never {@code null} + * @param originalMessage a message to copy from or re-use if no changes are made, can + * be {@code null} + */ + private MessageBuilder(T payload, Map headers) { + Assert.notNull(payload, "payload is required"); + Assert.notNull(headers, "headers is required"); + this.payload = payload; + this.headers.putAll(headers); + this.originalMessage = null; + } + + /** + * Create a builder for a new {@link Message} instance pre-populated with all of the + * headers copied from the provided message. The payload of the provided Message will + * also be used as the payload for the new message. + * + * @param message the Message from which the payload and all headers will be copied + */ + public static MessageBuilder fromMessage(Message message) { + Assert.notNull(message, "message must not be null"); + MessageBuilder builder = new MessageBuilder(message.getPayload(), message); + return builder; + } + + /** + * Create a builder for a new {@link Message} instance with the provided payload and + * headers. + * + * @param payload the payload for the new message + * @param headers the headers to use + */ + public static MessageBuilder fromPayloadAndHeaders(T payload, Map headers) { + MessageBuilder builder = new MessageBuilder(payload, headers); + return builder; + } + + /** + * Create a builder for a new {@link Message} instance with the provided payload. + * + * @param payload the payload for the new message + */ + public static MessageBuilder withPayload(T payload) { + MessageBuilder builder = new MessageBuilder(payload, (Message) null); + return builder; + } + + /** + * Set the value for the given header name. If the provided value is null + * the header will be removed. + */ + public MessageBuilder setHeader(String headerName, Object headerValue) { + Assert.isTrue(!this.isReadOnly(headerName), "The '" + headerName + "' header is read-only."); + if (StringUtils.hasLength(headerName)) { + putOrRemove(headerName, headerValue); + } + return this; + } + + private boolean isReadOnly(String headerName) { + return MessageHeaders.ID.equals(headerName) || MessageHeaders.TIMESTAMP.equals(headerName); + } + + private void putOrRemove(String headerName, Object headerValue) { + if (headerValue == null) { + this.headers.remove(headerName); + } + else { + this.headers.put(headerName, headerValue); + } + } + + /** + * Set the value for the given header name only if the header name is not already + * associated with a value. + */ + public MessageBuilder setHeaderIfAbsent(String headerName, Object headerValue) { + if (this.headers.get(headerName) == null) { + putOrRemove(headerName, headerValue); + } + return this; + } + + /** + * Removes all headers provided via array of 'headerPatterns'. As the name suggests + * the array may contain simple matching patterns for header names. Supported pattern + * styles are: "xxx*", "*xxx", "*xxx*" and "xxx*yyy". + */ + public MessageBuilder removeHeaders(String... headerPatterns) { + List toRemove = new ArrayList(); + for (String pattern : headerPatterns) { + if (StringUtils.hasLength(pattern)){ + if (pattern.contains("*")){ + for (String headerName : this.headers.keySet()) { + if (PatternMatchUtils.simpleMatch(pattern, headerName)){ + toRemove.add(headerName); + } + } + } + else { + toRemove.add(pattern); + } + } + } + for (String headerName : toRemove) { + this.headers.remove(headerName); + putOrRemove(headerName, null); + } + return this; + } + /** + * Remove the value for the given header name. + */ + public MessageBuilder removeHeader(String headerName) { + if (StringUtils.hasLength(headerName) && !isReadOnly(headerName)) { + this.headers.remove(headerName); + } + return this; + } + + /** + * Copy the name-value pairs from the provided Map. This operation will overwrite any + * existing values. Use { {@link #copyHeadersIfAbsent(Map)} to avoid overwriting + * values. Note that the 'id' and 'timestamp' header values will never be overwritten. + */ + public MessageBuilder copyHeaders(Map headersToCopy) { + Set keys = headersToCopy.keySet(); + for (String key : keys) { + if (!this.isReadOnly(key)) { + putOrRemove(key, headersToCopy.get(key)); + } + } + return this; + } + + /** + * Copy the name-value pairs from the provided Map. This operation will not + * overwrite any existing values. + */ + public MessageBuilder copyHeadersIfAbsent(Map headersToCopy) { + Set keys = headersToCopy.keySet(); + for (String key : keys) { + if (!this.isReadOnly(key) && (this.headers.get(key) == null)) { + putOrRemove(key, headersToCopy.get(key)); + } + } + return this; + } + + @SuppressWarnings("unchecked") + public Message build() { + + if (this.originalMessage != null + && this.headers.equals(this.originalMessage.getHeaders()) + && this.payload.equals(this.originalMessage.getPayload())) { + + return this.originalMessage; + } + +// if (this.payload instanceof Throwable) { +// return (Message) new ErrorMessage((Throwable) this.payload, this.headers); +// } + + this.headers.remove(MessageHeaders.ID); + this.headers.remove(MessageHeaders.TIMESTAMP); + + if (messageFactory == null) { + return new GenericMessage(this.payload, this.headers); + } + else { + return messageFactory.createMessage(payload, headers); + } + } + +} diff --git a/spring-context/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java b/spring-context/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java new file mode 100644 index 00000000000..bfd3bfc4196 --- /dev/null +++ b/spring-context/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java @@ -0,0 +1,169 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support; + +import java.util.Date; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +import org.junit.Test; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; + +import static org.junit.Assert.*; + + +/** + * @author Mark Fisher + */ +public class MessageBuilderTests { + + @Test + public void testSimpleMessageCreation() { + Message message = MessageBuilder.withPayload("foo").build(); + assertEquals("foo", message.getPayload()); + } + + @Test + public void testHeaderValues() { + Message message = MessageBuilder.withPayload("test") + .setHeader("foo", "bar") + .setHeader("count", new Integer(123)) + .build(); + assertEquals("bar", message.getHeaders().get("foo", String.class)); + assertEquals(new Integer(123), message.getHeaders().get("count", Integer.class)); + } + + @Test + public void testCopiedHeaderValues() { + Message message1 = MessageBuilder.withPayload("test1") + .setHeader("foo", "1") + .setHeader("bar", "2") + .build(); + Message message2 = MessageBuilder.withPayload("test2") + .copyHeaders(message1.getHeaders()) + .setHeader("foo", "42") + .setHeaderIfAbsent("bar", "99") + .build(); + assertEquals("test1", message1.getPayload()); + assertEquals("test2", message2.getPayload()); + assertEquals("1", message1.getHeaders().get("foo")); + assertEquals("42", message2.getHeaders().get("foo")); + assertEquals("2", message1.getHeaders().get("bar")); + assertEquals("2", message2.getHeaders().get("bar")); + } + + @Test(expected = IllegalArgumentException.class) + public void testIdHeaderValueReadOnly() { + UUID id = UUID.randomUUID(); + MessageBuilder.withPayload("test").setHeader(MessageHeaders.ID, id); + } + + @Test(expected = IllegalArgumentException.class) + public void testTimestampValueReadOnly() { + Long timestamp = 12345L; + MessageBuilder.withPayload("test").setHeader(MessageHeaders.TIMESTAMP, timestamp).build(); + } + + @Test + public void copyHeadersIfAbsent() { + Message message1 = MessageBuilder.withPayload("test1") + .setHeader("foo", "bar").build(); + Message message2 = MessageBuilder.withPayload("test2") + .setHeader("foo", 123) + .copyHeadersIfAbsent(message1.getHeaders()) + .build(); + assertEquals("test2", message2.getPayload()); + assertEquals(123, message2.getHeaders().get("foo")); + } + + @Test + public void createFromMessage() { + Message message1 = MessageBuilder.withPayload("test") + .setHeader("foo", "bar").build(); + Message message2 = MessageBuilder.fromMessage(message1).build(); + assertEquals("test", message2.getPayload()); + assertEquals("bar", message2.getHeaders().get("foo")); + } + + @Test + public void createIdRegenerated() { + Message message1 = MessageBuilder.withPayload("test") + .setHeader("foo", "bar").build(); + Message message2 = MessageBuilder.fromMessage(message1).setHeader("another", 1).build(); + assertEquals("bar", message2.getHeaders().get("foo")); + assertNotSame(message1.getHeaders().getId(), message2.getHeaders().getId()); + } + + @Test + public void testRemove() { + Message message1 = MessageBuilder.withPayload(1) + .setHeader("foo", "bar").build(); + Message message2 = MessageBuilder.fromMessage(message1) + .removeHeader("foo") + .build(); + assertFalse(message2.getHeaders().containsKey("foo")); + } + + @Test + public void testSettingToNullRemoves() { + Message message1 = MessageBuilder.withPayload(1) + .setHeader("foo", "bar").build(); + Message message2 = MessageBuilder.fromMessage(message1) + .setHeader("foo", null) + .build(); + assertFalse(message2.getHeaders().containsKey("foo")); + } + + @Test + public void testNotModifiedSameMessage() throws Exception { + Message original = MessageBuilder.withPayload("foo").build(); + Message result = MessageBuilder.fromMessage(original).build(); + assertEquals(original, result); + } + + @Test + public void testContainsHeaderNotModifiedSameMessage() throws Exception { + Message original = MessageBuilder.withPayload("foo").setHeader("bar", 42).build(); + Message result = MessageBuilder.fromMessage(original).build(); + assertEquals(original, result); + } + + @Test + public void testSameHeaderValueAddedNotModifiedSameMessage() throws Exception { + Message original = MessageBuilder.withPayload("foo").setHeader("bar", 42).build(); + Message result = MessageBuilder.fromMessage(original).setHeader("bar", 42).build(); + assertEquals(original, result); + } + + @Test + public void testCopySameHeaderValuesNotModifiedSameMessage() throws Exception { + Date current = new Date(); + Map originalHeaders = new HashMap(); + originalHeaders.put("b", "xyz"); + originalHeaders.put("c", current); + Message original = MessageBuilder.withPayload("foo").setHeader("a", 123).copyHeaders(originalHeaders).build(); + Map newHeaders = new HashMap(); + newHeaders.put("a", 123); + newHeaders.put("b", "xyz"); + newHeaders.put("c", current); + Message result = MessageBuilder.fromMessage(original).copyHeaders(newHeaders).build(); + assertEquals(original, result); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java index b0da4182ec9..873d9ef766f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java @@ -54,6 +54,7 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler>> subscriptionsBySession = new ConcurrentHashMap>>(); @@ -62,48 +58,17 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { super(publishChannel, clientChannel); this.reactor = reactor; this.payloadConverter = new CompositeMessageConverter(null); - this.messageFactory = new GenericMessageFactory(); - } - - public void setMessageFactory(MessageFactory messageFactory) { - this.messageFactory = messageFactory; } public void setMessageConverters(List converters) { this.payloadConverter = new CompositeMessageConverter(converters); } - @SuppressWarnings("unchecked") - @Override - public void handlePublish(Message message) { - - if (logger.isDebugEnabled()) { - logger.debug("Message received: " + message); - } - - try { - // Convert to byte[] payload before the fan-out - PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders()); - byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType()); - message = messageFactory.createMessage(payload, message.getHeaders()); - - this.reactor.notify(getPublishKey(inHeaders.getDestination()), Event.wrap(message)); - } - catch (Exception ex) { - logger.error("Failed to publish " + message, ex); - } - } - - private String getPublishKey(String destination) { - return "destination:" + destination; - } - @Override protected Collection getSupportedMessageTypes() { return Arrays.asList(MessageType.MESSAGE, MessageType.SUBSCRIBE, MessageType.UNSUBSCRIBE); } - @Override public void handleSubscribe(Message message) { @@ -112,33 +77,13 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { } PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); - final String subscriptionId = headers.getSubscriptionId(); - - Selector selector = new ObjectSelector(getPublishKey(headers.getDestination())); - Registration registration = this.reactor.on(selector, - new Consumer>>() { - @SuppressWarnings("unchecked") - @Override - public void accept(Event> event) { - Message message = event.getData(); - PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders()); - PubSubHeaders outHeaders = PubSubHeaders.create(); - outHeaders.setDestinations(inHeaders.getDestinations()); - if (inHeaders.getContentType() != null) { - outHeaders.setContentType(inHeaders.getContentType()); - } - outHeaders.setSubscriptionId(subscriptionId); - Object payload = message.getPayload(); - - Message outMessage = messageFactory.createMessage(payload, outHeaders.toMessageHeaders()); - getClientChannel().send(outMessage); - } - }); - - addSubscription(headers.getSessionId(), registration); - } + String subscriptionId = headers.getSubscriptionId(); + BroadcastingConsumer consumer = new BroadcastingConsumer(subscriptionId); - private void addSubscription(String sessionId, Registration registration) { + String key = getPublishKey(headers.getDestination()); + Registration registration = this.reactor.on(new ObjectSelector(key), consumer); + + String sessionId = headers.getSessionId(); List> list = this.subscriptionsBySession.get(sessionId); if (list == null) { list = new ArrayList>(); @@ -147,6 +92,30 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { list.add(registration); } + private String getPublishKey(String destination) { + return "destination:" + destination; + } + + @Override + public void handlePublish(Message message) { + + if (logger.isDebugEnabled()) { + logger.debug("Message received: " + message); + } + + try { + // Convert to byte[] payload before the fan-out + PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); + byte[] payload = payloadConverter.convertToPayload(message.getPayload(), headers.getContentType()); + message = MessageBuilder.fromPayloadAndHeaders(payload, message.getHeaders()).build(); + + this.reactor.notify(getPublishKey(headers.getDestination()), Event.wrap(message)); + } + catch (Exception ex) { + logger.error("Failed to publish " + message, ex); + } + } + @Override public void handleDisconnect(Message message) { PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); @@ -158,6 +127,7 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { removeSubscriptions(sessionId); } */ + private void removeSubscriptions(String sessionId) { List> registrations = this.subscriptionsBySession.remove(sessionId); if (logger.isTraceEnabled()) { @@ -168,4 +138,30 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { } } + + private final class BroadcastingConsumer implements Consumer>> { + + private final String subscriptionId; + + + private BroadcastingConsumer(String subscriptionId) { + this.subscriptionId = subscriptionId; + } + + @SuppressWarnings("unchecked") + @Override + public void accept(Event> event) { + + Message sentMessage = event.getData(); + + PubSubHeaders clientHeaders = PubSubHeaders.fromMessageHeaders(sentMessage.getHeaders()); + clientHeaders.setSubscriptionId(this.subscriptionId); + + Message clientMessage = MessageBuilder.fromPayloadAndHeaders(sentMessage.getPayload(), + clientHeaders.toMessageHeaders()).build(); + + getClientChannel().send(clientMessage); + } + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java index d26366b89e6..7587fe79d24 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java @@ -31,10 +31,8 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotationUtils; -import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.MessageFactory; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.annotation.MessageMapping; import org.springframework.stereotype.Controller; @@ -71,9 +69,6 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler private ReturnValueHandlerComposite returnValueHandlers = new ReturnValueHandlerComposite(); - private MessageFactory messageFactory = new GenericMessageFactory(); - - public AnnotationPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) { super(publishChannel, clientChannel); @@ -83,10 +78,6 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler this.messageConverters = converters; } - public void setMessageFactory(MessageFactory messageFactory) { - this.messageFactory = messageFactory; - } - @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { this.applicationContext = applicationContext; @@ -99,17 +90,13 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler @Override public void afterPropertiesSet() { - initHandlerMethods(); - MessageChannelArgumentResolver messageChannelArgumentResolver = new MessageChannelArgumentResolver(getPublishChannel()); - messageChannelArgumentResolver.setMessageFactory(messageFactory); - this.argumentResolvers.addResolver(messageChannelArgumentResolver); + initHandlerMethods(); + this.argumentResolvers.addResolver(new MessageChannelArgumentResolver(getPublishChannel())); this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverters)); - MessageReturnValueHandler messageReturnValueHandler = new MessageReturnValueHandler(getClientChannel()); - messageReturnValueHandler.setMessageFactory(messageFactory); - this.returnValueHandlers.addHandler(messageReturnValueHandler); + this.returnValueHandlers.addHandler(new MessageReturnValueHandler(getClientChannel())); } protected void initHandlerMethods() { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java index 3e7ba1c58bd..6b3a762e5a5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java @@ -17,10 +17,9 @@ package org.springframework.web.messaging.service.method; import org.springframework.core.MethodParameter; -import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.MessageFactory; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.web.messaging.PubSubHeaders; @@ -33,16 +32,10 @@ public class MessageChannelArgumentResolver implements ArgumentResolver { private final MessageChannel publishChannel; - private MessageFactory messageFactory; public MessageChannelArgumentResolver(MessageChannel publishChannel) { Assert.notNull(publishChannel, "publishChannel is required"); this.publishChannel = publishChannel; - this.messageFactory = new GenericMessageFactory(); - } - - public void setMessageFactory(MessageFactory messageFactory) { - this.messageFactory = messageFactory; } @Override @@ -67,7 +60,9 @@ public class MessageChannelArgumentResolver implements ArgumentResolver { public boolean send(Message message, long timeout) { PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); headers.setSessionId(sessionId); - publishChannel.send(messageFactory.createMessage(message.getPayload(), headers.toMessageHeaders())); + MessageBuilder messageToSend = MessageBuilder.fromPayloadAndHeaders( + message.getPayload(), headers.toMessageHeaders()); + publishChannel.send(messageToSend.build()); return true; } }; diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java index 4a1e69c04ba..4d53b339044 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java @@ -17,10 +17,11 @@ package org.springframework.web.messaging.service.method; import org.springframework.core.MethodParameter; -import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageFactory; +import org.springframework.messaging.support.GenericMessageFactory; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.web.messaging.PubSubHeaders; @@ -73,17 +74,24 @@ public class MessageReturnValueHandler implements ReturnValueHandler { return; } - PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders()); - String sessionId = inHeaders.getSessionId(); - String subscriptionId = inHeaders.getSubscriptionId(); - Assert.notNull(subscriptionId, "No subscription id: " + message); - - PubSubHeaders outHeaders = PubSubHeaders.fromMessageHeaders(returnMessage.getHeaders()); - outHeaders.setSessionId(sessionId); - outHeaders.setSubscriptionId(subscriptionId); - returnMessage = messageFactory.createMessage(returnMessage.getPayload(), outHeaders.toMessageHeaders()); + returnMessage = updateReturnMessage(returnMessage, message); this.clientChannel.send(returnMessage); } + protected Message updateReturnMessage(Message returnMessage, Message message) { + + PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); + String sessionId = headers.getSessionId(); + String subscriptionId = headers.getSubscriptionId(); + + Assert.notNull(subscriptionId, "No subscription id: " + message); + + PubSubHeaders returnHeaders = PubSubHeaders.fromMessageHeaders(returnMessage.getHeaders()); + returnHeaders.setSessionId(sessionId); + returnHeaders.setSubscriptionId(subscriptionId); + + return MessageBuilder.fromPayloadAndHeaders(returnMessage.getPayload(), returnHeaders.toMessageHeaders()).build(); + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java index 52086b68c65..124a38eaaff 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java @@ -19,12 +19,11 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.Charset; import java.util.List; -import java.util.Map; import java.util.Map.Entry; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageFactory; import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -51,7 +50,7 @@ public class StompMessageConverter { /** * @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String. */ - public > M toMessage(Object stompContent, String sessionId, MessageFactory messageFactory) { + public Message toMessage(Object stompContent, String sessionId) { byte[] byteContent = null; if (stompContent instanceof String) { @@ -102,7 +101,7 @@ public class StompMessageConverter { byte[] payload = new byte[totalLength - payloadIndex]; System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex); - return createMessage(command, stompHeaders.toMessageHeaders(), payload, messageFactory); + return MessageBuilder.fromPayloadAndHeaders(payload, stompHeaders.toMessageHeaders()).build(); } private int findIndexOfPayload(byte[] bytes) { @@ -132,10 +131,6 @@ public class StompMessageConverter { return index; } - protected > M createMessage(StompCommand command, Map headers, byte[] payload, MessageFactory messageFactory) { - return messageFactory.createMessage(payload, headers); - } - public byte[] fromMessage(Message message) { ByteArrayOutputStream out = new ByteArrayOutputStream(); MessageHeaders messageHeaders = message.getHeaders(); diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java index 5783971b72f..7f7c3ae6e3e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java @@ -23,11 +23,10 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.springframework.http.MediaType; -import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.MessageFactory; import org.springframework.messaging.SubscribableChannel; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.PubSubHeaders; @@ -57,8 +56,6 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler private MessageConverter payloadConverter; - private MessageFactory messageFactory = new GenericMessageFactory(); - private final TcpClient tcpClient; private final Map> connections = @@ -82,10 +79,6 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler this.payloadConverter = new CompositeMessageConverter(converters); } - public void setMessageFactory(MessageFactory messageFactory) { - this.messageFactory = messageFactory; - } - @Override protected Collection getSupportedMessageTypes() { return null; @@ -117,7 +110,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler // TODO: why are we getting empty frames? return; } - Message message = stompMessageConverter.toMessage(stompFrame, sessionId, messageFactory); + Message message = stompMessageConverter.toMessage(stompFrame, sessionId); getClientChannel().send(message); } }); @@ -134,19 +127,18 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler } - @SuppressWarnings("unchecked") private void forwardMessage(Message message, StompCommand command) { - StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); - String sessionId = stompHeaders.getSessionId(); + StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders()); + String sessionId = headers.getSessionId(); byte[] bytesToWrite; try { - stompHeaders.setStompCommandIfNotSet(StompCommand.SEND); + headers.setStompCommandIfNotSet(StompCommand.SEND); - MediaType contentType = stompHeaders.getContentType(); + MediaType contentType = headers.getContentType(); byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType); - Message byteMessage = messageFactory.createMessage(payload, stompHeaders.toMessageHeaders()); + Message byteMessage = MessageBuilder.fromPayloadAndHeaders(payload, headers.toMessageHeaders()).build(); bytesToWrite = this.stompMessageConverter.fromMessage(byteMessage); } catch (Throwable ex) { @@ -158,7 +150,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler Assert.notNull(connection, "TCP connection to message broker not found, sessionId=" + sessionId); try { if (logger.isTraceEnabled()) { - logger.trace("Forwarding STOMP " + stompHeaders.getStompCommand() + " message"); + logger.trace("Forwarding STOMP " + headers.getStompCommand() + " message"); } connection.out().accept(new String(bytesToWrite, Charset.forName("UTF-8"))); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java index e09a905be21..3a03c3a68e2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java @@ -25,12 +25,11 @@ import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.http.MediaType; -import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.MessageFactory; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.SubscribableChannel; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.converter.CompositeMessageConverter; @@ -50,6 +49,11 @@ import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; */ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { + /** + * + */ + private static final byte[] EMPTY_PAYLOAD = new byte[0]; + private static Log logger = LogFactory.getLog(StompWebSocketHandler.class); private final MessageChannel publishChannel; @@ -60,8 +64,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { private MessageConverter payloadConverter = new CompositeMessageConverter(null); - private MessageFactory messageFactory = new GenericMessageFactory(); - @SuppressWarnings("unchecked") public StompWebSocketHandler(MessageChannel publishChannel, SubscribableChannel clientChannel) { @@ -78,10 +80,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { this.payloadConverter = new CompositeMessageConverter(converters); } - public void setMessageFactory(MessageFactory messageFactory) { - this.messageFactory = messageFactory; - } - public StompMessageConverter getStompMessageConverter() { return this.stompMessageConverter; } @@ -101,7 +99,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) { try { String payload = textMessage.getPayload(); - Message message = this.stompMessageConverter.toMessage(payload, session.getId(), messageFactory); + Message message = this.stompMessageConverter.toMessage(payload, session.getId()); // TODO: validate size limits // http://stomp.github.io/stomp-specification-1.2.html#Size_Limits @@ -144,18 +142,17 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { } } - @SuppressWarnings("unchecked") protected void handleConnect(final WebSocketSession session, Message message) throws IOException { - StompHeaders connectStompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); - StompHeaders connectedStompHeaders = StompHeaders.create(StompCommand.CONNECTED); + StompHeaders connectHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); + StompHeaders connectedHeaders = StompHeaders.create(StompCommand.CONNECTED); - Set acceptVersions = connectStompHeaders.getAcceptVersion(); + Set acceptVersions = connectHeaders.getAcceptVersion(); if (acceptVersions.contains("1.2")) { - connectedStompHeaders.setAcceptVersion("1.2"); + connectedHeaders.setAcceptVersion("1.2"); } else if (acceptVersions.contains("1.1")) { - connectedStompHeaders.setAcceptVersion("1.1"); + connectedHeaders.setAcceptVersion("1.1"); } else if (acceptVersions.isEmpty()) { // 1.0 @@ -163,11 +160,12 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { else { throw new StompConversionException("Unsupported version '" + acceptVersions + "'"); } - connectedStompHeaders.setHeartbeat(0,0); // TODO + connectedHeaders.setHeartbeat(0,0); // TODO // TODO: security - Message connectedMessage = messageFactory.createMessage(new byte[0], connectedStompHeaders.toMessageHeaders()); + Message connectedMessage = MessageBuilder.fromPayloadAndHeaders(EMPTY_PAYLOAD, + connectedHeaders.toMessageHeaders()).build(); byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } @@ -187,14 +185,14 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { protected void handleDisconnect(Message stompMessage) { } - @SuppressWarnings("unchecked") protected void sendErrorMessage(WebSocketSession session, Throwable error) { - StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR); - stompHeaders.setMessage(error.getMessage()); + StompHeaders headers = StompHeaders.create(StompCommand.ERROR); + headers.setMessage(error.getMessage()); - Message errorMessage = messageFactory.createMessage(new byte[0], stompHeaders.toMessageHeaders()); - byte[] bytes = this.stompMessageConverter.fromMessage(errorMessage); + Message message = MessageBuilder.fromPayloadAndHeaders(EMPTY_PAYLOAD, + headers.toMessageHeaders()).build(); + byte[] bytes = this.stompMessageConverter.fromMessage(message); try { session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); @@ -214,19 +212,18 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { private final class ClientMessageConsumer implements MessageHandler> { - @SuppressWarnings("unchecked") @Override public void handleMessage(Message message) { - StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); - stompHeaders.setStompCommandIfNotSet(StompCommand.MESSAGE); + StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders()); + headers.setStompCommandIfNotSet(StompCommand.MESSAGE); - if (StompCommand.CONNECTED.equals(stompHeaders.getStompCommand())) { + if (StompCommand.CONNECTED.equals(headers.getStompCommand())) { // Ignore for now since we already sent it return; } - String sessionId = stompHeaders.getSessionId(); + String sessionId = headers.getSessionId(); if (sessionId == null) { logger.error("No \"sessionId\" header in message: " + message); } @@ -237,7 +234,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { byte[] payload; try { - MediaType contentType = stompHeaders.getContentType(); + MediaType contentType = headers.getContentType(); payload = payloadConverter.convertToPayload(message.getPayload(), contentType); } catch (Throwable t) { @@ -246,8 +243,8 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { } try { - Map messageHeaders = stompHeaders.toMessageHeaders(); - Message byteMessage = messageFactory.createMessage(payload, messageHeaders); + Message byteMessage = MessageBuilder.fromPayloadAndHeaders(payload, + headers.toMessageHeaders()).build(); byte[] bytes = getStompMessageConverter().fromMessage(byteMessage); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } @@ -255,7 +252,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { sendErrorMessage(session, t); } finally { - if (StompCommand.ERROR.equals(stompHeaders.getStompCommand())) { + if (StompCommand.ERROR.equals(headers.getStompCommand())) { try { session.close(CloseStatus.PROTOCOL_ERROR); } diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java index 366e9f758f2..e2eaed097f5 100644 --- a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java @@ -19,9 +19,7 @@ import java.util.Collections; import org.junit.Before; import org.junit.Test; -import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageFactory; import org.springframework.messaging.MessageHeaders; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.stomp.StompCommand; @@ -37,22 +35,19 @@ public class StompMessageConverterTests { private StompMessageConverter converter; - private MessageFactory messageFactory = new GenericMessageFactory(); - @Before public void setup() { this.converter = new StompMessageConverter(); } - @SuppressWarnings("unchecked") @Test public void connectFrame() throws Exception { String accept = "accept-version:1.1\n"; String host = "host:github.org\n"; String frame = "\n\n\nCONNECT\n" + accept + host + "\n"; - Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123", messageFactory); + Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); assertEquals(0, message.getPayload().length); @@ -76,14 +71,13 @@ public class StompMessageConverterTests { assertTrue(convertedBack.contains(host)); } - @SuppressWarnings("unchecked") @Test public void connectWithEscapes() throws Exception { String accept = "accept-version:1.1\n"; String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; String frame = "CONNECT\n" + accept + host + "\n"; - Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123", messageFactory); + Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); assertEquals(0, message.getPayload().length); @@ -99,14 +93,13 @@ public class StompMessageConverterTests { assertTrue(convertedBack.contains(host)); } - @SuppressWarnings("unchecked") @Test public void connectCR12() throws Exception { String accept = "accept-version:1.2\n"; String host = "host:github.org\n"; String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; - Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123", messageFactory); + Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); assertEquals(0, message.getPayload().length); @@ -122,14 +115,13 @@ public class StompMessageConverterTests { assertTrue(convertedBack.contains(host)); } - @SuppressWarnings("unchecked") @Test public void connectWithEscapesAndCR12() throws Exception { String accept = "accept-version:1.1\n"; String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; - Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123", messageFactory); + Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); assertEquals(0, message.getPayload().length);