Browse Source

Add MessageBuilder

pull/286/merge
Rossen Stoyanchev 13 years ago
parent
commit
c5b1f02c3a
  1. 1
      spring-context/src/main/java/org/springframework/messaging/MessageChannel.java
  2. 3
      spring-context/src/main/java/org/springframework/messaging/MessageFactory.java
  3. 1
      spring-context/src/main/java/org/springframework/messaging/MessageHandler.java
  4. 4
      spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java
  5. 4
      spring-context/src/main/java/org/springframework/messaging/support/GenericMessage.java
  6. 8
      spring-context/src/main/java/org/springframework/messaging/support/GenericMessageFactory.java
  7. 245
      spring-context/src/main/java/org/springframework/messaging/support/MessageBuilder.java
  8. 169
      spring-context/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java
  9. 1
      spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java
  10. 120
      spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java
  11. 19
      spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java
  12. 13
      spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java
  13. 28
      spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java
  14. 11
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java
  15. 24
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java
  16. 59
      spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java
  17. 16
      spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java

1
spring-context/src/main/java/org/springframework/messaging/MessageChannel.java

@ -23,6 +23,7 @@ package org.springframework.messaging; @@ -23,6 +23,7 @@ package org.springframework.messaging;
* @author Mark Fisher
* @since 4.0
*/
@SuppressWarnings("rawtypes")
public interface MessageChannel<M extends Message> {
/**

3
spring-context/src/main/java/org/springframework/messaging/MessageFactory.java

@ -22,9 +22,8 @@ import java.util.Map; @@ -22,9 +22,8 @@ import java.util.Map;
/**
* A factory for creating messages, allowing for control of the concrete type of the message.
*
*
*
* @author Andy Wilkinson
* @since 4.0
*/
public interface MessageFactory<M extends Message<?>> {

1
spring-context/src/main/java/org/springframework/messaging/MessageHandler.java

@ -24,6 +24,7 @@ package org.springframework.messaging; @@ -24,6 +24,7 @@ package org.springframework.messaging;
* @author Iwein Fuld
* @since 4.0
*/
@SuppressWarnings("rawtypes")
public interface MessageHandler<M extends Message> {
/**

4
spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java

@ -25,7 +25,9 @@ package org.springframework.messaging; @@ -25,7 +25,9 @@ package org.springframework.messaging;
* @author Mark Fisher
* @since 4.0
*/
public interface SubscribableChannel<M extends Message, H extends MessageHandler<M>> extends MessageChannel<M> {
@SuppressWarnings("rawtypes")
public interface SubscribableChannel<M extends Message, H extends MessageHandler<M>>
extends MessageChannel<M> {
/**
* Register a {@link MessageHandler} as a subscriber to this channel.

4
spring-context/src/main/java/org/springframework/messaging/GenericMessage.java → spring-context/src/main/java/org/springframework/messaging/support/GenericMessage.java

@ -14,12 +14,14 @@ @@ -14,12 +14,14 @@
* limitations under the License.
*/
package org.springframework.messaging;
package org.springframework.messaging.support;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;

8
spring-context/src/main/java/org/springframework/messaging/GenericMessageFactory.java → spring-context/src/main/java/org/springframework/messaging/support/GenericMessageFactory.java

@ -14,20 +14,24 @@ @@ -14,20 +14,24 @@
* limitations under the License.
*/
package org.springframework.messaging;
package org.springframework.messaging.support;
import java.util.Map;
import org.springframework.messaging.MessageFactory;
/**
* A {@link MessageFactory} that creates {@link GenericMessage GenericMessages}.
*
* @author Andy Wilkinson
* @since 4.0
*/
public class GenericMessageFactory implements MessageFactory<GenericMessage<?>> {
@Override
public <P> GenericMessage<?> createMessage(P payload, Map<String, Object> headers) {
public <P> GenericMessage<P> createMessage(P payload, Map<String, Object> headers) {
return new GenericMessage<P>(payload, headers);
}

245
spring-context/src/main/java/org/springframework/messaging/support/MessageBuilder.java

@ -0,0 +1,245 @@ @@ -0,0 +1,245 @@
/*
* Copyright 2002-2010 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.support;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert;
import org.springframework.util.PatternMatchUtils;
import org.springframework.util.StringUtils;
/**
* @author Arjen Poutsma
* @author Mark Fisher
* @author Oleg Zhurakousky
* @author Dave Syer
* @since 4.0
*/
public final class MessageBuilder<T> {
private final T payload;
private final Map<String, Object> headers = new HashMap<String, Object>();
private final Message<T> originalMessage;
@SuppressWarnings("rawtypes")
private static volatile MessageFactory messageFactory = null;
/**
* A constructor with payload and an optional message to copy headers from.
* This is a private constructor to be invoked from the static factory methods only.
*
* @param payload the message payload, never {@code null}
* @param originalMessage a message to copy from or re-use if no changes are made, can
* be {@code null}
*/
private MessageBuilder(T payload, Message<T> originalMessage) {
Assert.notNull(payload, "payload is required");
this.payload = payload;
this.originalMessage = originalMessage;
if (originalMessage != null) {
this.headers.putAll(originalMessage.getHeaders());
}
}
/**
* Private constructor to be invoked from the static factory methods only.
*
* @param payload the message payload, never {@code null}
* @param originalMessage a message to copy from or re-use if no changes are made, can
* be {@code null}
*/
private MessageBuilder(T payload, Map<String, Object> headers) {
Assert.notNull(payload, "payload is required");
Assert.notNull(headers, "headers is required");
this.payload = payload;
this.headers.putAll(headers);
this.originalMessage = null;
}
/**
* Create a builder for a new {@link Message} instance pre-populated with all of the
* headers copied from the provided message. The payload of the provided Message will
* also be used as the payload for the new message.
*
* @param message the Message from which the payload and all headers will be copied
*/
public static <T> MessageBuilder<T> fromMessage(Message<T> message) {
Assert.notNull(message, "message must not be null");
MessageBuilder<T> builder = new MessageBuilder<T>(message.getPayload(), message);
return builder;
}
/**
* Create a builder for a new {@link Message} instance with the provided payload and
* headers.
*
* @param payload the payload for the new message
* @param headers the headers to use
*/
public static <T> MessageBuilder<T> fromPayloadAndHeaders(T payload, Map<String, Object> headers) {
MessageBuilder<T> builder = new MessageBuilder<T>(payload, headers);
return builder;
}
/**
* Create a builder for a new {@link Message} instance with the provided payload.
*
* @param payload the payload for the new message
*/
public static <T> MessageBuilder<T> withPayload(T payload) {
MessageBuilder<T> builder = new MessageBuilder<T>(payload, (Message<T>) null);
return builder;
}
/**
* Set the value for the given header name. If the provided value is <code>null</code>
* the header will be removed.
*/
public MessageBuilder<T> setHeader(String headerName, Object headerValue) {
Assert.isTrue(!this.isReadOnly(headerName), "The '" + headerName + "' header is read-only.");
if (StringUtils.hasLength(headerName)) {
putOrRemove(headerName, headerValue);
}
return this;
}
private boolean isReadOnly(String headerName) {
return MessageHeaders.ID.equals(headerName) || MessageHeaders.TIMESTAMP.equals(headerName);
}
private void putOrRemove(String headerName, Object headerValue) {
if (headerValue == null) {
this.headers.remove(headerName);
}
else {
this.headers.put(headerName, headerValue);
}
}
/**
* Set the value for the given header name only if the header name is not already
* associated with a value.
*/
public MessageBuilder<T> setHeaderIfAbsent(String headerName, Object headerValue) {
if (this.headers.get(headerName) == null) {
putOrRemove(headerName, headerValue);
}
return this;
}
/**
* Removes all headers provided via array of 'headerPatterns'. As the name suggests
* the array may contain simple matching patterns for header names. Supported pattern
* styles are: "xxx*", "*xxx", "*xxx*" and "xxx*yyy".
*/
public MessageBuilder<T> removeHeaders(String... headerPatterns) {
List<String> toRemove = new ArrayList<String>();
for (String pattern : headerPatterns) {
if (StringUtils.hasLength(pattern)){
if (pattern.contains("*")){
for (String headerName : this.headers.keySet()) {
if (PatternMatchUtils.simpleMatch(pattern, headerName)){
toRemove.add(headerName);
}
}
}
else {
toRemove.add(pattern);
}
}
}
for (String headerName : toRemove) {
this.headers.remove(headerName);
putOrRemove(headerName, null);
}
return this;
}
/**
* Remove the value for the given header name.
*/
public MessageBuilder<T> removeHeader(String headerName) {
if (StringUtils.hasLength(headerName) && !isReadOnly(headerName)) {
this.headers.remove(headerName);
}
return this;
}
/**
* Copy the name-value pairs from the provided Map. This operation will overwrite any
* existing values. Use { {@link #copyHeadersIfAbsent(Map)} to avoid overwriting
* values. Note that the 'id' and 'timestamp' header values will never be overwritten.
*/
public MessageBuilder<T> copyHeaders(Map<String, ?> headersToCopy) {
Set<String> keys = headersToCopy.keySet();
for (String key : keys) {
if (!this.isReadOnly(key)) {
putOrRemove(key, headersToCopy.get(key));
}
}
return this;
}
/**
* Copy the name-value pairs from the provided Map. This operation will <em>not</em>
* overwrite any existing values.
*/
public MessageBuilder<T> copyHeadersIfAbsent(Map<String, ?> headersToCopy) {
Set<String> keys = headersToCopy.keySet();
for (String key : keys) {
if (!this.isReadOnly(key) && (this.headers.get(key) == null)) {
putOrRemove(key, headersToCopy.get(key));
}
}
return this;
}
@SuppressWarnings("unchecked")
public Message<T> build() {
if (this.originalMessage != null
&& this.headers.equals(this.originalMessage.getHeaders())
&& this.payload.equals(this.originalMessage.getPayload())) {
return this.originalMessage;
}
// if (this.payload instanceof Throwable) {
// return (Message<T>) new ErrorMessage((Throwable) this.payload, this.headers);
// }
this.headers.remove(MessageHeaders.ID);
this.headers.remove(MessageHeaders.TIMESTAMP);
if (messageFactory == null) {
return new GenericMessage<T>(this.payload, this.headers);
}
else {
return messageFactory.createMessage(payload, headers);
}
}
}

169
spring-context/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java

@ -0,0 +1,169 @@ @@ -0,0 +1,169 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.support;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import org.junit.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import static org.junit.Assert.*;
/**
* @author Mark Fisher
*/
public class MessageBuilderTests {
@Test
public void testSimpleMessageCreation() {
Message<String> message = MessageBuilder.withPayload("foo").build();
assertEquals("foo", message.getPayload());
}
@Test
public void testHeaderValues() {
Message<String> message = MessageBuilder.withPayload("test")
.setHeader("foo", "bar")
.setHeader("count", new Integer(123))
.build();
assertEquals("bar", message.getHeaders().get("foo", String.class));
assertEquals(new Integer(123), message.getHeaders().get("count", Integer.class));
}
@Test
public void testCopiedHeaderValues() {
Message<String> message1 = MessageBuilder.withPayload("test1")
.setHeader("foo", "1")
.setHeader("bar", "2")
.build();
Message<String> message2 = MessageBuilder.withPayload("test2")
.copyHeaders(message1.getHeaders())
.setHeader("foo", "42")
.setHeaderIfAbsent("bar", "99")
.build();
assertEquals("test1", message1.getPayload());
assertEquals("test2", message2.getPayload());
assertEquals("1", message1.getHeaders().get("foo"));
assertEquals("42", message2.getHeaders().get("foo"));
assertEquals("2", message1.getHeaders().get("bar"));
assertEquals("2", message2.getHeaders().get("bar"));
}
@Test(expected = IllegalArgumentException.class)
public void testIdHeaderValueReadOnly() {
UUID id = UUID.randomUUID();
MessageBuilder.withPayload("test").setHeader(MessageHeaders.ID, id);
}
@Test(expected = IllegalArgumentException.class)
public void testTimestampValueReadOnly() {
Long timestamp = 12345L;
MessageBuilder.withPayload("test").setHeader(MessageHeaders.TIMESTAMP, timestamp).build();
}
@Test
public void copyHeadersIfAbsent() {
Message<String> message1 = MessageBuilder.withPayload("test1")
.setHeader("foo", "bar").build();
Message<String> message2 = MessageBuilder.withPayload("test2")
.setHeader("foo", 123)
.copyHeadersIfAbsent(message1.getHeaders())
.build();
assertEquals("test2", message2.getPayload());
assertEquals(123, message2.getHeaders().get("foo"));
}
@Test
public void createFromMessage() {
Message<String> message1 = MessageBuilder.withPayload("test")
.setHeader("foo", "bar").build();
Message<String> message2 = MessageBuilder.fromMessage(message1).build();
assertEquals("test", message2.getPayload());
assertEquals("bar", message2.getHeaders().get("foo"));
}
@Test
public void createIdRegenerated() {
Message<String> message1 = MessageBuilder.withPayload("test")
.setHeader("foo", "bar").build();
Message<String> message2 = MessageBuilder.fromMessage(message1).setHeader("another", 1).build();
assertEquals("bar", message2.getHeaders().get("foo"));
assertNotSame(message1.getHeaders().getId(), message2.getHeaders().getId());
}
@Test
public void testRemove() {
Message<Integer> message1 = MessageBuilder.withPayload(1)
.setHeader("foo", "bar").build();
Message<Integer> message2 = MessageBuilder.fromMessage(message1)
.removeHeader("foo")
.build();
assertFalse(message2.getHeaders().containsKey("foo"));
}
@Test
public void testSettingToNullRemoves() {
Message<Integer> message1 = MessageBuilder.withPayload(1)
.setHeader("foo", "bar").build();
Message<Integer> message2 = MessageBuilder.fromMessage(message1)
.setHeader("foo", null)
.build();
assertFalse(message2.getHeaders().containsKey("foo"));
}
@Test
public void testNotModifiedSameMessage() throws Exception {
Message<?> original = MessageBuilder.withPayload("foo").build();
Message<?> result = MessageBuilder.fromMessage(original).build();
assertEquals(original, result);
}
@Test
public void testContainsHeaderNotModifiedSameMessage() throws Exception {
Message<?> original = MessageBuilder.withPayload("foo").setHeader("bar", 42).build();
Message<?> result = MessageBuilder.fromMessage(original).build();
assertEquals(original, result);
}
@Test
public void testSameHeaderValueAddedNotModifiedSameMessage() throws Exception {
Message<?> original = MessageBuilder.withPayload("foo").setHeader("bar", 42).build();
Message<?> result = MessageBuilder.fromMessage(original).setHeader("bar", 42).build();
assertEquals(original, result);
}
@Test
public void testCopySameHeaderValuesNotModifiedSameMessage() throws Exception {
Date current = new Date();
Map<String, Object> originalHeaders = new HashMap<String, Object>();
originalHeaders.put("b", "xyz");
originalHeaders.put("c", current);
Message<?> original = MessageBuilder.withPayload("foo").setHeader("a", 123).copyHeaders(originalHeaders).build();
Map<String, Object> newHeaders = new HashMap<String, Object>();
newHeaders.put("a", 123);
newHeaders.put("b", "xyz");
newHeaders.put("c", current);
Message<?> result = MessageBuilder.fromMessage(original).copyHeaders(newHeaders).build();
assertEquals(original, result);
}
}

1
spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java

@ -54,6 +54,7 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler<Mes @@ -54,6 +54,7 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler<Mes
private final PathMatcher pathMatcher = new AntPathMatcher();
/**
* @param publishChannel a channel for publishing messages from within the
* application

120
spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java

@ -23,11 +23,10 @@ import java.util.List; @@ -23,11 +23,10 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
@ -38,7 +37,6 @@ import reactor.fn.Consumer; @@ -38,7 +37,6 @@ import reactor.fn.Consumer;
import reactor.fn.Event;
import reactor.fn.registry.Registration;
import reactor.fn.selector.ObjectSelector;
import reactor.fn.selector.Selector;
/**
@ -51,8 +49,6 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { @@ -51,8 +49,6 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler {
private MessageConverter payloadConverter;
private MessageFactory messageFactory;
private Map<String, List<Registration<?>>> subscriptionsBySession = new ConcurrentHashMap<String, List<Registration<?>>>();
@ -62,48 +58,17 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { @@ -62,48 +58,17 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler {
super(publishChannel, clientChannel);
this.reactor = reactor;
this.payloadConverter = new CompositeMessageConverter(null);
this.messageFactory = new GenericMessageFactory();
}
public void setMessageFactory(MessageFactory messageFactory) {
this.messageFactory = messageFactory;
}
public void setMessageConverters(List<MessageConverter> converters) {
this.payloadConverter = new CompositeMessageConverter(converters);
}
@SuppressWarnings("unchecked")
@Override
public void handlePublish(Message<?> message) {
if (logger.isDebugEnabled()) {
logger.debug("Message received: " + message);
}
try {
// Convert to byte[] payload before the fan-out
PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType());
message = messageFactory.createMessage(payload, message.getHeaders());
this.reactor.notify(getPublishKey(inHeaders.getDestination()), Event.wrap(message));
}
catch (Exception ex) {
logger.error("Failed to publish " + message, ex);
}
}
private String getPublishKey(String destination) {
return "destination:" + destination;
}
@Override
protected Collection<MessageType> getSupportedMessageTypes() {
return Arrays.asList(MessageType.MESSAGE, MessageType.SUBSCRIBE, MessageType.UNSUBSCRIBE);
}
@Override
public void handleSubscribe(Message<?> message) {
@ -112,33 +77,13 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { @@ -112,33 +77,13 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler {
}
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
final String subscriptionId = headers.getSubscriptionId();
Selector selector = new ObjectSelector<String>(getPublishKey(headers.getDestination()));
Registration<?> registration = this.reactor.on(selector,
new Consumer<Event<Message<?>>>() {
@SuppressWarnings("unchecked")
@Override
public void accept(Event<Message<?>> event) {
Message<?> message = event.getData();
PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
PubSubHeaders outHeaders = PubSubHeaders.create();
outHeaders.setDestinations(inHeaders.getDestinations());
if (inHeaders.getContentType() != null) {
outHeaders.setContentType(inHeaders.getContentType());
}
outHeaders.setSubscriptionId(subscriptionId);
Object payload = message.getPayload();
Message outMessage = messageFactory.createMessage(payload, outHeaders.toMessageHeaders());
getClientChannel().send(outMessage);
}
});
addSubscription(headers.getSessionId(), registration);
}
String subscriptionId = headers.getSubscriptionId();
BroadcastingConsumer consumer = new BroadcastingConsumer(subscriptionId);
private void addSubscription(String sessionId, Registration<?> registration) {
String key = getPublishKey(headers.getDestination());
Registration<?> registration = this.reactor.on(new ObjectSelector<String>(key), consumer);
String sessionId = headers.getSessionId();
List<Registration<?>> list = this.subscriptionsBySession.get(sessionId);
if (list == null) {
list = new ArrayList<Registration<?>>();
@ -147,6 +92,30 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { @@ -147,6 +92,30 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler {
list.add(registration);
}
private String getPublishKey(String destination) {
return "destination:" + destination;
}
@Override
public void handlePublish(Message<?> message) {
if (logger.isDebugEnabled()) {
logger.debug("Message received: " + message);
}
try {
// Convert to byte[] payload before the fan-out
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
byte[] payload = payloadConverter.convertToPayload(message.getPayload(), headers.getContentType());
message = MessageBuilder.fromPayloadAndHeaders(payload, message.getHeaders()).build();
this.reactor.notify(getPublishKey(headers.getDestination()), Event.wrap(message));
}
catch (Exception ex) {
logger.error("Failed to publish " + message, ex);
}
}
@Override
public void handleDisconnect(Message<?> message) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
@ -158,6 +127,7 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { @@ -158,6 +127,7 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler {
removeSubscriptions(sessionId);
}
*/
private void removeSubscriptions(String sessionId) {
List<Registration<?>> registrations = this.subscriptionsBySession.remove(sessionId);
if (logger.isTraceEnabled()) {
@ -168,4 +138,30 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { @@ -168,4 +138,30 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler {
}
}
private final class BroadcastingConsumer implements Consumer<Event<Message<?>>> {
private final String subscriptionId;
private BroadcastingConsumer(String subscriptionId) {
this.subscriptionId = subscriptionId;
}
@SuppressWarnings("unchecked")
@Override
public void accept(Event<Message<?>> event) {
Message<?> sentMessage = event.getData();
PubSubHeaders clientHeaders = PubSubHeaders.fromMessageHeaders(sentMessage.getHeaders());
clientHeaders.setSubscriptionId(this.subscriptionId);
Message<?> clientMessage = MessageBuilder.fromPayloadAndHeaders(sentMessage.getPayload(),
clientHeaders.toMessageHeaders()).build();
getClientChannel().send(clientMessage);
}
}
}

19
spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java

@ -31,10 +31,8 @@ import org.springframework.context.ApplicationContext; @@ -31,10 +31,8 @@ import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.annotation.MessageMapping;
import org.springframework.stereotype.Controller;
@ -71,9 +69,6 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler @@ -71,9 +69,6 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler
private ReturnValueHandlerComposite returnValueHandlers = new ReturnValueHandlerComposite();
private MessageFactory messageFactory = new GenericMessageFactory();
public AnnotationPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) {
super(publishChannel, clientChannel);
@ -83,10 +78,6 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler @@ -83,10 +78,6 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler
this.messageConverters = converters;
}
public void setMessageFactory(MessageFactory messageFactory) {
this.messageFactory = messageFactory;
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
@ -99,17 +90,13 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler @@ -99,17 +90,13 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler
@Override
public void afterPropertiesSet() {
initHandlerMethods();
MessageChannelArgumentResolver messageChannelArgumentResolver = new MessageChannelArgumentResolver(getPublishChannel());
messageChannelArgumentResolver.setMessageFactory(messageFactory);
this.argumentResolvers.addResolver(messageChannelArgumentResolver);
initHandlerMethods();
this.argumentResolvers.addResolver(new MessageChannelArgumentResolver(getPublishChannel()));
this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverters));
MessageReturnValueHandler messageReturnValueHandler = new MessageReturnValueHandler(getClientChannel());
messageReturnValueHandler.setMessageFactory(messageFactory);
this.returnValueHandlers.addHandler(messageReturnValueHandler);
this.returnValueHandlers.addHandler(new MessageReturnValueHandler(getClientChannel()));
}
protected void initHandlerMethods() {

13
spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java

@ -17,10 +17,9 @@ @@ -17,10 +17,9 @@
package org.springframework.web.messaging.service.method;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.web.messaging.PubSubHeaders;
@ -33,16 +32,10 @@ public class MessageChannelArgumentResolver implements ArgumentResolver { @@ -33,16 +32,10 @@ public class MessageChannelArgumentResolver implements ArgumentResolver {
private final MessageChannel publishChannel;
private MessageFactory messageFactory;
public MessageChannelArgumentResolver(MessageChannel publishChannel) {
Assert.notNull(publishChannel, "publishChannel is required");
this.publishChannel = publishChannel;
this.messageFactory = new GenericMessageFactory();
}
public void setMessageFactory(MessageFactory messageFactory) {
this.messageFactory = messageFactory;
}
@Override
@ -67,7 +60,9 @@ public class MessageChannelArgumentResolver implements ArgumentResolver { @@ -67,7 +60,9 @@ public class MessageChannelArgumentResolver implements ArgumentResolver {
public boolean send(Message<?> message, long timeout) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
headers.setSessionId(sessionId);
publishChannel.send(messageFactory.createMessage(message.getPayload(), headers.toMessageHeaders()));
MessageBuilder<?> messageToSend = MessageBuilder.fromPayloadAndHeaders(
message.getPayload(), headers.toMessageHeaders());
publishChannel.send(messageToSend.build());
return true;
}
};

28
spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java

@ -17,10 +17,11 @@ @@ -17,10 +17,11 @@
package org.springframework.web.messaging.service.method;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.support.GenericMessageFactory;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.web.messaging.PubSubHeaders;
@ -73,17 +74,24 @@ public class MessageReturnValueHandler implements ReturnValueHandler { @@ -73,17 +74,24 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
return;
}
PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
String sessionId = inHeaders.getSessionId();
String subscriptionId = inHeaders.getSubscriptionId();
Assert.notNull(subscriptionId, "No subscription id: " + message);
PubSubHeaders outHeaders = PubSubHeaders.fromMessageHeaders(returnMessage.getHeaders());
outHeaders.setSessionId(sessionId);
outHeaders.setSubscriptionId(subscriptionId);
returnMessage = messageFactory.createMessage(returnMessage.getPayload(), outHeaders.toMessageHeaders());
returnMessage = updateReturnMessage(returnMessage, message);
this.clientChannel.send(returnMessage);
}
protected Message<?> updateReturnMessage(Message<?> returnMessage, Message<?> message) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
String sessionId = headers.getSessionId();
String subscriptionId = headers.getSubscriptionId();
Assert.notNull(subscriptionId, "No subscription id: " + message);
PubSubHeaders returnHeaders = PubSubHeaders.fromMessageHeaders(returnMessage.getHeaders());
returnHeaders.setSessionId(sessionId);
returnHeaders.setSubscriptionId(subscriptionId);
return MessageBuilder.fromPayloadAndHeaders(returnMessage.getPayload(), returnHeaders.toMessageHeaders()).build();
}
}

11
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java

@ -19,12 +19,11 @@ import java.io.ByteArrayOutputStream; @@ -19,12 +19,11 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
@ -51,7 +50,7 @@ public class StompMessageConverter { @@ -51,7 +50,7 @@ public class StompMessageConverter {
/**
* @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String.
*/
public <M extends Message<?>> M toMessage(Object stompContent, String sessionId, MessageFactory<M> messageFactory) {
public Message<byte[]> toMessage(Object stompContent, String sessionId) {
byte[] byteContent = null;
if (stompContent instanceof String) {
@ -102,7 +101,7 @@ public class StompMessageConverter { @@ -102,7 +101,7 @@ public class StompMessageConverter {
byte[] payload = new byte[totalLength - payloadIndex];
System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
return createMessage(command, stompHeaders.toMessageHeaders(), payload, messageFactory);
return MessageBuilder.fromPayloadAndHeaders(payload, stompHeaders.toMessageHeaders()).build();
}
private int findIndexOfPayload(byte[] bytes) {
@ -132,10 +131,6 @@ public class StompMessageConverter { @@ -132,10 +131,6 @@ public class StompMessageConverter {
return index;
}
protected <M extends Message<?>> M createMessage(StompCommand command, Map<String, Object> headers, byte[] payload, MessageFactory<M> messageFactory) {
return messageFactory.createMessage(payload, headers);
}
public byte[] fromMessage(Message<byte[]> message) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
MessageHeaders messageHeaders = message.getHeaders();

24
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java

@ -23,11 +23,10 @@ import java.util.Map; @@ -23,11 +23,10 @@ import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.http.MediaType;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubHeaders;
@ -57,8 +56,6 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler @@ -57,8 +56,6 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
private MessageConverter payloadConverter;
private MessageFactory messageFactory = new GenericMessageFactory();
private final TcpClient<String, String> tcpClient;
private final Map<String, TcpConnection<String, String>> connections =
@ -82,10 +79,6 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler @@ -82,10 +79,6 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
this.payloadConverter = new CompositeMessageConverter(converters);
}
public void setMessageFactory(MessageFactory messageFactory) {
this.messageFactory = messageFactory;
}
@Override
protected Collection<MessageType> getSupportedMessageTypes() {
return null;
@ -117,7 +110,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler @@ -117,7 +110,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
// TODO: why are we getting empty frames?
return;
}
Message<byte[]> message = stompMessageConverter.toMessage(stompFrame, sessionId, messageFactory);
Message<byte[]> message = stompMessageConverter.toMessage(stompFrame, sessionId);
getClientChannel().send(message);
}
});
@ -134,19 +127,18 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler @@ -134,19 +127,18 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
}
@SuppressWarnings("unchecked")
private void forwardMessage(Message<?> message, StompCommand command) {
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
String sessionId = stompHeaders.getSessionId();
StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders());
String sessionId = headers.getSessionId();
byte[] bytesToWrite;
try {
stompHeaders.setStompCommandIfNotSet(StompCommand.SEND);
headers.setStompCommandIfNotSet(StompCommand.SEND);
MediaType contentType = stompHeaders.getContentType();
MediaType contentType = headers.getContentType();
byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType);
Message<byte[]> byteMessage = messageFactory.createMessage(payload, stompHeaders.toMessageHeaders());
Message<byte[]> byteMessage = MessageBuilder.fromPayloadAndHeaders(payload, headers.toMessageHeaders()).build();
bytesToWrite = this.stompMessageConverter.fromMessage(byteMessage);
}
catch (Throwable ex) {
@ -158,7 +150,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler @@ -158,7 +150,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
Assert.notNull(connection, "TCP connection to message broker not found, sessionId=" + sessionId);
try {
if (logger.isTraceEnabled()) {
logger.trace("Forwarding STOMP " + stompHeaders.getStompCommand() + " message");
logger.trace("Forwarding STOMP " + headers.getStompCommand() + " message");
}
connection.out().accept(new String(bytesToWrite, Charset.forName("UTF-8")));
}

59
spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java

@ -25,12 +25,11 @@ import java.util.concurrent.ConcurrentHashMap; @@ -25,12 +25,11 @@ import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.MediaType;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
@ -50,6 +49,11 @@ import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; @@ -50,6 +49,11 @@ import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter;
*/
public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
/**
*
*/
private static final byte[] EMPTY_PAYLOAD = new byte[0];
private static Log logger = LogFactory.getLog(StompWebSocketHandler.class);
private final MessageChannel publishChannel;
@ -60,8 +64,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { @@ -60,8 +64,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
private MessageConverter payloadConverter = new CompositeMessageConverter(null);
private MessageFactory messageFactory = new GenericMessageFactory();
@SuppressWarnings("unchecked")
public StompWebSocketHandler(MessageChannel publishChannel, SubscribableChannel clientChannel) {
@ -78,10 +80,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { @@ -78,10 +80,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
this.payloadConverter = new CompositeMessageConverter(converters);
}
public void setMessageFactory(MessageFactory messageFactory) {
this.messageFactory = messageFactory;
}
public StompMessageConverter getStompMessageConverter() {
return this.stompMessageConverter;
}
@ -101,7 +99,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { @@ -101,7 +99,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) {
try {
String payload = textMessage.getPayload();
Message<byte[]> message = this.stompMessageConverter.toMessage(payload, session.getId(), messageFactory);
Message<byte[]> message = this.stompMessageConverter.toMessage(payload, session.getId());
// TODO: validate size limits
// http://stomp.github.io/stomp-specification-1.2.html#Size_Limits
@ -144,18 +142,17 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { @@ -144,18 +142,17 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
}
}
@SuppressWarnings("unchecked")
protected void handleConnect(final WebSocketSession session, Message<byte[]> message) throws IOException {
StompHeaders connectStompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
StompHeaders connectedStompHeaders = StompHeaders.create(StompCommand.CONNECTED);
StompHeaders connectHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
StompHeaders connectedHeaders = StompHeaders.create(StompCommand.CONNECTED);
Set<String> acceptVersions = connectStompHeaders.getAcceptVersion();
Set<String> acceptVersions = connectHeaders.getAcceptVersion();
if (acceptVersions.contains("1.2")) {
connectedStompHeaders.setAcceptVersion("1.2");
connectedHeaders.setAcceptVersion("1.2");
}
else if (acceptVersions.contains("1.1")) {
connectedStompHeaders.setAcceptVersion("1.1");
connectedHeaders.setAcceptVersion("1.1");
}
else if (acceptVersions.isEmpty()) {
// 1.0
@ -163,11 +160,12 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { @@ -163,11 +160,12 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
else {
throw new StompConversionException("Unsupported version '" + acceptVersions + "'");
}
connectedStompHeaders.setHeartbeat(0,0); // TODO
connectedHeaders.setHeartbeat(0,0); // TODO
// TODO: security
Message<byte[]> connectedMessage = messageFactory.createMessage(new byte[0], connectedStompHeaders.toMessageHeaders());
Message<byte[]> connectedMessage = MessageBuilder.fromPayloadAndHeaders(EMPTY_PAYLOAD,
connectedHeaders.toMessageHeaders()).build();
byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
@ -187,14 +185,14 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { @@ -187,14 +185,14 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
protected void handleDisconnect(Message<byte[]> stompMessage) {
}
@SuppressWarnings("unchecked")
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR);
stompHeaders.setMessage(error.getMessage());
StompHeaders headers = StompHeaders.create(StompCommand.ERROR);
headers.setMessage(error.getMessage());
Message<byte[]> errorMessage = messageFactory.createMessage(new byte[0], stompHeaders.toMessageHeaders());
byte[] bytes = this.stompMessageConverter.fromMessage(errorMessage);
Message<byte[]> message = MessageBuilder.fromPayloadAndHeaders(EMPTY_PAYLOAD,
headers.toMessageHeaders()).build();
byte[] bytes = this.stompMessageConverter.fromMessage(message);
try {
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
@ -214,19 +212,18 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { @@ -214,19 +212,18 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
private final class ClientMessageConsumer implements MessageHandler<Message<?>> {
@SuppressWarnings("unchecked")
@Override
public void handleMessage(Message<?> message) {
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
stompHeaders.setStompCommandIfNotSet(StompCommand.MESSAGE);
StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders());
headers.setStompCommandIfNotSet(StompCommand.MESSAGE);
if (StompCommand.CONNECTED.equals(stompHeaders.getStompCommand())) {
if (StompCommand.CONNECTED.equals(headers.getStompCommand())) {
// Ignore for now since we already sent it
return;
}
String sessionId = stompHeaders.getSessionId();
String sessionId = headers.getSessionId();
if (sessionId == null) {
logger.error("No \"sessionId\" header in message: " + message);
}
@ -237,7 +234,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { @@ -237,7 +234,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
byte[] payload;
try {
MediaType contentType = stompHeaders.getContentType();
MediaType contentType = headers.getContentType();
payload = payloadConverter.convertToPayload(message.getPayload(), contentType);
}
catch (Throwable t) {
@ -246,8 +243,8 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { @@ -246,8 +243,8 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
}
try {
Map<String, Object> messageHeaders = stompHeaders.toMessageHeaders();
Message<byte[]> byteMessage = messageFactory.createMessage(payload, messageHeaders);
Message<byte[]> byteMessage = MessageBuilder.fromPayloadAndHeaders(payload,
headers.toMessageHeaders()).build();
byte[] bytes = getStompMessageConverter().fromMessage(byteMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
@ -255,7 +252,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { @@ -255,7 +252,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
sendErrorMessage(session, t);
}
finally {
if (StompCommand.ERROR.equals(stompHeaders.getStompCommand())) {
if (StompCommand.ERROR.equals(headers.getStompCommand())) {
try {
session.close(CloseStatus.PROTOCOL_ERROR);
}

16
spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java

@ -19,9 +19,7 @@ import java.util.Collections; @@ -19,9 +19,7 @@ import java.util.Collections;
import org.junit.Before;
import org.junit.Test;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.MessageHeaders;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.stomp.StompCommand;
@ -37,22 +35,19 @@ public class StompMessageConverterTests { @@ -37,22 +35,19 @@ public class StompMessageConverterTests {
private StompMessageConverter converter;
private MessageFactory messageFactory = new GenericMessageFactory();
@Before
public void setup() {
this.converter = new StompMessageConverter();
}
@SuppressWarnings("unchecked")
@Test
public void connectFrame() throws Exception {
String accept = "accept-version:1.1\n";
String host = "host:github.org\n";
String frame = "\n\n\nCONNECT\n" + accept + host + "\n";
Message<byte[]> message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123", messageFactory);
Message<byte[]> message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123");
assertEquals(0, message.getPayload().length);
@ -76,14 +71,13 @@ public class StompMessageConverterTests { @@ -76,14 +71,13 @@ public class StompMessageConverterTests {
assertTrue(convertedBack.contains(host));
}
@SuppressWarnings("unchecked")
@Test
public void connectWithEscapes() throws Exception {
String accept = "accept-version:1.1\n";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String frame = "CONNECT\n" + accept + host + "\n";
Message<byte[]> message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123", messageFactory);
Message<byte[]> message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123");
assertEquals(0, message.getPayload().length);
@ -99,14 +93,13 @@ public class StompMessageConverterTests { @@ -99,14 +93,13 @@ public class StompMessageConverterTests {
assertTrue(convertedBack.contains(host));
}
@SuppressWarnings("unchecked")
@Test
public void connectCR12() throws Exception {
String accept = "accept-version:1.2\n";
String host = "host:github.org\n";
String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
Message<byte[]> message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123", messageFactory);
Message<byte[]> message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123");
assertEquals(0, message.getPayload().length);
@ -122,14 +115,13 @@ public class StompMessageConverterTests { @@ -122,14 +115,13 @@ public class StompMessageConverterTests {
assertTrue(convertedBack.contains(host));
}
@SuppressWarnings("unchecked")
@Test
public void connectWithEscapesAndCR12() throws Exception {
String accept = "accept-version:1.1\n";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
Message<byte[]> message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123", messageFactory);
Message<byte[]> message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123");
assertEquals(0, message.getPayload().length);

Loading…
Cancel
Save