diff --git a/build.gradle b/build.gradle index 120979a964e..f35fe502dda 100644 --- a/build.gradle +++ b/build.gradle @@ -598,6 +598,8 @@ project("spring-websocket") { testCompile("org.apache.tomcat.embed:tomcat-embed-core:8.0.0-RC5") testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}") testCompile("log4j:log4j:1.2.17") + testCompile("org.projectreactor:reactor-core:1.0.0.RELEASE") + testCompile("org.projectreactor:reactor-tcp:1.0.0.RELEASE") } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractBrokerRegistration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractBrokerRegistration.java index 573fbd2e7c0..ff8d5427025 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractBrokerRegistration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractBrokerRegistration.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.List; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.handler.AbstractBrokerMessageHandler; import org.springframework.util.Assert; @@ -33,19 +34,31 @@ import org.springframework.util.Assert; */ public abstract class AbstractBrokerRegistration { + private final SubscribableChannel clientInboundChannel; + private final MessageChannel clientOutboundChannel; private final List destinationPrefixes; - public AbstractBrokerRegistration(MessageChannel clientOutboundChannel, String[] destinationPrefixes) { + public AbstractBrokerRegistration(SubscribableChannel clientInboundChannel, + MessageChannel clientOutboundChannel, String[] destinationPrefixes) { + + Assert.notNull(clientOutboundChannel, "'clientInboundChannel' must not be null"); Assert.notNull(clientOutboundChannel, "'clientOutboundChannel' must not be null"); + + this.clientInboundChannel = clientInboundChannel; this.clientOutboundChannel = clientOutboundChannel; + this.destinationPrefixes = (destinationPrefixes != null) ? Arrays.asList(destinationPrefixes) : Collections.emptyList(); } + protected SubscribableChannel getClientInboundChannel() { + return this.clientInboundChannel; + } + protected MessageChannel getClientOutboundChannel() { return this.clientOutboundChannel; } @@ -54,6 +67,6 @@ public abstract class AbstractBrokerRegistration { return this.destinationPrefixes; } - protected abstract AbstractBrokerMessageHandler getMessageHandler(); + protected abstract AbstractBrokerMessageHandler getMessageHandler(SubscribableChannel brokerChannel); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java index 748f0c9c51b..65605dcb2b5 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java @@ -47,7 +47,7 @@ import java.util.List; * into any application component to send messages. *

* Sub-classes are responsible for the part of the configuration that feed messages - * to and from the client inbound/outbound channels (e.g. STOMP over WebSokcet). + * to and from the client inbound/outbound channels (e.g. STOMP over WebSocket). * * @author Rossen Stoyanchev * @since 4.0 @@ -86,7 +86,7 @@ public abstract class AbstractMessageBrokerConfiguration { public ThreadPoolTaskExecutor clientInboundChannelExecutor() { TaskExecutorRegistration r = getClientInboundChannelRegistration().getTaskExecutorRegistration(); ThreadPoolTaskExecutor executor = (r != null) ? r.getTaskExecutor() : new ThreadPoolTaskExecutor(); - executor.setThreadNamePrefix("ClientInboundChannel-"); + executor.setThreadNamePrefix("clientInboundChannel-"); return executor; } @@ -121,7 +121,7 @@ public abstract class AbstractMessageBrokerConfiguration { public ThreadPoolTaskExecutor clientOutboundChannelExecutor() { TaskExecutorRegistration r = getClientOutboundChannelRegistration().getTaskExecutorRegistration(); ThreadPoolTaskExecutor executor = (r != null) ? r.getTaskExecutor() : new ThreadPoolTaskExecutor(); - executor.setThreadNamePrefix("ClientOutboundChannel-"); + executor.setThreadNamePrefix("clientOutboundChannel-"); return executor; } @@ -160,7 +160,7 @@ public abstract class AbstractMessageBrokerConfiguration { public ThreadPoolTaskExecutor brokerChannelExecutor() { TaskExecutorRegistration r = getBrokerRegistry().getBrokerChannelRegistration().getTaskExecutorRegistration(); ThreadPoolTaskExecutor executor = (r != null) ? r.getTaskExecutor() : new ThreadPoolTaskExecutor(); - executor.setThreadNamePrefix("BrokerChannel-"); + executor.setThreadNamePrefix("brokerChannel-"); return executor; } @@ -170,7 +170,7 @@ public abstract class AbstractMessageBrokerConfiguration { */ protected final MessageBrokerRegistry getBrokerRegistry() { if (this.brokerRegistry == null) { - MessageBrokerRegistry registry = new MessageBrokerRegistry(clientOutboundChannel()); + MessageBrokerRegistry registry = new MessageBrokerRegistry(clientInboundChannel(), clientOutboundChannel()); configureMessageBroker(registry); this.brokerRegistry = registry; } @@ -187,45 +187,30 @@ public abstract class AbstractMessageBrokerConfiguration { @Bean public SimpAnnotationMethodMessageHandler simpAnnotationMethodMessageHandler() { - SimpAnnotationMethodMessageHandler handler = - new SimpAnnotationMethodMessageHandler(brokerMessagingTemplate(), clientOutboundChannel()); + SimpAnnotationMethodMessageHandler handler = new SimpAnnotationMethodMessageHandler( + clientInboundChannel(), clientOutboundChannel(), brokerMessagingTemplate()); handler.setDestinationPrefixes(getBrokerRegistry().getApplicationDestinationPrefixes()); handler.setMessageConverter(brokerMessageConverter()); - clientInboundChannel().subscribe(handler); return handler; } @Bean public AbstractBrokerMessageHandler simpleBrokerMessageHandler() { - SimpleBrokerMessageHandler handler = getBrokerRegistry().getSimpleBroker(); - if (handler != null) { - clientInboundChannel().subscribe(handler); - brokerChannel().subscribe(handler); - return handler; - } - return noopBroker; + SimpleBrokerMessageHandler handler = getBrokerRegistry().getSimpleBroker(brokerChannel()); + return (handler != null) ? handler : noopBroker; } @Bean public AbstractBrokerMessageHandler stompBrokerRelayMessageHandler() { - AbstractBrokerMessageHandler handler = getBrokerRegistry().getStompBrokerRelay(); - if (handler != null) { - clientInboundChannel().subscribe(handler); - brokerChannel().subscribe(handler); - return handler; - } - return noopBroker; + AbstractBrokerMessageHandler handler = getBrokerRegistry().getStompBrokerRelay(brokerChannel()); + return (handler != null) ? handler : noopBroker; } @Bean public UserDestinationMessageHandler userDestinationMessageHandler() { - UserDestinationMessageHandler handler = new UserDestinationMessageHandler( - brokerMessagingTemplate(), userDestinationResolver()); - - clientInboundChannel().subscribe(handler); - brokerChannel().subscribe(handler); + clientInboundChannel(), clientOutboundChannel(), brokerChannel(), userDestinationResolver()); return handler; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/MessageBrokerRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/MessageBrokerRegistry.java index 1525dbf0f75..1a01db94d57 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/MessageBrokerRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/MessageBrokerRegistry.java @@ -20,6 +20,7 @@ import java.util.Arrays; import java.util.Collection; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.handler.SimpleBrokerMessageHandler; import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; import org.springframework.util.Assert; @@ -32,11 +33,13 @@ import org.springframework.util.Assert; */ public class MessageBrokerRegistry { + private final SubscribableChannel clientInboundChannel; + private final MessageChannel clientOutboundChannel; - private SimpleBrokerRegistration simpleBroker; + private SimpleBrokerRegistration simpleBrokerRegistration; - private StompBrokerRelayRegistration stompRelay; + private StompBrokerRelayRegistration brokerRelayRegistration; private String[] applicationDestinationPrefixes; @@ -45,8 +48,10 @@ public class MessageBrokerRegistry { private ChannelRegistration brokerChannelRegistration = new ChannelRegistration(); - public MessageBrokerRegistry(MessageChannel clientOutboundChannel) { + public MessageBrokerRegistry(SubscribableChannel clientInboundChannel, MessageChannel clientOutboundChannel) { + Assert.notNull(clientInboundChannel); Assert.notNull(clientOutboundChannel); + this.clientInboundChannel = clientInboundChannel; this.clientOutboundChannel = clientOutboundChannel; } @@ -55,8 +60,9 @@ public class MessageBrokerRegistry { * destinations targeting the broker (e.g. destinations prefixed with "/topic"). */ public SimpleBrokerRegistration enableSimpleBroker(String... destinationPrefixes) { - this.simpleBroker = new SimpleBrokerRegistration(this.clientOutboundChannel, destinationPrefixes); - return this.simpleBroker; + this.simpleBrokerRegistration = new SimpleBrokerRegistration( + this.clientInboundChannel, this.clientOutboundChannel, destinationPrefixes); + return this.simpleBrokerRegistration; } /** @@ -65,8 +71,9 @@ public class MessageBrokerRegistry { * destinations. */ public StompBrokerRelayRegistration enableStompBrokerRelay(String... destinationPrefixes) { - this.stompRelay = new StompBrokerRelayRegistration(this.clientOutboundChannel, destinationPrefixes); - return this.stompRelay; + this.brokerRelayRegistration = new StompBrokerRelayRegistration( + this.clientInboundChannel, this.clientOutboundChannel, destinationPrefixes); + return this.brokerRelayRegistration; } /** @@ -113,19 +120,21 @@ public class MessageBrokerRegistry { } - protected SimpleBrokerMessageHandler getSimpleBroker() { - initSimpleBrokerIfNecessary(); - return (this.simpleBroker != null) ? this.simpleBroker.getMessageHandler() : null; - } - - protected void initSimpleBrokerIfNecessary() { - if ((this.simpleBroker == null) && (this.stompRelay == null)) { - this.simpleBroker = new SimpleBrokerRegistration(this.clientOutboundChannel, null); + protected SimpleBrokerMessageHandler getSimpleBroker(SubscribableChannel brokerChannel) { + if ((this.simpleBrokerRegistration == null) && (this.brokerRelayRegistration == null)) { + enableSimpleBroker(); } + if (this.simpleBrokerRegistration != null) { + return this.simpleBrokerRegistration.getMessageHandler(brokerChannel); + } + return null; } - protected StompBrokerRelayMessageHandler getStompBrokerRelay() { - return (this.stompRelay != null) ? this.stompRelay.getMessageHandler() : null; + protected StompBrokerRelayMessageHandler getStompBrokerRelay(SubscribableChannel brokerChannel) { + if (this.brokerRelayRegistration != null) { + return this.brokerRelayRegistration.getMessageHandler(brokerChannel); + } + return null; } protected Collection getApplicationDestinationPrefixes() { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/SimpleBrokerRegistration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/SimpleBrokerRegistration.java index a9db100c91b..df1ed72ec4b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/SimpleBrokerRegistration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/SimpleBrokerRegistration.java @@ -17,6 +17,7 @@ package org.springframework.messaging.simp.config; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.handler.SimpleBrokerMessageHandler; /** @@ -28,14 +29,17 @@ import org.springframework.messaging.simp.handler.SimpleBrokerMessageHandler; public class SimpleBrokerRegistration extends AbstractBrokerRegistration { - public SimpleBrokerRegistration(MessageChannel clientOutboundChannel, String[] destinationPrefixes) { - super(clientOutboundChannel, destinationPrefixes); + public SimpleBrokerRegistration(SubscribableChannel clientInboundChannel, + MessageChannel clientOutboundChannel, String[] destinationPrefixes) { + + super(clientInboundChannel, clientOutboundChannel, destinationPrefixes); } @Override - protected SimpleBrokerMessageHandler getMessageHandler() { - return new SimpleBrokerMessageHandler(getClientOutboundChannel(), getDestinationPrefixes()); + protected SimpleBrokerMessageHandler getMessageHandler(SubscribableChannel brokerChannel) { + return new SimpleBrokerMessageHandler(getClientInboundChannel(), + getClientOutboundChannel(), brokerChannel, getDestinationPrefixes()); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompBrokerRelayRegistration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompBrokerRelayRegistration.java index cb1e1abf251..f778964bf3d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompBrokerRelayRegistration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompBrokerRelayRegistration.java @@ -17,6 +17,7 @@ package org.springframework.messaging.simp.config; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; import org.springframework.util.Assert; @@ -43,8 +44,10 @@ public class StompBrokerRelayRegistration extends AbstractBrokerRegistration { private boolean autoStartup = true; - public StompBrokerRelayRegistration(MessageChannel clientOutboundChannel, String[] destinationPrefixes) { - super(clientOutboundChannel, destinationPrefixes); + public StompBrokerRelayRegistration(SubscribableChannel clientInboundChannel, + MessageChannel clientOutboundChannel, String[] destinationPrefixes) { + + super(clientInboundChannel, clientOutboundChannel, destinationPrefixes); } @@ -119,10 +122,10 @@ public class StompBrokerRelayRegistration extends AbstractBrokerRegistration { } - protected StompBrokerRelayMessageHandler getMessageHandler() { + protected StompBrokerRelayMessageHandler getMessageHandler(SubscribableChannel brokerChannel) { - StompBrokerRelayMessageHandler handler = - new StompBrokerRelayMessageHandler(getClientOutboundChannel(), getDestinationPrefixes()); + StompBrokerRelayMessageHandler handler = new StompBrokerRelayMessageHandler(getClientInboundChannel(), + getClientOutboundChannel(), brokerChannel, getDestinationPrefixes()); handler.setRelayHost(this.relayHost); handler.setRelayPort(this.relayPort); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodMessageHandler.java index 15172cea0ab..d287710a76c 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodMessageHandler.java @@ -27,11 +27,13 @@ import java.util.Set; import org.springframework.beans.factory.config.ConfigurableBeanFactory; import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.SmartLifecycle; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.convert.ConversionService; import org.springframework.format.support.DefaultFormattingConversionService; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.core.AbstractMessageSendingTemplate; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.support.AnnotationExceptionHandlerMethodResolver; @@ -54,6 +56,7 @@ import org.springframework.messaging.simp.annotation.support.PrincipalMethodArgu import org.springframework.messaging.simp.annotation.support.SendToMethodReturnValueHandler; import org.springframework.messaging.simp.annotation.support.SubscriptionMethodReturnValueHandler; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.channel.AbstractSubscribableChannel; import org.springframework.messaging.support.converter.ByteArrayMessageConverter; import org.springframework.messaging.support.converter.CompositeMessageConverter; import org.springframework.messaging.support.converter.MessageConverter; @@ -74,30 +77,44 @@ import org.springframework.util.PathMatcher; * @author Brian Clozel * @since 4.0 */ -public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHandler { +public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHandler + implements SmartLifecycle { - private final SimpMessageSendingOperations brokerTemplate; + private final SubscribableChannel clientInboundChannel; private final SimpMessageSendingOperations clientMessagingTemplate; + private final SimpMessageSendingOperations brokerTemplate; + private MessageConverter messageConverter; private ConversionService conversionService = new DefaultFormattingConversionService(); private PathMatcher pathMatcher = new AntPathMatcher(); + private Object lifecycleMonitor = new Object(); + + private volatile boolean running = false; + /** - * @param brokerTemplate a messaging template to send application messages to the broker + * Create an instance of SimpAnnotationMethodMessageHandler with the given + * message channels and broker messaging template. + * + * @param clientInboundChannel the channel for receiving messages from clients (e.g. WebSocket clients) * @param clientOutboundChannel the channel for messages to clients (e.g. WebSocket clients) + * @param brokerTemplate a messaging template to send application messages to the broker */ - public SimpAnnotationMethodMessageHandler(SimpMessageSendingOperations brokerTemplate, - MessageChannel clientOutboundChannel) { + public SimpAnnotationMethodMessageHandler(SubscribableChannel clientInboundChannel, + MessageChannel clientOutboundChannel, SimpMessageSendingOperations brokerTemplate) { - Assert.notNull(brokerTemplate, "BrokerTemplate must not be null"); - Assert.notNull(clientOutboundChannel, "ClientOutboundChannel must not be null"); - this.brokerTemplate = brokerTemplate; + Assert.notNull(clientInboundChannel, "clientInboundChannel must not be null"); + Assert.notNull(clientOutboundChannel, "clientOutboundChannel must not be null"); + Assert.notNull(brokerTemplate, "brokerTemplate must not be null"); + + this.clientInboundChannel = clientInboundChannel; this.clientMessagingTemplate = new SimpMessagingTemplate(clientOutboundChannel); + this.brokerTemplate = brokerTemplate; Collection converters = new ArrayList(); converters.add(new StringMessageConverter()); @@ -159,6 +176,46 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan return this.pathMatcher; } + @Override + public boolean isAutoStartup() { + return true; + } + + @Override + public int getPhase() { + return Integer.MAX_VALUE; + } + + @Override + public final boolean isRunning() { + synchronized (this.lifecycleMonitor) { + return this.running; + } + } + + @Override + public final void start() { + synchronized (this.lifecycleMonitor) { + this.clientInboundChannel.subscribe(this); + this.running = true; + } + } + + @Override + public final void stop() { + synchronized (this.lifecycleMonitor) { + this.running = false; + this.clientInboundChannel.unsubscribe(this); + } + } + + @Override + public final void stop(Runnable callback) { + synchronized (this.lifecycleMonitor) { + stop(); + callback.run(); + } + } protected List initArgumentResolvers() { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java index 531353a6433..902f59fbce5 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java @@ -20,6 +20,7 @@ import java.util.Collection; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; @@ -34,23 +35,49 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { private static final byte[] EMPTY_PAYLOAD = new byte[0]; - private final MessageChannel messageChannel; + private final SubscribableChannel clientInboundChannel; + + private final MessageChannel clientOutboundChannel; + + private final SubscribableChannel brokerChannel; private SubscriptionRegistry subscriptionRegistry = new DefaultSubscriptionRegistry(); /** - * @param messageChannel the channel to broadcast messages to + * Create a SimpleBrokerMessageHandler instance with the given message channels + * and destination prefixes. + * + * @param clientInboundChannel the channel for receiving messages from clients (e.g. WebSocket clients) + * @param clientOutboundChannel the channel for sending messages to clients (e.g. WebSocket clients) + * @param brokerChannel the channel for the application to send messages to the broker */ - public SimpleBrokerMessageHandler(MessageChannel messageChannel, Collection destinationPrefixes) { + public SimpleBrokerMessageHandler(SubscribableChannel clientInboundChannel, MessageChannel clientOutboundChannel, + SubscribableChannel brokerChannel, Collection destinationPrefixes) { + super(destinationPrefixes); - Assert.notNull(messageChannel, "MessageChannel must not be null"); - this.messageChannel = messageChannel; + + Assert.notNull(clientInboundChannel, "'clientInboundChannel' must not be null"); + Assert.notNull(clientOutboundChannel, "'clientOutboundChannel' must not be null"); + Assert.notNull(brokerChannel, "'brokerChannel' must not be null"); + + + this.clientInboundChannel = clientInboundChannel; + this.clientOutboundChannel = clientOutboundChannel; + this.brokerChannel = brokerChannel; } - public MessageChannel getMessageChannel() { - return this.messageChannel; + public SubscribableChannel getClientInboundChannel() { + return this.clientInboundChannel; + } + + public MessageChannel getClientOutboundChannel() { + return this.clientOutboundChannel; + } + + public SubscribableChannel getBrokerChannel() { + return this.brokerChannel; } public void setSubscriptionRegistry(SubscriptionRegistry subscriptionRegistry) { @@ -66,11 +93,15 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { @Override public void startInternal() { publishBrokerAvailableEvent(); + this.clientInboundChannel.subscribe(this); + this.brokerChannel.subscribe(this); } @Override public void stopInternal() { publishBrokerUnavailableEvent(); + this.clientInboundChannel.unsubscribe(this); + this.brokerChannel.unsubscribe(this); } @Override @@ -106,7 +137,7 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { replyHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message); Message connectAck = MessageBuilder.withPayload(EMPTY_PAYLOAD).setHeaders(replyHeaders).build(); - this.messageChannel.send(connectAck); + this.clientOutboundChannel.send(connectAck); } } @@ -124,7 +155,7 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { Object payload = message.getPayload(); Message clientMessage = MessageBuilder.withPayload(payload).setHeaders(headers).build(); try { - this.messageChannel.send(clientMessage); + this.clientOutboundChannel.send(clientMessage); } catch (Throwable ex) { logger.error("Failed to send message to destination=" + destination + diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/UserDestinationMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/UserDestinationMessageHandler.java index f194217f95e..01b177ae440 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/UserDestinationMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/UserDestinationMessageHandler.java @@ -20,10 +20,10 @@ import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.messaging.Message; -import org.springframework.messaging.MessageHandler; -import org.springframework.messaging.MessagingException; +import org.springframework.context.SmartLifecycle; +import org.springframework.messaging.*; import org.springframework.messaging.core.MessageSendingOperations; +import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -37,28 +37,46 @@ import org.springframework.util.CollectionUtils; * @author Rossen Stoyanchev * @since 4.0 */ -public class UserDestinationMessageHandler implements MessageHandler { +public class UserDestinationMessageHandler implements MessageHandler, SmartLifecycle { private static final Log logger = LogFactory.getLog(UserDestinationMessageHandler.class); - private final MessageSendingOperations messagingTemplate; + private final SubscribableChannel clientInboundChannel; + + private final MessageChannel clientOutboundChannel; + + private final SubscribableChannel brokerChannel; + + private final MessageSendingOperations brokerMessagingTemplate; private final UserDestinationResolver userDestinationResolver; + private Object lifecycleMonitor = new Object(); + + private volatile boolean running = false; + /** * Create an instance of the handler with the given messaging template and a * user destination resolver. - * @param messagingTemplate a messaging template to use for sending messages - * with translated user destinations + * @param clientInChannel the channel for receiving messages from clients (e.g. WebSocket clients) + * @param clientOutChannel the channel for sending messages to clients (e.g. WebSocket clients) + * @param brokerChannel the channel for sending messages with translated user destinations * @param userDestinationResolver the resolver to use to find queue suffixes for a user */ - public UserDestinationMessageHandler(MessageSendingOperations messagingTemplate, - UserDestinationResolver userDestinationResolver) { - Assert.notNull(messagingTemplate, "MessagingTemplate must not be null"); + public UserDestinationMessageHandler(SubscribableChannel clientInChannel, MessageChannel clientOutChannel, + SubscribableChannel brokerChannel, UserDestinationResolver userDestinationResolver) { + + Assert.notNull(clientInChannel, "'clientInChannel' must not be null"); + Assert.notNull(clientOutChannel, "'clientOutChannel' must not be null"); + Assert.notNull(brokerChannel, "'brokerChannel' must not be null"); Assert.notNull(userDestinationResolver, "DestinationResolver must not be null"); - this.messagingTemplate = messagingTemplate; + + this.clientInboundChannel = clientInChannel; + this.clientOutboundChannel = clientOutChannel; + this.brokerChannel = brokerChannel; + this.brokerMessagingTemplate = new SimpMessagingTemplate(brokerChannel); this.userDestinationResolver = userDestinationResolver; } @@ -73,10 +91,52 @@ public class UserDestinationMessageHandler implements MessageHandler { * Return the configured messaging template for sending messages with * translated destinations. */ - public MessageSendingOperations getMessagingTemplate() { - return this.messagingTemplate; + public MessageSendingOperations getBrokerMessagingTemplate() { + return this.brokerMessagingTemplate; + } + + @Override + public boolean isAutoStartup() { + return true; + } + + @Override + public int getPhase() { + return Integer.MAX_VALUE; + } + + @Override + public final boolean isRunning() { + synchronized (this.lifecycleMonitor) { + return this.running; + } + } + + @Override + public final void start() { + synchronized (this.lifecycleMonitor) { + this.clientInboundChannel.subscribe(this); + this.brokerChannel.subscribe(this); + this.running = true; + } } + @Override + public final void stop() { + synchronized (this.lifecycleMonitor) { + this.running = false; + this.clientInboundChannel.unsubscribe(this); + this.brokerChannel.unsubscribe(this); + } + } + + @Override + public final void stop(Runnable callback) { + synchronized (this.lifecycleMonitor) { + stop(); + callback.run(); + } + } @Override public void handleMessage(Message message) throws MessagingException { @@ -90,7 +150,7 @@ public class UserDestinationMessageHandler implements MessageHandler { if (logger.isDebugEnabled()) { logger.debug("Sending message to resolved destination=" + targetDestination); } - this.messagingTemplate.send(targetDestination, message); + this.brokerMessagingTemplate.send(targetDestination, message); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 9941cff977d..058e0b75dc1 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -21,10 +21,7 @@ import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; -import org.springframework.messaging.Message; -import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.MessageDeliveryException; -import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.*; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.handler.AbstractBrokerMessageHandler; @@ -77,7 +74,11 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } - private final MessageChannel messageChannel; + private final SubscribableChannel clientInboundChannel; + + private final MessageChannel clientOutboundChannel; + + private final SubscribableChannel brokerChannel; private String relayHost = "127.0.0.1"; @@ -100,14 +101,28 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler /** - * @param messageChannel the channel to send messages from the STOMP broker to + * Create a StompBrokerRelayMessageHandler instance with the given message channels + * and destination prefixes. + * + * @param clientInChannel the channel for receiving messages from clients (e.g. WebSocket clients) + * @param clientOutChannel the channel for sending messages to clients (e.g. WebSocket clients) + * @param brokerChannel the channel for the application to send messages to the broker * @param destinationPrefixes the broker supported destination prefixes; destinations * that do not match the given prefix are ignored. */ - public StompBrokerRelayMessageHandler(MessageChannel messageChannel, Collection destinationPrefixes) { + public StompBrokerRelayMessageHandler(SubscribableChannel clientInChannel, MessageChannel clientOutChannel, + SubscribableChannel brokerChannel, Collection destinationPrefixes) { + super(destinationPrefixes); - Assert.notNull(messageChannel, "MessageChannel must not be null"); - this.messageChannel = messageChannel; + + Assert.notNull(clientInChannel, "'clientInChannel' must not be null"); + Assert.notNull(clientOutChannel, "'clientOutChannel' must not be null"); + Assert.notNull(brokerChannel, "'brokerChannel' must not be null"); + + + this.clientInboundChannel = clientInChannel; + this.clientOutboundChannel = clientOutChannel; + this.brokerChannel = brokerChannel; } @@ -242,6 +257,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @Override protected void startInternal() { + this.clientInboundChannel.subscribe(this); + this.brokerChannel.subscribe(this); + if (this.tcpClient == null) { this.tcpClient = new ReactorNettyTcpClient(this.relayHost, this.relayPort, new StompCodec()); } @@ -265,6 +283,10 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @Override protected void stopInternal() { + + this.clientInboundChannel.unsubscribe(this); + this.brokerChannel.unsubscribe(this); + for (StompConnectionHandler handler : this.connectionHandlers.values()) { try { handler.resetTcpConnection(); @@ -416,7 +438,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler protected void sendMessageToClient(Message message) { if (this.isRemoteClientSession) { - StompBrokerRelayMessageHandler.this.messageChannel.send(message); + StompBrokerRelayMessageHandler.this.clientOutboundChannel.send(message); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java index 057543133c2..8e461682559 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java @@ -43,7 +43,7 @@ public abstract class AbstractSubscribableChannel extends AbstractMessageChannel /** * Whether the given {@link MessageHandler} is already subscribed. */ - protected abstract boolean hasSubscription(MessageHandler handler); + public abstract boolean hasSubscription(MessageHandler handler); /** * Subscribe the given {@link MessageHandler}. diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ExecutorSubscribableChannel.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ExecutorSubscribableChannel.java index 440f64ba55f..ecf71c9c4f9 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ExecutorSubscribableChannel.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ExecutorSubscribableChannel.java @@ -62,7 +62,7 @@ public class ExecutorSubscribableChannel extends AbstractSubscribableChannel { } @Override - protected boolean hasSubscription(MessageHandler handler) { + public boolean hasSubscription(MessageHandler handler) { return this.handlers.contains(handler); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/StubMessageChannel.java b/spring-messaging/src/test/java/org/springframework/messaging/StubMessageChannel.java index 7440cd2cbc4..ac40836c449 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/StubMessageChannel.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/StubMessageChannel.java @@ -24,10 +24,12 @@ import java.util.List; * * @author Rossen Stoyanchev */ -public class StubMessageChannel implements MessageChannel { +public class StubMessageChannel implements SubscribableChannel { private final List> messages = new ArrayList<>(); + private final List handlers = new ArrayList<>(); + public List> getMessages() { return this.messages; @@ -47,4 +49,15 @@ public class StubMessageChannel implements MessageChannel { return true; } + @Override + public boolean subscribe(MessageHandler handler) { + this.handlers.add(handler); + return true; + } + + @Override + public boolean unsubscribe(MessageHandler handler) { + this.handlers.remove(handler); + return true; + } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java index 5a036448a0f..4b4d6d5db79 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java @@ -46,7 +46,6 @@ import org.springframework.stereotype.Controller; import org.springframework.util.MimeTypeUtils; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import static org.junit.Assert.*; @@ -319,6 +318,7 @@ public class MessageBrokerConfigurationTests { } @Override + @Bean public AbstractSubscribableChannel brokerChannel() { return new TestChannel(); } @@ -334,7 +334,7 @@ public class MessageBrokerConfigurationTests { @Override public void configureMessageBroker(MessageBrokerRegistry registry) { - registry.enableStompBrokerRelay("/topic", "/queue").setAutoStartup(false); + registry.enableStompBrokerRelay("/topic", "/queue").setAutoStartup(true); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodMessageHandlerTests.java index 29609dd7447..6ae0d7be0ce 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodMessageHandlerTests.java @@ -25,6 +25,7 @@ import org.mockito.Mockito; import org.springframework.context.support.StaticApplicationContext; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.handler.annotation.Header; import org.springframework.messaging.handler.annotation.Headers; import org.springframework.messaging.handler.annotation.MessageMapping; @@ -53,9 +54,9 @@ public class SimpAnnotationMethodMessageHandlerTests { @Before public void setup() { - MessageChannel channel = Mockito.mock(MessageChannel.class); + SubscribableChannel channel = Mockito.mock(SubscribableChannel.class); SimpMessageSendingOperations brokerTemplate = new SimpMessagingTemplate(channel); - this.messageHandler = new TestSimpAnnotationMethodMessageHandler(brokerTemplate, channel); + this.messageHandler = new TestSimpAnnotationMethodMessageHandler(brokerTemplate, channel, channel); this.messageHandler.setApplicationContext(new StaticApplicationContext()); this.messageHandler.afterPropertiesSet(); @@ -145,9 +146,9 @@ public class SimpAnnotationMethodMessageHandlerTests { private static class TestSimpAnnotationMethodMessageHandler extends SimpAnnotationMethodMessageHandler { public TestSimpAnnotationMethodMessageHandler(SimpMessageSendingOperations brokerTemplate, - MessageChannel clientOutboundChannel) { + SubscribableChannel clientInboundChannel, MessageChannel clientOutboundChannel) { - super(brokerTemplate, clientOutboundChannel); + super(clientInboundChannel, clientOutboundChannel, brokerTemplate); } public void registerHandler(Object handler) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandlerTests.java index a026515de06..f86c6f54e43 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandlerTests.java @@ -26,6 +26,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; @@ -43,7 +44,13 @@ public class SimpleBrokerMessageHandlerTests { private SimpleBrokerMessageHandler messageHandler; @Mock - private MessageChannel clientChannel; + private SubscribableChannel clientInboundChannel; + + @Mock + private MessageChannel clientOutboundChannel; + + @Mock + private SubscribableChannel brokerChannel; @Captor ArgumentCaptor> messageCaptor; @@ -52,7 +59,8 @@ public class SimpleBrokerMessageHandlerTests { @Before public void setup() { MockitoAnnotations.initMocks(this); - this.messageHandler = new SimpleBrokerMessageHandler(this.clientChannel, Collections.emptyList()); + this.messageHandler = new SimpleBrokerMessageHandler(this.clientInboundChannel, + this.clientOutboundChannel, this.brokerChannel, Collections.emptyList()); } @@ -72,7 +80,7 @@ public class SimpleBrokerMessageHandlerTests { this.messageHandler.handleMessage(createMessage("/foo", "message1")); this.messageHandler.handleMessage(createMessage("/bar", "message2")); - verify(this.clientChannel, times(6)).send(this.messageCaptor.capture()); + verify(this.clientOutboundChannel, times(6)).send(this.messageCaptor.capture()); assertCapturedMessage("sess1", "sub1", "/foo"); assertCapturedMessage("sess1", "sub2", "/foo"); assertCapturedMessage("sess2", "sub1", "/foo"); @@ -105,7 +113,7 @@ public class SimpleBrokerMessageHandlerTests { this.messageHandler.handleMessage(createMessage("/foo", "message1")); this.messageHandler.handleMessage(createMessage("/bar", "message2")); - verify(this.clientChannel, times(3)).send(this.messageCaptor.capture()); + verify(this.clientOutboundChannel, times(3)).send(this.messageCaptor.capture()); assertCapturedMessage(sess2, "sub1", "/foo"); assertCapturedMessage(sess2, "sub2", "/foo"); assertCapturedMessage(sess2, "sub3", "/bar"); @@ -121,7 +129,7 @@ public class SimpleBrokerMessageHandlerTests { Message connectMessage = createConnectMessage(sess1); this.messageHandler.handleMessage(connectMessage); - verify(this.clientChannel, times(1)).send(this.messageCaptor.capture()); + verify(this.clientOutboundChannel, times(1)).send(this.messageCaptor.capture()); Message connectAckMessage = this.messageCaptor.getValue(); SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.wrap(connectAckMessage); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/UserDestinationMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/UserDestinationMessageHandlerTests.java index 7ef9a3ee5e7..813743d403f 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/UserDestinationMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/UserDestinationMessageHandlerTests.java @@ -16,21 +16,22 @@ package org.springframework.messaging.simp.handler; -import org.apache.activemq.transport.stomp.Stomp; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; +import org.mockito.Mock; import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; import org.springframework.messaging.Message; -import org.springframework.messaging.core.MessageSendingOperations; +import org.springframework.messaging.StubMessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.TestPrincipal; -import org.springframework.messaging.simp.stomp.StompCommand; -import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.support.MessageBuilder; import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.*; /** * Unit tests for {@link UserDestinationMessageHandler}. @@ -39,54 +40,60 @@ public class UserDestinationMessageHandlerTests { private UserDestinationMessageHandler messageHandler; - private MessageSendingOperations messagingTemplate; + + @Mock + private SubscribableChannel brokerChannel; private UserSessionRegistry registry; @Before public void setup() { - this.messagingTemplate = Mockito.mock(MessageSendingOperations.class); + MockitoAnnotations.initMocks(this); this.registry = new DefaultUserSessionRegistry(); DefaultUserDestinationResolver resolver = new DefaultUserDestinationResolver(this.registry); - this.messageHandler = new UserDestinationMessageHandler(this.messagingTemplate, resolver); + this.messageHandler = new UserDestinationMessageHandler(new StubMessageChannel(), + new StubMessageChannel(), this.brokerChannel, resolver); } @Test public void handleSubscribe() { this.registry.registerSessionId("joe", "123"); + when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true); this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", "/user/queue/foo")); - ArgumentCaptor captor1 = ArgumentCaptor.forClass(String.class); - ArgumentCaptor captor2 = ArgumentCaptor.forClass(Message.class); - Mockito.verify(this.messagingTemplate).send(captor1.capture(), captor2.capture()); + ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); + Mockito.verify(this.brokerChannel).send(captor.capture()); - assertEquals("/queue/foo-user123", captor1.getValue()); + assertEquals("/queue/foo-user123", + captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER)); } @Test public void handleUnsubscribe() { this.registry.registerSessionId("joe", "123"); + when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true); this.messageHandler.handleMessage(createMessage(SimpMessageType.UNSUBSCRIBE, "joe", "/user/queue/foo")); - ArgumentCaptor captor1 = ArgumentCaptor.forClass(String.class); - ArgumentCaptor captor2 = ArgumentCaptor.forClass(Message.class); - Mockito.verify(this.messagingTemplate).send(captor1.capture(), captor2.capture()); + ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); + Mockito.verify(this.brokerChannel).send(captor.capture()); - assertEquals("/queue/foo-user123", captor1.getValue()); + assertEquals("/queue/foo-user123", + captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER)); } @Test public void handleMessage() { this.registry.registerSessionId("joe", "123"); + when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true); this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "/user/joe/queue/foo")); - ArgumentCaptor captor1 = ArgumentCaptor.forClass(String.class); - ArgumentCaptor captor2 = ArgumentCaptor.forClass(Message.class); - Mockito.verify(this.messagingTemplate).send(captor1.capture(), captor2.capture()); + ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); + Mockito.verify(this.brokerChannel).send(captor.capture()); - assertEquals("/queue/foo-user123", captor1.getValue()); + assertEquals("/queue/foo-user123", + captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER)); } @@ -95,23 +102,23 @@ public class UserDestinationMessageHandlerTests { // no destination this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", null)); - Mockito.verifyZeroInteractions(this.messagingTemplate); + Mockito.verifyZeroInteractions(this.brokerChannel); // not a user destination this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "/queue/foo")); - Mockito.verifyZeroInteractions(this.messagingTemplate); + Mockito.verifyZeroInteractions(this.brokerChannel); // subscribe + no user this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, null, "/user/queue/foo")); - Mockito.verifyZeroInteractions(this.messagingTemplate); + Mockito.verifyZeroInteractions(this.brokerChannel); // subscribe + not a user destination this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", "/queue/foo")); - Mockito.verifyZeroInteractions(this.messagingTemplate); + Mockito.verifyZeroInteractions(this.brokerChannel); // no match on message type this.messageHandler.handleMessage(createMessage(SimpMessageType.CONNECT, "joe", "user/joe/queue/foo")); - Mockito.verifyZeroInteractions(this.messagingTemplate); + Mockito.verifyZeroInteractions(this.brokerChannel); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java index 0e4688f5239..d0fce950c2c 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -32,10 +32,7 @@ import org.junit.Before; import org.junit.Test; import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEventPublisher; -import org.springframework.messaging.Message; -import org.springframework.messaging.MessageDeliveryException; -import org.springframework.messaging.MessageHandler; -import org.springframework.messaging.MessagingException; +import org.springframework.messaging.*; import org.springframework.messaging.simp.BrokerAvailabilityEvent; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; @@ -92,7 +89,8 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { } private void createAndStartRelay() throws InterruptedException { - this.relay = new StompBrokerRelayMessageHandler(this.responseChannel, Arrays.asList("/queue/", "/topic/")); + this.relay = new StompBrokerRelayMessageHandler(new StubMessageChannel(), + this.responseChannel, new StubMessageChannel(), Arrays.asList("/queue/", "/topic/")); this.relay.setRelayPort(this.port); this.relay.setApplicationEventPublisher(this.eventPublisher); this.relay.setSystemHeartbeatReceiveInterval(0); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java index 0eeff202f02..824d978e11f 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java @@ -53,7 +53,8 @@ public class StompBrokerRelayMessageHandlerTests { this.tcpClient = new StubTcpOperations(); - this.brokerRelay = new StompBrokerRelayMessageHandler(new StubMessageChannel(), Arrays.asList("/topic")); + this.brokerRelay = new StompBrokerRelayMessageHandler(new StubMessageChannel(), + new StubMessageChannel(), new StubMessageChannel(), Arrays.asList("/topic")); this.brokerRelay.setTcpClient(tcpClient); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index eb11a036251..d23005c6ff9 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -21,10 +21,8 @@ import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.messaging.Message; -import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.MessageHandler; -import org.springframework.messaging.MessagingException; +import org.springframework.context.SmartLifecycle; +import org.springframework.messaging.*; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -50,11 +48,14 @@ import org.springframework.web.socket.support.SubProtocolCapable; * * @since 4.0 */ -public class SubProtocolWebSocketHandler implements SubProtocolCapable, WebSocketHandler, MessageHandler { +public class SubProtocolWebSocketHandler + implements SubProtocolCapable, WebSocketHandler, MessageHandler, SmartLifecycle { private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class); - private final MessageChannel clientOutboundChannel; + private final MessageChannel clientInboundChannel; + + private final SubscribableChannel clientOutboundChannel; private final Map protocolHandlers = new TreeMap(String.CASE_INSENSITIVE_ORDER); @@ -63,9 +64,15 @@ public class SubProtocolWebSocketHandler implements SubProtocolCapable, WebSocke private final Map sessions = new ConcurrentHashMap(); + private Object lifecycleMonitor = new Object(); + + private volatile boolean running = false; - public SubProtocolWebSocketHandler(MessageChannel clientOutboundChannel) { + + public SubProtocolWebSocketHandler(MessageChannel clientInboundChannel, SubscribableChannel clientOutboundChannel) { + Assert.notNull(clientInboundChannel, "ClientInboundChannel must not be null"); Assert.notNull(clientOutboundChannel, "ClientOutboundChannel must not be null"); + this.clientInboundChannel = clientInboundChannel; this.clientOutboundChannel = clientOutboundChannel; } @@ -82,6 +89,11 @@ public class SubProtocolWebSocketHandler implements SubProtocolCapable, WebSocke } } + public List getProtocolHandlers() { + return new ArrayList(protocolHandlers.values()); + } + + /** * Register a sub-protocol handler. */ @@ -101,9 +113,9 @@ public class SubProtocolWebSocketHandler implements SubProtocolCapable, WebSocke } /** - * @return the configured sub-protocol handlers + * Return the sub-protocols keyed by protocol name. */ - public Map getProtocolHandlers() { + public Map getProtocolHandlerMap() { return this.protocolHandlers; } @@ -133,10 +145,51 @@ public class SubProtocolWebSocketHandler implements SubProtocolCapable, WebSocke return new ArrayList(this.protocolHandlers.keySet()); } + @Override + public boolean isAutoStartup() { + return true; + } + + @Override + public int getPhase() { + return Integer.MAX_VALUE; + } + + @Override + public final boolean isRunning() { + synchronized (this.lifecycleMonitor) { + return this.running; + } + } + + @Override + public final void start() { + synchronized (this.lifecycleMonitor) { + this.clientOutboundChannel.subscribe(this); + this.running = true; + } + } + + @Override + public final void stop() { + synchronized (this.lifecycleMonitor) { + this.running = false; + this.clientOutboundChannel.unsubscribe(this); + } + } + + @Override + public final void stop(Runnable callback) { + synchronized (this.lifecycleMonitor) { + stop(); + callback.run(); + } + } + @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { this.sessions.put(session.getId(), session); - findProtocolHandler(session).afterSessionStarted(session, this.clientOutboundChannel); + findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel); } protected final SubProtocolHandler findProtocolHandler(WebSocketSession session) { @@ -167,7 +220,7 @@ public class SubProtocolWebSocketHandler implements SubProtocolCapable, WebSocke @Override public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { - findProtocolHandler(session).handleMessageFromClient(session, message, this.clientOutboundChannel); + findProtocolHandler(session).handleMessageFromClient(session, message, this.clientInboundChannel); } @Override @@ -216,7 +269,7 @@ public class SubProtocolWebSocketHandler implements SubProtocolCapable, WebSocke @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { this.sessions.remove(session.getId()); - findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.clientOutboundChannel); + findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.clientInboundChannel); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistry.java index aec9ac94d6d..2a01f88bffe 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistry.java @@ -55,8 +55,10 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry { public WebMvcStompEndpointRegistry(WebSocketHandler webSocketHandler, UserSessionRegistry userSessionRegistry, TaskScheduler defaultSockJsTaskScheduler) { + Assert.notNull(webSocketHandler); Assert.notNull(userSessionRegistry); + this.webSocketHandler = webSocketHandler; this.subProtocolWebSocketHandler = unwrapSubProtocolWebSocketHandler(webSocketHandler); this.stompHandler = new StompSubProtocolHandler(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebSocketMessageBrokerConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebSocketMessageBrokerConfigurationSupport.java index 22e55d63ad1..d523e8344d6 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebSocketMessageBrokerConfigurationSupport.java @@ -53,9 +53,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac @Bean public WebSocketHandler subProtocolWebSocketHandler() { - SubProtocolWebSocketHandler handler = new SubProtocolWebSocketHandler(clientInboundChannel()); - clientOutboundChannel().subscribe(handler); - return handler; + return new SubProtocolWebSocketHandler(clientInboundChannel(), clientOutboundChannel()); } /** diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/xml/MessageBrokerBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/xml/MessageBrokerBeanDefinitionParser.java new file mode 100644 index 00000000000..02f129b2caf --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/xml/MessageBrokerBeanDefinitionParser.java @@ -0,0 +1,436 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.messaging.config.xml; + +import org.springframework.beans.MutablePropertyValues; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConstructorArgumentValues; +import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.parsing.BeanComponentDefinition; +import org.springframework.beans.factory.parsing.CompositeComponentDefinition; +import org.springframework.beans.factory.support.ManagedList; +import org.springframework.beans.factory.support.ManagedMap; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.factory.xml.BeanDefinitionParser; +import org.springframework.beans.factory.xml.ParserContext; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.messaging.simp.handler.DefaultUserDestinationResolver; +import org.springframework.messaging.simp.handler.DefaultUserSessionRegistry; +import org.springframework.messaging.simp.handler.SimpAnnotationMethodMessageHandler; +import org.springframework.messaging.simp.handler.SimpleBrokerMessageHandler; +import org.springframework.messaging.simp.handler.UserDestinationMessageHandler; +import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; +import org.springframework.messaging.support.channel.ExecutorSubscribableChannel; +import org.springframework.messaging.support.converter.ByteArrayMessageConverter; +import org.springframework.messaging.support.converter.CompositeMessageConverter; +import org.springframework.messaging.support.converter.DefaultContentTypeResolver; +import org.springframework.messaging.support.converter.MappingJackson2MessageConverter; +import org.springframework.messaging.support.converter.StringMessageConverter; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.util.ClassUtils; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.StringUtils; +import org.springframework.util.xml.DomUtils; +import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; +import org.springframework.web.socket.messaging.StompSubProtocolHandler; +import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; +import org.springframework.web.socket.server.config.xml.WebSocketNamespaceUtils; +import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; +import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler; +import org.w3c.dom.Element; + +import java.util.Arrays; +import java.util.List; + + +/** + * A {@link org.springframework.beans.factory.xml.BeanDefinitionParser} + * that provides the configuration for the + * {@code } XML namespace element. + *

+ * Registers a Spring MVC {@link org.springframework.web.servlet.handler.SimpleUrlHandlerMapping} + * with order=1 to map HTTP WebSocket handshake requests from STOMP/WebSocket clients. + *

+ * Registers the following {@link org.springframework.messaging.MessageChannel}s: + *

    + *
  • "clientInboundChannel" for receiving messages from clients (e.g. WebSocket clients) + *
  • "clientOutboundChannel" for sending messages to clients (e.g. WebSocket clients) + *
  • "brokerChannel" for sending messages from within the application to the message broker + *
+ *

+ * Registers one of the following based on the selected message broker options: + *

    + *
  • a {@link SimpleBrokerMessageHandler} if the is used + *
  • a {@link StompBrokerRelayMessageHandler} if the is used + *
+ *

+ * Registers a {@link UserDestinationMessageHandler} for handling user destinations. + * + * @author Brian Clozel + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { + + protected static final String SOCKJS_SCHEDULER_BEAN_NAME = "messageBrokerSockJsScheduler"; + + private static final int DEFAULT_MAPPING_ORDER = 1; + + private static final boolean jackson2Present= ClassUtils.isPresent( + "com.fasterxml.jackson.databind.ObjectMapper", MessageBrokerBeanDefinitionParser.class.getClassLoader()); + + + @Override + public BeanDefinition parse(Element element, ParserContext parserCxt) { + + Object source = parserCxt.extractSource(element); + CompositeComponentDefinition compDefinition = new CompositeComponentDefinition(element.getTagName(), source); + parserCxt.pushContainingComponent(compDefinition); + + String orderAttribute = element.getAttribute("order"); + int order = orderAttribute.isEmpty() ? DEFAULT_MAPPING_ORDER : Integer.valueOf(orderAttribute); + + ManagedMap urlMap = new ManagedMap(); + urlMap.setSource(source); + + RootBeanDefinition handlerMappingDef = new RootBeanDefinition(SimpleUrlHandlerMapping.class); + handlerMappingDef.getPropertyValues().add("order", order); + handlerMappingDef.getPropertyValues().add("urlMap", urlMap); + + String channelName = "clientInboundChannel"; + Element channelElem = DomUtils.getChildElementByTagName(element, "client-inbound-channel"); + RuntimeBeanReference clientInChannel = getMessageChannel(channelName, channelElem, parserCxt, source); + + channelName = "clientOutboundChannel"; + channelElem = DomUtils.getChildElementByTagName(element, "client-outbound-channel"); + RuntimeBeanReference clientOutChannel = getMessageChannel(channelName, channelElem, parserCxt, source); + + RootBeanDefinition userSessionRegistryDef = new RootBeanDefinition(DefaultUserSessionRegistry.class); + String userSessionRegistryName = registerBeanDef(userSessionRegistryDef, parserCxt, source); + RuntimeBeanReference userSessionRegistry = new RuntimeBeanReference(userSessionRegistryName); + + RuntimeBeanReference subProtocolWebSocketHandler = registerSubProtocolWebSocketHandler( + clientInChannel, clientOutChannel, userSessionRegistry, parserCxt, source); + + List stompEndpointElements = DomUtils.getChildElementsByTagName(element, "stomp-endpoint"); + for(Element stompEndpointElement : stompEndpointElements) { + + RuntimeBeanReference requestHandler = registerHttpRequestHandler( + stompEndpointElement, subProtocolWebSocketHandler, parserCxt, source); + + List paths = Arrays.asList(stompEndpointElement.getAttribute("path").split(",")); + for(String path : paths) { + if (DomUtils.getChildElementByTagName(stompEndpointElement, "sockjs") != null) { + path = path.endsWith("/") ? path + "**" : path + "/**"; + } + urlMap.put(path, requestHandler); + } + } + + registerBeanDef(handlerMappingDef, parserCxt, source); + + channelName = "brokerChannel"; + channelElem = DomUtils.getChildElementByTagName(element, "broker-channel"); + RuntimeBeanReference brokerChannel = getMessageChannel(channelName, channelElem, parserCxt, source); + registerMessageBroker(element, clientInChannel, clientOutChannel, brokerChannel, parserCxt, source); + + RuntimeBeanReference messageConverter = registerBrokerMessageConverter(parserCxt, source); + RuntimeBeanReference messagingTemplate = registerBrokerMessagingTemplate(element, brokerChannel, + messageConverter, parserCxt, source); + + registerAnnotationMethodMessageHandler(element, clientInChannel, clientOutChannel, + messageConverter, messagingTemplate, parserCxt, source); + + RuntimeBeanReference userDestinationResolver = registerUserDestinationResolver(element, + userSessionRegistryDef, parserCxt, source); + + registerUserDestinationMessageHandler(clientInChannel, clientOutChannel, brokerChannel, + userDestinationResolver, parserCxt, source); + + parserCxt.popAndRegisterContainingComponent(); + + return null; + } + + private RuntimeBeanReference getMessageChannel(String channelName, Element channelElement, + ParserContext parserCxt, Object source) { + + RootBeanDefinition executorDef = null; + + if (channelElement != null) { + Element executor = DomUtils.getChildElementByTagName(channelElement, "executor"); + if (executor != null) { + executorDef = new RootBeanDefinition(ThreadPoolTaskExecutor.class); + String attrValue = executor.getAttribute("core-pool-size"); + if (!StringUtils.isEmpty(attrValue)) { + executorDef.getPropertyValues().add("corePoolSize", attrValue); + } + attrValue = executor.getAttribute("max-pool-size"); + if (!StringUtils.isEmpty(attrValue)) { + executorDef.getPropertyValues().add("maxPoolSize", attrValue); + } + attrValue = executor.getAttribute("keep-alive-seconds"); + if (!StringUtils.isEmpty(attrValue)) { + executorDef.getPropertyValues().add("keepAliveSeconds", attrValue); + } + attrValue = executor.getAttribute("queue-capacity"); + if (!StringUtils.isEmpty(attrValue)) { + executorDef.getPropertyValues().add("queueCapacity", attrValue); + } + } + } + else if (!channelName.equals("brokerChannel")) { + executorDef = new RootBeanDefinition(ThreadPoolTaskExecutor.class); + } + + ConstructorArgumentValues values = new ConstructorArgumentValues(); + if (executorDef != null) { + executorDef.getPropertyValues().add("threadNamePrefix", channelName + "-"); + String executorName = channelName + "Executor"; + registerBeanDefByName(executorName, executorDef, parserCxt, source); + values.addIndexedArgumentValue(0, new RuntimeBeanReference(executorName)); + } + + RootBeanDefinition channelDef = new RootBeanDefinition(ExecutorSubscribableChannel.class, values, null); + + if (channelElement != null) { + Element interceptorsElement = DomUtils.getChildElementByTagName(channelElement, "interceptors"); + ManagedList interceptorList = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, parserCxt); + channelDef.getPropertyValues().add("interceptors", interceptorList); + } + + registerBeanDefByName(channelName, channelDef, parserCxt, source); + return new RuntimeBeanReference(channelName); + } + + private RuntimeBeanReference registerSubProtocolWebSocketHandler( + RuntimeBeanReference clientInChannel, RuntimeBeanReference clientOutChannel, + RuntimeBeanReference userSessionRegistry, ParserContext parserCxt, Object source) { + + RootBeanDefinition stompHandlerDef = new RootBeanDefinition(StompSubProtocolHandler.class); + stompHandlerDef.getPropertyValues().add("userSessionRegistry", userSessionRegistry); + registerBeanDef(stompHandlerDef, parserCxt, source); + + ConstructorArgumentValues cavs = new ConstructorArgumentValues(); + cavs.addIndexedArgumentValue(0, clientInChannel); + cavs.addIndexedArgumentValue(1, clientOutChannel); + + RootBeanDefinition subProtocolWshDef = new RootBeanDefinition(SubProtocolWebSocketHandler.class, cavs, null); + subProtocolWshDef.getPropertyValues().addPropertyValue("protocolHandlers", stompHandlerDef); + String subProtocolWshName = registerBeanDef(subProtocolWshDef, parserCxt, source); + return new RuntimeBeanReference(subProtocolWshName); + } + + private RuntimeBeanReference registerHttpRequestHandler(Element stompEndpointElement, + RuntimeBeanReference subProtocolWebSocketHandler, ParserContext parserCxt, Object source) { + + RootBeanDefinition httpRequestHandlerDef; + + RuntimeBeanReference handshakeHandler = + WebSocketNamespaceUtils.registerHandshakeHandler(stompEndpointElement, parserCxt, source); + + RuntimeBeanReference sockJsService = WebSocketNamespaceUtils.registerSockJsService( + stompEndpointElement, SOCKJS_SCHEDULER_BEAN_NAME, parserCxt, source); + + if (sockJsService != null) { + ConstructorArgumentValues cavs = new ConstructorArgumentValues(); + cavs.addIndexedArgumentValue(0, sockJsService); + cavs.addIndexedArgumentValue(1, subProtocolWebSocketHandler); + httpRequestHandlerDef = new RootBeanDefinition(SockJsHttpRequestHandler.class, cavs, null); + } + else { + ConstructorArgumentValues cavs = new ConstructorArgumentValues(); + cavs.addIndexedArgumentValue(0, subProtocolWebSocketHandler); + if(handshakeHandler != null) { + cavs.addIndexedArgumentValue(1, handshakeHandler); + } + httpRequestHandlerDef = new RootBeanDefinition(WebSocketHttpRequestHandler.class, cavs, null); + // TODO: httpRequestHandlerDef.getPropertyValues().add("handshakeInterceptors", ...); + } + + String httpRequestHandlerBeanName = registerBeanDef(httpRequestHandlerDef, parserCxt, source); + return new RuntimeBeanReference(httpRequestHandlerBeanName); + } + + private void registerMessageBroker(Element messageBrokerElement, RuntimeBeanReference clientInChannelDef, + RuntimeBeanReference clientOutChannelDef, RuntimeBeanReference brokerChannelDef, + ParserContext parserCxt, Object source) { + + Element simpleBrokerElem = DomUtils.getChildElementByTagName(messageBrokerElement, "simple-broker"); + Element brokerRelayElem = DomUtils.getChildElementByTagName(messageBrokerElement, "stomp-broker-relay"); + + ConstructorArgumentValues cavs = new ConstructorArgumentValues(); + cavs.addIndexedArgumentValue(0, clientInChannelDef); + cavs.addIndexedArgumentValue(1, clientOutChannelDef); + cavs.addIndexedArgumentValue(2, brokerChannelDef); + + if (simpleBrokerElem != null) { + + String prefix = simpleBrokerElem.getAttribute("prefix"); + cavs.addIndexedArgumentValue(3, Arrays.asList(prefix.split(","))); + RootBeanDefinition brokerDef = new RootBeanDefinition(SimpleBrokerMessageHandler.class, cavs, null); + registerBeanDef(brokerDef, parserCxt, source); + } + else if (brokerRelayElem != null) { + + String prefix = brokerRelayElem.getAttribute("prefix"); + cavs.addIndexedArgumentValue(3, Arrays.asList(prefix.split(","))); + + MutablePropertyValues mpvs = new MutablePropertyValues(); + String relayHost = brokerRelayElem.getAttribute("relay-host"); + if(!relayHost.isEmpty()) { + mpvs.add("relayHost",relayHost); + } + String relayPort = brokerRelayElem.getAttribute("relay-port"); + if(!relayPort.isEmpty()) { + mpvs.add("relayPort", Integer.valueOf(relayPort)); + } + String systemLogin = brokerRelayElem.getAttribute("system-login"); + if(!systemLogin.isEmpty()) { + mpvs.add("systemLogin",systemLogin); + } + String systemPasscode = brokerRelayElem.getAttribute("system-passcode"); + if(!systemPasscode.isEmpty()) { + mpvs.add("systemPasscode",systemPasscode); + } + String systemHeartbeatSendInterval = brokerRelayElem.getAttribute("system-heartbeat-send-interval"); + if(!systemHeartbeatSendInterval.isEmpty()) { + mpvs.add("systemHeartbeatSendInterval",Long.parseLong(systemHeartbeatSendInterval)); + } + String systemHeartbeatReceiveInterval = brokerRelayElem.getAttribute("system-heartbeat-receive-interval"); + if(!systemHeartbeatReceiveInterval.isEmpty()) { + mpvs.add("systemHeartbeatReceiveInterval",Long.parseLong(systemHeartbeatReceiveInterval)); + } + String virtualHost = brokerRelayElem.getAttribute("virtual-host"); + if(!virtualHost.isEmpty()) { + mpvs.add("virtualHost",virtualHost); + } + + RootBeanDefinition messageBrokerDef = new RootBeanDefinition(StompBrokerRelayMessageHandler.class, cavs, mpvs); + registerBeanDef(messageBrokerDef, parserCxt, source); + } + + } + + private RuntimeBeanReference registerBrokerMessageConverter(ParserContext parserCxt, Object source) { + + RootBeanDefinition contentTypeResolverDef = new RootBeanDefinition(DefaultContentTypeResolver.class); + + ManagedList convertersDef = new ManagedList(); + if (jackson2Present) { + convertersDef.add(new RootBeanDefinition(MappingJackson2MessageConverter.class)); + contentTypeResolverDef.getPropertyValues().add("defaultMimeType", MimeTypeUtils.APPLICATION_JSON); + } + convertersDef.add(new RootBeanDefinition(StringMessageConverter.class)); + convertersDef.add(new RootBeanDefinition(ByteArrayMessageConverter.class)); + + ConstructorArgumentValues cavs = new ConstructorArgumentValues(); + cavs.addIndexedArgumentValue(0, convertersDef); + cavs.addIndexedArgumentValue(1, contentTypeResolverDef); + + RootBeanDefinition brokerMessage = new RootBeanDefinition(CompositeMessageConverter.class, cavs, null); + return new RuntimeBeanReference(registerBeanDef(brokerMessage,parserCxt, source)); + } + + private RuntimeBeanReference registerBrokerMessagingTemplate( + Element element, RuntimeBeanReference brokerChannelDef, RuntimeBeanReference messageConverterRef, + ParserContext parserCxt, Object source) { + + ConstructorArgumentValues cavs = new ConstructorArgumentValues(); + cavs.addIndexedArgumentValue(0, brokerChannelDef); + RootBeanDefinition messagingTemplateDef = new RootBeanDefinition(SimpMessagingTemplate.class,cavs, null); + + String userDestinationPrefixAttribute = element.getAttribute("user-destination-prefix"); + if(!userDestinationPrefixAttribute.isEmpty()) { + messagingTemplateDef.getPropertyValues().add("userDestinationPrefix", userDestinationPrefixAttribute); + } + messagingTemplateDef.getPropertyValues().add("messageConverter", messageConverterRef); + + return new RuntimeBeanReference(registerBeanDef(messagingTemplateDef,parserCxt, source)); + } + + private void registerAnnotationMethodMessageHandler(Element messageBrokerElement, + RuntimeBeanReference clientInChannelDef, RuntimeBeanReference clientOutChannelDef, + RuntimeBeanReference brokerMessageConverterRef, RuntimeBeanReference brokerMessagingTemplateRef, + ParserContext parserCxt, Object source) { + + String applicationDestinationPrefix = messageBrokerElement.getAttribute("application-destination-prefix"); + + ConstructorArgumentValues cavs = new ConstructorArgumentValues(); + cavs.addIndexedArgumentValue(0, clientInChannelDef); + cavs.addIndexedArgumentValue(1, clientOutChannelDef); + cavs.addIndexedArgumentValue(2, brokerMessagingTemplateRef); + + MutablePropertyValues mpvs = new MutablePropertyValues(); + mpvs.add("destinationPrefixes",Arrays.asList(applicationDestinationPrefix.split(","))); + mpvs.add("messageConverter", brokerMessageConverterRef); + + RootBeanDefinition annotationMethodMessageHandlerDef = + new RootBeanDefinition(SimpAnnotationMethodMessageHandler.class, cavs, mpvs); + + registerBeanDef(annotationMethodMessageHandlerDef, parserCxt, source); + } + + private RuntimeBeanReference registerUserDestinationResolver(Element messageBrokerElement, + BeanDefinition userSessionRegistryDef, ParserContext parserCxt, Object source) { + + ConstructorArgumentValues cavs = new ConstructorArgumentValues(); + cavs.addIndexedArgumentValue(0, userSessionRegistryDef); + RootBeanDefinition userDestinationResolverDef = + new RootBeanDefinition(DefaultUserDestinationResolver.class, cavs, null); + String prefix = messageBrokerElement.getAttribute("user-destination-prefix"); + if (!prefix.isEmpty()) { + userDestinationResolverDef.getPropertyValues().add("userDestinationPrefix", prefix); + } + String userDestinationResolverName = registerBeanDef(userDestinationResolverDef, parserCxt, source); + return new RuntimeBeanReference(userDestinationResolverName); + } + + private RuntimeBeanReference registerUserDestinationMessageHandler(RuntimeBeanReference clientInChannelDef, + RuntimeBeanReference clientOutChannelDef, RuntimeBeanReference brokerChannelDef, + RuntimeBeanReference userDestinationResolverRef, ParserContext parserCxt, Object source) { + + ConstructorArgumentValues cavs = new ConstructorArgumentValues(); + cavs.addIndexedArgumentValue(0, clientInChannelDef); + cavs.addIndexedArgumentValue(1, clientOutChannelDef); + cavs.addIndexedArgumentValue(2, brokerChannelDef); + cavs.addIndexedArgumentValue(3, userDestinationResolverRef); + + RootBeanDefinition userDestinationMessageHandlerDef = + new RootBeanDefinition(UserDestinationMessageHandler.class, cavs, null); + + String userDestinationMessageHandleName = registerBeanDef(userDestinationMessageHandlerDef, parserCxt, source); + return new RuntimeBeanReference(userDestinationMessageHandleName); + } + + + private static String registerBeanDef(RootBeanDefinition beanDef, ParserContext parserCxt, Object source) { + String beanName = parserCxt.getReaderContext().generateBeanName(beanDef); + registerBeanDefByName(beanName, beanDef, parserCxt, source); + return beanName; + } + + private static void registerBeanDefByName(String beanName, RootBeanDefinition beanDef, + ParserContext parserCxt, Object source) { + + beanDef.setSource(source); + beanDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + parserCxt.getRegistry().registerBeanDefinition(beanName, beanDef); + parserCxt.registerComponent(new BeanComponentDefinition(beanDef, beanName)); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/xml/WebSocketNamespaceHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/xml/WebSocketNamespaceHandler.java new file mode 100644 index 00000000000..12c1d1d3afe --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/xml/WebSocketNamespaceHandler.java @@ -0,0 +1,44 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.messaging.config.xml; + +import org.springframework.beans.factory.xml.NamespaceHandlerSupport; +import org.springframework.util.ClassUtils; +import org.springframework.web.socket.server.config.xml.HandlersBeanDefinitionParser; + + +/** + * {@link org.springframework.beans.factory.xml.NamespaceHandler} for Spring WebSocket + * configuration namespace. + * + * @author Brian Clozel + * @since 4.0 + */ +public class WebSocketNamespaceHandler extends NamespaceHandlerSupport { + + private static boolean isSpringMessagingPresent = ClassUtils.isPresent( + "org.springframework.messaging.Message", WebSocketNamespaceHandler.class.getClassLoader()); + + + @Override + public void init() { + registerBeanDefinitionParser("handlers", new HandlersBeanDefinitionParser()); + if (isSpringMessagingPresent) { + registerBeanDefinitionParser("message-broker", new MessageBrokerBeanDefinitionParser()); + } + } +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/xml/package-info.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/xml/package-info.java new file mode 100644 index 00000000000..4e0e8f7dbc8 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/xml/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Support for the {@code } XML namespace element. + */ +package org.springframework.web.socket.messaging.config.xml; \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/xml/HandlersBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/xml/HandlersBeanDefinitionParser.java new file mode 100644 index 00000000000..2945da48588 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/xml/HandlersBeanDefinitionParser.java @@ -0,0 +1,194 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.config.xml; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConstructorArgumentValues; +import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.parsing.BeanComponentDefinition; +import org.springframework.beans.factory.parsing.CompositeComponentDefinition; +import org.springframework.beans.factory.support.ManagedList; +import org.springframework.beans.factory.support.ManagedMap; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.factory.xml.BeanDefinitionParser; +import org.springframework.beans.factory.xml.ParserContext; +import org.springframework.util.xml.DomUtils; +import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; +import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; +import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler; +import org.w3c.dom.Element; + +import java.util.Arrays; +import java.util.List; + + +/** + * A {@link BeanDefinitionParser} that provides the configuration for the + * {@code } namespace element. It registers a Spring MVC + * {@link org.springframework.web.servlet.handler.SimpleUrlHandlerMapping} + * to map HTTP WebSocket handshake requests to + * {@link org.springframework.web.socket.WebSocketHandler}s. + * + * @author Brian Clozel + * @since 4.0 + */ +public class HandlersBeanDefinitionParser implements BeanDefinitionParser { + + private static final String SOCK_JS_SCHEDULER_NAME = "SockJsScheduler"; + + private static final int DEFAULT_MAPPING_ORDER = 1; + + + @Override + public BeanDefinition parse(Element element, ParserContext parserCxt) { + + Object source = parserCxt.extractSource(element); + CompositeComponentDefinition compDefinition = new CompositeComponentDefinition(element.getTagName(), source); + parserCxt.pushContainingComponent(compDefinition); + + String orderAttribute = element.getAttribute("order"); + int order = orderAttribute.isEmpty() ? DEFAULT_MAPPING_ORDER : Integer.valueOf(orderAttribute); + + RootBeanDefinition handlerMappingDef = new RootBeanDefinition(SimpleUrlHandlerMapping.class); + handlerMappingDef.setSource(source); + handlerMappingDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + handlerMappingDef.getPropertyValues().add("order", order); + String handlerMappingName = parserCxt.getReaderContext().registerWithGeneratedName(handlerMappingDef); + + RuntimeBeanReference handshakeHandler = WebSocketNamespaceUtils.registerHandshakeHandler(element, parserCxt, source); + Element interceptorsElement = DomUtils.getChildElementByTagName(element, "handshake-interceptors"); + ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, parserCxt); + RuntimeBeanReference sockJsServiceRef = + WebSocketNamespaceUtils.registerSockJsService(element, SOCK_JS_SCHEDULER_NAME, parserCxt, source); + + HandlerMappingStrategy strategy = createHandlerMappingStrategy(sockJsServiceRef, handshakeHandler, interceptors); + + List mappingElements = DomUtils.getChildElementsByTagName(element, "mapping"); + ManagedMap urlMap = new ManagedMap(); + urlMap.setSource(source); + + for(Element mappingElement : mappingElements) { + urlMap.putAll(strategy.createMappings(mappingElement, parserCxt)); + } + handlerMappingDef.getPropertyValues().add("urlMap", urlMap); + + parserCxt.registerComponent(new BeanComponentDefinition(handlerMappingDef, handlerMappingName)); + parserCxt.popAndRegisterContainingComponent(); + return null; + } + + + private interface HandlerMappingStrategy { + + public ManagedMap createMappings(Element mappingElement, ParserContext parserContext); + } + + private HandlerMappingStrategy createHandlerMappingStrategy( + RuntimeBeanReference sockJsServiceRef, RuntimeBeanReference handshakeHandlerRef, + ManagedList interceptorsList) { + + if(sockJsServiceRef != null) { + SockJSHandlerMappingStrategy strategy = new SockJSHandlerMappingStrategy(); + strategy.setSockJsServiceRef(sockJsServiceRef); + return strategy; + } + else { + WebSocketHandlerMappingStrategy strategy = new WebSocketHandlerMappingStrategy(); + strategy.setHandshakeHandlerReference(handshakeHandlerRef); + strategy.setInterceptorsList(interceptorsList); + return strategy; + } + } + + private class WebSocketHandlerMappingStrategy implements HandlerMappingStrategy { + + private RuntimeBeanReference handshakeHandlerReference; + + private ManagedList interceptorsList; + + public void setHandshakeHandlerReference(RuntimeBeanReference handshakeHandlerReference) { + this.handshakeHandlerReference = handshakeHandlerReference; + } + + public void setInterceptorsList(ManagedList interceptorsList) { this.interceptorsList = interceptorsList; } + + @Override + public ManagedMap createMappings(Element mappingElement, ParserContext parserContext) { + + ManagedMap urlMap = new ManagedMap(); + Object source = parserContext.extractSource(mappingElement); + + List mappings = Arrays.asList(mappingElement.getAttribute("path").split(",")); + RuntimeBeanReference webSocketHandlerReference = new RuntimeBeanReference(mappingElement.getAttribute("handler")); + + ConstructorArgumentValues cavs = new ConstructorArgumentValues(); + cavs.addIndexedArgumentValue(0, webSocketHandlerReference); + if(this.handshakeHandlerReference != null) { + cavs.addIndexedArgumentValue(1, this.handshakeHandlerReference); + } + RootBeanDefinition requestHandlerDef = new RootBeanDefinition(WebSocketHttpRequestHandler.class, cavs, null); + requestHandlerDef.setSource(source); + requestHandlerDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + requestHandlerDef.getPropertyValues().add("handshakeInterceptors", this.interceptorsList); + String requestHandlerName = parserContext.getReaderContext().registerWithGeneratedName(requestHandlerDef); + RuntimeBeanReference requestHandlerRef = new RuntimeBeanReference(requestHandlerName); + + for(String mapping : mappings) { + urlMap.put(mapping, requestHandlerRef); + } + + return urlMap; + } + } + + private class SockJSHandlerMappingStrategy implements HandlerMappingStrategy { + + private RuntimeBeanReference sockJsServiceRef; + + public void setSockJsServiceRef(RuntimeBeanReference sockJsServiceRef) { + this.sockJsServiceRef = sockJsServiceRef; + } + + @Override + public ManagedMap createMappings(Element mappingElement, ParserContext parserContext) { + + ManagedMap urlMap = new ManagedMap(); + Object source = parserContext.extractSource(mappingElement); + + List mappings = Arrays.asList(mappingElement.getAttribute("path").split(",")); + RuntimeBeanReference webSocketHandlerReference = new RuntimeBeanReference(mappingElement.getAttribute("handler")); + + ConstructorArgumentValues cavs = new ConstructorArgumentValues(); + cavs.addIndexedArgumentValue(0, this.sockJsServiceRef, "SockJsService"); + cavs.addIndexedArgumentValue(1, webSocketHandlerReference, "WebSocketHandler"); + + RootBeanDefinition requestHandlerDef = new RootBeanDefinition(SockJsHttpRequestHandler.class, cavs, null); + requestHandlerDef.setSource(source); + requestHandlerDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + String requestHandlerName = parserContext.getReaderContext().registerWithGeneratedName(requestHandlerDef); + RuntimeBeanReference requestHandlerRef = new RuntimeBeanReference(requestHandlerName); + + for(String path : mappings) { + String pathPattern = path.endsWith("/") ? path + "**" : path + "/**"; + urlMap.put(pathPattern, requestHandlerRef); + } + + return urlMap; + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/xml/WebSocketNamespaceUtils.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/xml/WebSocketNamespaceUtils.java new file mode 100644 index 00000000000..038bb7a5773 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/xml/WebSocketNamespaceUtils.java @@ -0,0 +1,178 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.config.xml; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConstructorArgumentValues; +import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.parsing.BeanComponentDefinition; +import org.springframework.beans.factory.support.ManagedList; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.factory.xml.ParserContext; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; +import org.springframework.util.xml.DomUtils; +import org.springframework.web.socket.server.DefaultHandshakeHandler; +import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; +import org.w3c.dom.Element; + +import java.util.Collections; + +/** + * Provides utility methods for parsing common WebSocket XML namespace elements. + * + * @author Brian Clozel + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class WebSocketNamespaceUtils { + + + public static RuntimeBeanReference registerHandshakeHandler(Element element, ParserContext parserContext, Object source) { + + RuntimeBeanReference handlerRef; + + Element handlerElem = DomUtils.getChildElementByTagName(element, "handshake-handler"); + if(handlerElem != null) { + handlerRef = new RuntimeBeanReference(handlerElem.getAttribute("ref")); + } + else { + RootBeanDefinition defaultHandlerDef = new RootBeanDefinition(DefaultHandshakeHandler.class); + defaultHandlerDef.setSource(source); + defaultHandlerDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + String handlerName = parserContext.getReaderContext().registerWithGeneratedName(defaultHandlerDef); + handlerRef = new RuntimeBeanReference(handlerName); + } + + return handlerRef; + } + + public static RuntimeBeanReference registerSockJsService(Element element, String sockJsSchedulerName, + ParserContext parserContext, Object source) { + + Element sockJsElement = DomUtils.getChildElementByTagName(element, "sockjs"); + + if(sockJsElement != null) { + ConstructorArgumentValues cavs = new ConstructorArgumentValues(); + + // TODO: polish the way constructor arguments are set + + String customTaskSchedulerName = sockJsElement.getAttribute("task-scheduler"); + if(!customTaskSchedulerName.isEmpty()) { + cavs.addIndexedArgumentValue(0, new RuntimeBeanReference(customTaskSchedulerName)); + } + else { + cavs.addIndexedArgumentValue(0, registerSockJsTaskScheduler(sockJsSchedulerName, parserContext, source)); + } + + Element transportHandlersElement = DomUtils.getChildElementByTagName(sockJsElement, "transport-handlers"); + boolean registerDefaults = true; + if(transportHandlersElement != null) { + String registerDefaultsAttribute = transportHandlersElement.getAttribute("register-defaults"); + registerDefaults = !registerDefaultsAttribute.equals("false"); + } + + ManagedList transportHandlersList = parseBeanSubElements(transportHandlersElement, parserContext); + + if(registerDefaults) { + cavs.addIndexedArgumentValue(1, Collections.emptyList()); + if(transportHandlersList.isEmpty()) { + cavs.addIndexedArgumentValue(2, new ConstructorArgumentValues.ValueHolder(null)); + } + else { + cavs.addIndexedArgumentValue(2, transportHandlersList); + } + } + else { + if(transportHandlersList.isEmpty()) { + cavs.addIndexedArgumentValue(1, new ConstructorArgumentValues.ValueHolder(null)); + } + else { + cavs.addIndexedArgumentValue(1, transportHandlersList); + } + cavs.addIndexedArgumentValue(2, new ConstructorArgumentValues.ValueHolder(null)); + } + + RootBeanDefinition sockJsServiceDef = new RootBeanDefinition(DefaultSockJsService.class, cavs, null); + sockJsServiceDef.setSource(source); + + String attrValue = sockJsElement.getAttribute("name"); + if(!attrValue.isEmpty()) { + sockJsServiceDef.getPropertyValues().add("name", attrValue); + } + attrValue = sockJsElement.getAttribute("websocket-enabled"); + if(!attrValue.isEmpty()) { + sockJsServiceDef.getPropertyValues().add("webSocketsEnabled", Boolean.valueOf(attrValue)); + } + attrValue = sockJsElement.getAttribute("session-cookie-needed"); + if(!attrValue.isEmpty()) { + sockJsServiceDef.getPropertyValues().add("sessionCookieNeeded", Boolean.valueOf(attrValue)); + } + attrValue = sockJsElement.getAttribute("stream-bytes-limit"); + if(!attrValue.isEmpty()) { + sockJsServiceDef.getPropertyValues().add("streamBytesLimit", Integer.valueOf(attrValue)); + } + attrValue = sockJsElement.getAttribute("disconnect-delay"); + if(!attrValue.isEmpty()) { + sockJsServiceDef.getPropertyValues().add("disconnectDelay", Long.valueOf(attrValue)); + } + attrValue = sockJsElement.getAttribute("http-message-cache-size"); + if(!attrValue.isEmpty()) { + sockJsServiceDef.getPropertyValues().add("httpMessageCacheSize", Integer.valueOf(attrValue)); + } + attrValue = sockJsElement.getAttribute("heartbeat-time"); + if(!attrValue.isEmpty()) { + sockJsServiceDef.getPropertyValues().add("heartbeatTime", Long.valueOf(attrValue)); + } + + sockJsServiceDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + String sockJsServiceName = parserContext.getReaderContext().registerWithGeneratedName(sockJsServiceDef); + return new RuntimeBeanReference(sockJsServiceName); + } + + return null; + } + + private static RuntimeBeanReference registerSockJsTaskScheduler(String schedulerName, + ParserContext parserContext, Object source) { + + if (!parserContext.getRegistry().containsBeanDefinition(schedulerName)) { + RootBeanDefinition taskSchedulerDef = new RootBeanDefinition(ThreadPoolTaskScheduler.class); + taskSchedulerDef.setSource(source); + taskSchedulerDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + taskSchedulerDef.getPropertyValues().add("threadNamePrefix", schedulerName + "-"); + parserContext.getRegistry().registerBeanDefinition(schedulerName, taskSchedulerDef); + parserContext.registerComponent(new BeanComponentDefinition(taskSchedulerDef, schedulerName)); + } + + return new RuntimeBeanReference(schedulerName); + } + + public static ManagedList parseBeanSubElements(Element parentElement, ParserContext parserContext) { + + ManagedList beans = new ManagedList(); + if (parentElement != null) { + beans.setSource(parserContext.extractSource(parentElement)); + for (Element beanElement : DomUtils.getChildElementsByTagName(parentElement, new String[] { "bean", "ref" })) { + Object object = parserContext.getDelegate().parsePropertySubElement(beanElement, null); + beans.add(object); + } + } + + return beans; + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/xml/package-info.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/xml/package-info.java new file mode 100644 index 00000000000..ae07cbf2417 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/xml/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Support for the {@code } XML namespace element. + */ +package org.springframework.web.socket.server.config.xml; \ No newline at end of file diff --git a/spring-websocket/src/main/resources/META-INF/spring.handlers b/spring-websocket/src/main/resources/META-INF/spring.handlers new file mode 100644 index 00000000000..a544492f093 --- /dev/null +++ b/spring-websocket/src/main/resources/META-INF/spring.handlers @@ -0,0 +1,17 @@ +# +# Copyright 2002-2013 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +http\://www.springframework.org/schema/websocket=org.springframework.web.socket.messaging.config.xml.WebSocketNamespaceHandler \ No newline at end of file diff --git a/spring-websocket/src/main/resources/META-INF/spring.schemas b/spring-websocket/src/main/resources/META-INF/spring.schemas new file mode 100644 index 00000000000..f6cd0a347b3 --- /dev/null +++ b/spring-websocket/src/main/resources/META-INF/spring.schemas @@ -0,0 +1,18 @@ +# +# Copyright 2002-2013 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +http\://www.springframework.org/schema/websocket/spring-websocket-4.0.xsd=org/springframework/web/socket/server/config/xml/spring-websocket-4.0.xsd +http\://www.springframework.org/schema/websocket/spring-websocket.xsd=org/springframework/web/socket/server/config/xml/spring-websocket-4.0.xsd diff --git a/spring-websocket/src/main/resources/META-INF/spring.tooling b/spring-websocket/src/main/resources/META-INF/spring.tooling new file mode 100644 index 00000000000..df867548edb --- /dev/null +++ b/spring-websocket/src/main/resources/META-INF/spring.tooling @@ -0,0 +1,4 @@ +# Tooling related information for the mvc namespace +http\://www.springframework.org/schema/websocket@name=websocket Namespace +http\://www.springframework.org/schema/websocket@prefix=websocket +http\://www.springframework.org/schema/websocket@icon=org/springframework/web/socket/server/config/xml/spring-websocket.gif diff --git a/spring-websocket/src/main/resources/org/springframework/web/socket/server/config/xml/spring-websocket-4.0.xsd b/spring-websocket/src/main/resources/org/springframework/web/socket/server/config/xml/spring-websocket-4.0.xsd new file mode 100644 index 00000000000..c493a13e96b --- /dev/null +++ b/spring-websocket/src/main/resources/org/springframework/web/socket/server/config/xml/spring-websocket-4.0.xsd @@ -0,0 +1,217 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java index 72ff870f34a..0ddce92313f 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java @@ -23,6 +23,7 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.web.socket.support.TestWebSocketSession; import static org.mockito.Mockito.*; @@ -45,14 +46,17 @@ public class SubProtocolWebSocketHandlerTests { @Mock SubProtocolHandler defaultHandler; - @Mock MessageChannel channel; + @Mock MessageChannel inClientChannel; + + @Mock + SubscribableChannel outClientChannel; @Before public void setup() { MockitoAnnotations.initMocks(this); - this.webSocketHandler = new SubProtocolWebSocketHandler(this.channel); + this.webSocketHandler = new SubProtocolWebSocketHandler(this.inClientChannel, this.outClientChannel); when(stompHandler.getSupportedProtocols()).thenReturn(Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp")); when(mqttHandler.getSupportedProtocols()).thenReturn(Arrays.asList("MQTT")); @@ -67,8 +71,8 @@ public class SubProtocolWebSocketHandlerTests { this.session.setAcceptedProtocol("v12.sToMp"); this.webSocketHandler.afterConnectionEstablished(session); - verify(this.stompHandler).afterSessionStarted(session, this.channel); - verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.channel); + verify(this.stompHandler).afterSessionStarted(session, this.inClientChannel); + verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.inClientChannel); } @Test @@ -77,7 +81,7 @@ public class SubProtocolWebSocketHandlerTests { this.session.setAcceptedProtocol("v12.sToMp"); this.webSocketHandler.afterConnectionEstablished(session); - verify(this.stompHandler).afterSessionStarted(session, this.channel); + verify(this.stompHandler).afterSessionStarted(session, this.inClientChannel); } @Test(expected=IllegalStateException.class) @@ -94,9 +98,9 @@ public class SubProtocolWebSocketHandlerTests { this.webSocketHandler.setDefaultProtocolHandler(defaultHandler); this.webSocketHandler.afterConnectionEstablished(session); - verify(this.defaultHandler).afterSessionStarted(session, this.channel); - verify(this.stompHandler, times(0)).afterSessionStarted(session, this.channel); - verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.channel); + verify(this.defaultHandler).afterSessionStarted(session, this.inClientChannel); + verify(this.stompHandler, times(0)).afterSessionStarted(session, this.inClientChannel); + verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.inClientChannel); } @Test @@ -105,9 +109,9 @@ public class SubProtocolWebSocketHandlerTests { this.webSocketHandler.setDefaultProtocolHandler(defaultHandler); this.webSocketHandler.afterConnectionEstablished(session); - verify(this.defaultHandler).afterSessionStarted(session, this.channel); - verify(this.stompHandler, times(0)).afterSessionStarted(session, this.channel); - verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.channel); + verify(this.defaultHandler).afterSessionStarted(session, this.inClientChannel); + verify(this.stompHandler, times(0)).afterSessionStarted(session, this.inClientChannel); + verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.inClientChannel); } @Test @@ -115,7 +119,7 @@ public class SubProtocolWebSocketHandlerTests { this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler)); this.webSocketHandler.afterConnectionEstablished(session); - verify(this.stompHandler).afterSessionStarted(session, this.channel); + verify(this.stompHandler).afterSessionStarted(session, this.inClientChannel); } @Test(expected=IllegalStateException.class) diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistrationTests.java index e34f87923b5..ba180d4b831 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistrationTests.java @@ -51,7 +51,8 @@ public class WebMvcStompEndpointRegistrationTests { @Before public void setup() { - this.wsHandler = new SubProtocolWebSocketHandler(new ExecutorSubscribableChannel()); + this.wsHandler = new SubProtocolWebSocketHandler( + new ExecutorSubscribableChannel(), new ExecutorSubscribableChannel()); this.scheduler = Mockito.mock(TaskScheduler.class); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistryTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistryTests.java index cbeff8dd10e..339e06b59ff 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistryTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistryTests.java @@ -22,6 +22,7 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.handler.DefaultUserSessionRegistry; import org.springframework.messaging.simp.handler.UserSessionRegistry; import org.springframework.scheduling.TaskScheduler; @@ -49,8 +50,9 @@ public class WebMvcStompEndpointRegistryTests { @Before public void setup() { - MessageChannel channel = Mockito.mock(MessageChannel.class); - this.webSocketHandler = new SubProtocolWebSocketHandler(channel); + MessageChannel inChannel = Mockito.mock(MessageChannel.class); + SubscribableChannel outChannel = Mockito.mock(SubscribableChannel.class); + this.webSocketHandler = new SubProtocolWebSocketHandler(inChannel, outChannel); this.userSessionRegistry = new DefaultUserSessionRegistry(); TaskScheduler taskScheduler = Mockito.mock(TaskScheduler.class); this.registry = new WebMvcStompEndpointRegistry(webSocketHandler, userSessionRegistry, taskScheduler); @@ -62,7 +64,7 @@ public class WebMvcStompEndpointRegistryTests { this.registry.addEndpoint("/stomp"); - Map protocolHandlers = webSocketHandler.getProtocolHandlers(); + Map protocolHandlers = webSocketHandler.getProtocolHandlerMap(); assertEquals(3, protocolHandlers.size()); assertNotNull(protocolHandlers.get("v10.stomp")); assertNotNull(protocolHandlers.get("v11.stomp")); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/xml/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/xml/MessageBrokerBeanDefinitionParserTests.java new file mode 100644 index 00000000000..2119439f07e --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/xml/MessageBrokerBeanDefinitionParserTests.java @@ -0,0 +1,267 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.messaging.config.xml; + +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Test; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; +import org.springframework.core.io.ClassPathResource; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.messaging.simp.handler.*; +import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; +import org.springframework.messaging.support.channel.AbstractSubscribableChannel; +import org.springframework.messaging.support.converter.CompositeMessageConverter; +import org.springframework.messaging.support.converter.MessageConverter; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.web.HttpRequestHandler; +import org.springframework.web.context.support.GenericWebApplicationContext; +import org.springframework.web.servlet.HandlerMapping; +import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; +import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; +import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler; +import org.springframework.web.socket.support.WebSocketHandlerDecorator; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.*; + +/** + * Test fixture for the configuration in websocket-config-broker*.xml test files. + * @author Brian Clozel + */ +public class MessageBrokerBeanDefinitionParserTests { + + private GenericWebApplicationContext appContext; + + @Before + public void setup() { + this.appContext = new GenericWebApplicationContext(); + } + + @Test + public void simpleBroker() { + + loadBeanDefinitions("websocket-config-broker-simple.xml"); + + HandlerMapping hm = this.appContext.getBean(HandlerMapping.class); + assertNotNull(hm); + assertThat(hm, Matchers.instanceOf(SimpleUrlHandlerMapping.class)); + + SimpleUrlHandlerMapping suhm = (SimpleUrlHandlerMapping) hm; + assertThat(suhm.getUrlMap().keySet(), Matchers.hasSize(4)); + assertThat(suhm.getUrlMap().values(), Matchers.hasSize(4)); + + HttpRequestHandler httpRequestHandler = (HttpRequestHandler) suhm.getUrlMap().get("/foo"); + assertNotNull(httpRequestHandler); + assertThat(httpRequestHandler, Matchers.instanceOf(WebSocketHttpRequestHandler.class)); + WebSocketHttpRequestHandler wsHttpRequestHandler = (WebSocketHttpRequestHandler) httpRequestHandler; + WebSocketHandler wsHandler = unwrapWebSocketHandler(wsHttpRequestHandler.getWebSocketHandler()); + assertNotNull(wsHandler); + assertThat(wsHandler, Matchers.instanceOf(SubProtocolWebSocketHandler.class)); + SubProtocolWebSocketHandler subProtocolWsHandler = (SubProtocolWebSocketHandler) wsHandler; + assertEquals(Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp"), subProtocolWsHandler.getSubProtocols()); + + httpRequestHandler = (HttpRequestHandler) suhm.getUrlMap().get("/test/**"); + assertNotNull(httpRequestHandler); + assertThat(httpRequestHandler, Matchers.instanceOf(SockJsHttpRequestHandler.class)); + SockJsHttpRequestHandler sockJsHttpRequestHandler = (SockJsHttpRequestHandler) httpRequestHandler; + wsHandler = unwrapWebSocketHandler(sockJsHttpRequestHandler.getWebSocketHandler()); + assertNotNull(wsHandler); + assertThat(wsHandler, Matchers.instanceOf(SubProtocolWebSocketHandler.class)); + assertNotNull(sockJsHttpRequestHandler.getSockJsService()); + + UserDestinationResolver userDestResolver = this.appContext.getBean(UserDestinationResolver.class); + assertNotNull(userDestResolver); + assertThat(userDestResolver, Matchers.instanceOf(DefaultUserDestinationResolver.class)); + DefaultUserDestinationResolver defaultUserDestResolver = (DefaultUserDestinationResolver) userDestResolver; + assertEquals("/personal/", defaultUserDestResolver.getDestinationPrefix()); + + List> subscriberTypes = + Arrays.>asList(SimpAnnotationMethodMessageHandler.class, + UserDestinationMessageHandler.class, SimpleBrokerMessageHandler.class); + testChannel("clientInboundChannel", subscriberTypes, 0); + testExecutor("clientInboundChannel", 1, Integer.MAX_VALUE, 60); + + subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); + testChannel("clientOutboundChannel", subscriberTypes, 0); + testExecutor("clientOutboundChannel", 1, Integer.MAX_VALUE, 60); + + subscriberTypes = Arrays.>asList( + SimpleBrokerMessageHandler.class, UserDestinationMessageHandler.class); + testChannel("brokerChannel", subscriberTypes, 0); + try { + this.appContext.getBean("brokerChannelExecutor", ThreadPoolTaskExecutor.class); + fail("expected exception"); + } + catch (NoSuchBeanDefinitionException ex) { + // expected + } + } + + @Test + public void stompBrokerRelay() { + + loadBeanDefinitions("websocket-config-broker-relay.xml"); + + HandlerMapping hm = this.appContext.getBean(HandlerMapping.class); + assertNotNull(hm); + assertThat(hm, Matchers.instanceOf(SimpleUrlHandlerMapping.class)); + + SimpleUrlHandlerMapping suhm = (SimpleUrlHandlerMapping) hm; + assertThat(suhm.getUrlMap().keySet(), Matchers.hasSize(1)); + assertThat(suhm.getUrlMap().values(), Matchers.hasSize(1)); + assertEquals(2, suhm.getOrder()); + + HttpRequestHandler httpRequestHandler = (HttpRequestHandler) suhm.getUrlMap().get("/foo/**"); + assertNotNull(httpRequestHandler); + assertThat(httpRequestHandler, Matchers.instanceOf(SockJsHttpRequestHandler.class)); + SockJsHttpRequestHandler sockJsHttpRequestHandler = (SockJsHttpRequestHandler) httpRequestHandler; + WebSocketHandler wsHandler = unwrapWebSocketHandler(sockJsHttpRequestHandler.getWebSocketHandler()); + assertNotNull(wsHandler); + assertThat(wsHandler, Matchers.instanceOf(SubProtocolWebSocketHandler.class)); + assertNotNull(sockJsHttpRequestHandler.getSockJsService()); + + UserDestinationResolver userDestResolver = this.appContext.getBean(UserDestinationResolver.class); + assertNotNull(userDestResolver); + assertThat(userDestResolver, Matchers.instanceOf(DefaultUserDestinationResolver.class)); + DefaultUserDestinationResolver defaultUserDestResolver = (DefaultUserDestinationResolver) userDestResolver; + assertEquals("/user/", defaultUserDestResolver.getDestinationPrefix()); + + StompBrokerRelayMessageHandler messageBroker = this.appContext.getBean(StompBrokerRelayMessageHandler.class); + assertNotNull(messageBroker); + assertEquals("login", messageBroker.getSystemLogin()); + assertEquals("pass", messageBroker.getSystemPasscode()); + assertEquals("relayhost", messageBroker.getRelayHost()); + assertEquals(1234, messageBroker.getRelayPort()); + assertEquals("spring.io", messageBroker.getVirtualHost()); + assertEquals(5000, messageBroker.getSystemHeartbeatReceiveInterval()); + assertEquals(5000, messageBroker.getSystemHeartbeatSendInterval()); + assertThat(messageBroker.getDestinationPrefixes(), Matchers.containsInAnyOrder("/topic","/queue")); + + List> subscriberTypes = + Arrays.>asList(SimpAnnotationMethodMessageHandler.class, + UserDestinationMessageHandler.class, StompBrokerRelayMessageHandler.class); + testChannel("clientInboundChannel", subscriberTypes, 0); + testExecutor("clientInboundChannel", 1, Integer.MAX_VALUE, 60); + + subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); + testChannel("clientOutboundChannel", subscriberTypes, 0); + testExecutor("clientOutboundChannel", 1, Integer.MAX_VALUE, 60); + + subscriberTypes = Arrays.>asList( + StompBrokerRelayMessageHandler.class, UserDestinationMessageHandler.class); + testChannel("brokerChannel", subscriberTypes, 0); + try { + this.appContext.getBean("brokerChannelExecutor", ThreadPoolTaskExecutor.class); + fail("expected exception"); + } + catch (NoSuchBeanDefinitionException ex) { + // expected + } + } + + @Test + public void annotationMethodMessageHandler() { + + loadBeanDefinitions("websocket-config-broker-simple.xml"); + + SimpAnnotationMethodMessageHandler annotationMethodMessageHandler = + this.appContext.getBean(SimpAnnotationMethodMessageHandler.class); + + assertNotNull(annotationMethodMessageHandler); + MessageConverter messageConverter = annotationMethodMessageHandler.getMessageConverter(); + assertNotNull(messageConverter); + assertTrue(messageConverter instanceof CompositeMessageConverter); + + + CompositeMessageConverter compositeMessageConverter = this.appContext.getBean(CompositeMessageConverter.class); + assertNotNull(compositeMessageConverter); + + SimpMessagingTemplate simpMessagingTemplate = this.appContext.getBean(SimpMessagingTemplate.class); + assertNotNull(simpMessagingTemplate); + assertEquals("/personal", simpMessagingTemplate.getUserDestinationPrefix()); + + } + + @Test + public void customChannels() { + + loadBeanDefinitions("websocket-config-broker-customchannels.xml"); + + List> subscriberTypes = + Arrays.>asList(SimpAnnotationMethodMessageHandler.class, + UserDestinationMessageHandler.class, SimpleBrokerMessageHandler.class); + + testChannel("clientInboundChannel", subscriberTypes, 1); + testExecutor("clientInboundChannel", 100, 200, 600); + + subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); + + testChannel("clientOutboundChannel", subscriberTypes, 2); + testExecutor("clientOutboundChannel", 101, 201, 601); + + subscriberTypes = Arrays.>asList(SimpleBrokerMessageHandler.class, + UserDestinationMessageHandler.class); + + testChannel("brokerChannel", subscriberTypes, 0); + testExecutor("brokerChannel", 102, 202, 602); + } + + private void testChannel(String channelName, List> subscriberTypes, + int interceptorCount) { + + AbstractSubscribableChannel channel = this.appContext.getBean(channelName, AbstractSubscribableChannel.class); + + for (Class subscriberType : subscriberTypes) { + MessageHandler subscriber = this.appContext.getBean(subscriberType); + assertNotNull("No subsription for " + subscriberType, subscriber); + assertTrue(channel.hasSubscription(subscriber)); + } + + assertEquals(interceptorCount, channel.getInterceptors().size()); + } + + private void testExecutor(String channelName, int corePoolSize, int maxPoolSize, int keepAliveSeconds) { + + ThreadPoolTaskExecutor taskExecutor = + this.appContext.getBean(channelName + "Executor", ThreadPoolTaskExecutor.class); + + assertEquals(corePoolSize, taskExecutor.getCorePoolSize()); + assertEquals(maxPoolSize, taskExecutor.getMaxPoolSize()); + assertEquals(keepAliveSeconds, taskExecutor.getKeepAliveSeconds()); + } + + private void loadBeanDefinitions(String fileName) { + XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.appContext); + ClassPathResource resource = new ClassPathResource(fileName, MessageBrokerBeanDefinitionParserTests.class); + reader.loadBeanDefinitions(resource); + this.appContext.refresh(); + } + + private WebSocketHandler unwrapWebSocketHandler(WebSocketHandler handler) { + return (handler instanceof WebSocketHandlerDecorator) ? + ((WebSocketHandlerDecorator) handler).getLastHandler() : handler; + } + +} \ No newline at end of file diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/config/xml/HandlersBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/config/xml/HandlersBeanDefinitionParserTests.java new file mode 100644 index 00000000000..cbdd88d3fee --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/config/xml/HandlersBeanDefinitionParserTests.java @@ -0,0 +1,291 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.config.xml; + +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Test; +import org.springframework.beans.DirectFieldAccessor; +import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; +import org.springframework.core.io.ClassPathResource; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.scheduling.Trigger; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; +import org.springframework.web.context.support.GenericWebApplicationContext; +import org.springframework.web.servlet.HandlerMapping; +import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.server.DefaultHandshakeHandler; +import org.springframework.web.socket.server.HandshakeFailureException; +import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; +import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; +import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler; +import org.springframework.web.socket.sockjs.SockJsService; +import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; +import org.springframework.web.socket.sockjs.transport.handler.EventSourceTransportHandler; +import org.springframework.web.socket.sockjs.transport.handler.HtmlFileTransportHandler; +import org.springframework.web.socket.sockjs.transport.handler.JsonpPollingTransportHandler; +import org.springframework.web.socket.sockjs.transport.handler.JsonpReceivingTransportHandler; +import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; +import org.springframework.web.socket.sockjs.transport.handler.XhrPollingTransportHandler; +import org.springframework.web.socket.sockjs.transport.handler.XhrReceivingTransportHandler; +import org.springframework.web.socket.sockjs.transport.handler.XhrStreamingTransportHandler; + +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ScheduledFuture; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +/** + * Test fixture for HandlersBeanDefinitionParser. + * See test configuration files websocket-config-handlers*.xml. + * + * @author Brian Clozel + */ +public class HandlersBeanDefinitionParserTests { + + private GenericWebApplicationContext appContext; + + + @Before + public void setup() { + appContext = new GenericWebApplicationContext(); + } + + @Test + public void webSocketHandlers() { + loadBeanDefinitions("websocket-config-handlers.xml"); + Map handlersMap = appContext.getBeansOfType(HandlerMapping.class); + assertNotNull(handlersMap); + assertThat(handlersMap.values(), Matchers.hasSize(2)); + + for(HandlerMapping handlerMapping : handlersMap.values()) { + assertTrue(handlerMapping instanceof SimpleUrlHandlerMapping); + SimpleUrlHandlerMapping urlHandlerMapping = (SimpleUrlHandlerMapping) handlerMapping; + + if(urlHandlerMapping.getUrlMap().keySet().contains("/foo")) { + assertThat(urlHandlerMapping.getUrlMap().keySet(),Matchers.contains("/foo","/bar")); + WebSocketHttpRequestHandler handler = (WebSocketHttpRequestHandler) + urlHandlerMapping.getUrlMap().get("/foo"); + assertNotNull(handler); + checkDelegateHandlerType(handler.getWebSocketHandler(), FooWebSocketHandler.class); + HandshakeHandler handshakeHandler = (HandshakeHandler) + new DirectFieldAccessor(handler).getPropertyValue("handshakeHandler"); + assertNotNull(handshakeHandler); + assertTrue(handshakeHandler instanceof DefaultHandshakeHandler); + } + else { + assertThat(urlHandlerMapping.getUrlMap().keySet(),Matchers.contains("/test")); + WebSocketHttpRequestHandler handler = (WebSocketHttpRequestHandler) + urlHandlerMapping.getUrlMap().get("/test"); + assertNotNull(handler); + checkDelegateHandlerType(handler.getWebSocketHandler(), TestWebSocketHandler.class); + HandshakeHandler handshakeHandler = (HandshakeHandler) + new DirectFieldAccessor(handler).getPropertyValue("handshakeHandler"); + assertNotNull(handshakeHandler); + assertTrue(handshakeHandler instanceof DefaultHandshakeHandler); + } + } + } + + @Test + public void websocketHandlersAttributes() { + loadBeanDefinitions("websocket-config-handlers-attributes.xml"); + HandlerMapping handlerMapping = appContext.getBean(HandlerMapping.class); + assertNotNull(handlerMapping); + assertTrue(handlerMapping instanceof SimpleUrlHandlerMapping); + + SimpleUrlHandlerMapping urlHandlerMapping = (SimpleUrlHandlerMapping) handlerMapping; + assertEquals(2, urlHandlerMapping.getOrder()); + + WebSocketHttpRequestHandler handler = (WebSocketHttpRequestHandler) urlHandlerMapping.getUrlMap().get("/foo"); + assertNotNull(handler); + checkDelegateHandlerType(handler.getWebSocketHandler(), FooWebSocketHandler.class); + HandshakeHandler handshakeHandler = (HandshakeHandler) + new DirectFieldAccessor(handler).getPropertyValue("handshakeHandler"); + assertNotNull(handshakeHandler); + assertTrue(handshakeHandler instanceof TestHandshakeHandler); + List handshakeInterceptorList = (List) + new DirectFieldAccessor(handler).getPropertyValue("interceptors"); + assertNotNull(handshakeInterceptorList); + assertThat(handshakeInterceptorList, Matchers.contains( + Matchers.instanceOf(FooTestInterceptor.class), Matchers.instanceOf(BarTestInterceptor.class))); + + handler = (WebSocketHttpRequestHandler) urlHandlerMapping.getUrlMap().get("/test"); + assertNotNull(handler); + checkDelegateHandlerType(handler.getWebSocketHandler(), TestWebSocketHandler.class); + handshakeHandler = (HandshakeHandler) new DirectFieldAccessor(handler).getPropertyValue("handshakeHandler"); + assertNotNull(handshakeHandler); + assertTrue(handshakeHandler instanceof TestHandshakeHandler); + handshakeInterceptorList = (List) + new DirectFieldAccessor(handler).getPropertyValue("interceptors"); + assertNotNull(handshakeInterceptorList); + assertThat(handshakeInterceptorList, Matchers.contains( + Matchers.instanceOf(FooTestInterceptor.class), Matchers.instanceOf(BarTestInterceptor.class))); + + } + + @Test + public void sockJSSupport() { + loadBeanDefinitions("websocket-config-handlers-sockjs.xml"); + SimpleUrlHandlerMapping handlerMapping = appContext.getBean(SimpleUrlHandlerMapping.class); + assertNotNull(handlerMapping); + SockJsHttpRequestHandler testHandler = (SockJsHttpRequestHandler) handlerMapping.getUrlMap().get("/test/**"); + assertNotNull(testHandler); + checkDelegateHandlerType(testHandler.getWebSocketHandler(), TestWebSocketHandler.class); + SockJsService testSockJsService = testHandler.getSockJsService(); + SockJsHttpRequestHandler fooHandler = (SockJsHttpRequestHandler) handlerMapping.getUrlMap().get("/foo/**"); + assertNotNull(fooHandler); + checkDelegateHandlerType(fooHandler.getWebSocketHandler(), FooWebSocketHandler.class); + + SockJsService sockJsService = fooHandler.getSockJsService(); + assertNotNull(sockJsService); + assertEquals(testSockJsService, sockJsService); + + assertThat(sockJsService, Matchers.instanceOf(DefaultSockJsService.class)); + DefaultSockJsService defaultSockJsService = (DefaultSockJsService) sockJsService; + assertThat(defaultSockJsService.getTaskScheduler(), Matchers.instanceOf(ThreadPoolTaskScheduler.class)); + assertThat(defaultSockJsService.getTransportHandlers().values(), Matchers.containsInAnyOrder( + Matchers.instanceOf(XhrPollingTransportHandler.class), + Matchers.instanceOf(XhrReceivingTransportHandler.class), + Matchers.instanceOf(JsonpPollingTransportHandler.class), + Matchers.instanceOf(JsonpReceivingTransportHandler.class), + Matchers.instanceOf(XhrStreamingTransportHandler.class), + Matchers.instanceOf(EventSourceTransportHandler.class), + Matchers.instanceOf(HtmlFileTransportHandler.class), + Matchers.instanceOf(WebSocketTransportHandler.class))); + + } + + @Test + public void sockJSAttributesSupport() { + loadBeanDefinitions("websocket-config-handlers-sockjs-attributes.xml"); + SimpleUrlHandlerMapping handlerMapping = appContext.getBean(SimpleUrlHandlerMapping.class); + assertNotNull(handlerMapping); + SockJsHttpRequestHandler handler = (SockJsHttpRequestHandler) handlerMapping.getUrlMap().get("/test/**"); + assertNotNull(handler); + checkDelegateHandlerType(handler.getWebSocketHandler(), TestWebSocketHandler.class); + SockJsService sockJsService = handler.getSockJsService(); + assertNotNull(sockJsService); + assertThat(sockJsService, Matchers.instanceOf(DefaultSockJsService.class)); + DefaultSockJsService defaultSockJsService = (DefaultSockJsService) sockJsService; + assertThat(defaultSockJsService.getTaskScheduler(), Matchers.instanceOf(TestTaskScheduler.class)); + assertThat(defaultSockJsService.getTransportHandlers().values(), Matchers.containsInAnyOrder( + Matchers.instanceOf(XhrPollingTransportHandler.class), + Matchers.instanceOf(XhrStreamingTransportHandler.class))); + + assertEquals("testSockJsService", defaultSockJsService.getName()); + assertFalse(defaultSockJsService.isWebSocketEnabled()); + assertFalse(defaultSockJsService.isSessionCookieNeeded()); + assertEquals(2048, defaultSockJsService.getStreamBytesLimit()); + assertEquals(256, defaultSockJsService.getDisconnectDelay()); + assertEquals(1024, defaultSockJsService.getHttpMessageCacheSize()); + assertEquals(20, defaultSockJsService.getHeartbeatTime()); + } + + private void loadBeanDefinitions(String fileName) { + XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(appContext); + ClassPathResource resource = new ClassPathResource(fileName, HandlersBeanDefinitionParserTests.class); + reader.loadBeanDefinitions(resource); + appContext.refresh(); + } + + private void checkDelegateHandlerType(WebSocketHandler handler, Class handlerClass) { + do { + handler = (WebSocketHandler) new DirectFieldAccessor(handler).getPropertyValue("delegate"); + } + while (new DirectFieldAccessor(handler).isReadableProperty("delegate")); + assertTrue(handlerClass.isInstance(handler)); + } + +} + +class TestWebSocketHandler implements WebSocketHandler { + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception {} + + @Override + public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception {} + + @Override + public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {} + + @Override + public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {} + + @Override + public boolean supportsPartialMessages() { return false; } +} + +class FooWebSocketHandler extends TestWebSocketHandler { } + +class TestHandshakeHandler implements HandshakeHandler { + @Override + public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Map attributes) + throws HandshakeFailureException { + return false; + } +} + +class FooTestInterceptor implements HandshakeInterceptor { + @Override + public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Map attributes) + throws Exception { + return false; + } + + @Override + public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Exception exception) { + } +} + +class BarTestInterceptor extends FooTestInterceptor {} + +class TestTaskScheduler implements TaskScheduler { + @Override + public ScheduledFuture schedule(Runnable task, Trigger trigger) { return null; } + + @Override + public ScheduledFuture schedule(Runnable task, Date startTime) { return null; } + + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable task, Date startTime, long period) { return null; } + + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable task, long period) { return null; } + + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable task, Date startTime, long delay) { return null; } + + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable task, long delay) { return null; } +} \ No newline at end of file diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/messaging/config/xml/websocket-config-broker-customchannels.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/messaging/config/xml/websocket-config-broker-customchannels.xml new file mode 100644 index 00000000000..5f0a809e6b5 --- /dev/null +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/messaging/config/xml/websocket-config-broker-customchannels.xml @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/messaging/config/xml/websocket-config-broker-relay.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/messaging/config/xml/websocket-config-broker-relay.xml new file mode 100644 index 00000000000..a60341a4803 --- /dev/null +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/messaging/config/xml/websocket-config-broker-relay.xml @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/messaging/config/xml/websocket-config-broker-simple.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/messaging/config/xml/websocket-config-broker-simple.xml new file mode 100644 index 00000000000..4447bb341dc --- /dev/null +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/messaging/config/xml/websocket-config-broker-simple.xml @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/server/config/xml/websocket-config-handlers-attributes.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/server/config/xml/websocket-config-handlers-attributes.xml new file mode 100644 index 00000000000..dd2b9a75714 --- /dev/null +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/server/config/xml/websocket-config-handlers-attributes.xml @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/server/config/xml/websocket-config-handlers-sockjs-attributes.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/server/config/xml/websocket-config-handlers-sockjs-attributes.xml new file mode 100644 index 00000000000..38ab610280e --- /dev/null +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/server/config/xml/websocket-config-handlers-sockjs-attributes.xml @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/server/config/xml/websocket-config-handlers-sockjs.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/server/config/xml/websocket-config-handlers-sockjs.xml new file mode 100644 index 00000000000..f37802b1d1b --- /dev/null +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/server/config/xml/websocket-config-handlers-sockjs.xml @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/server/config/xml/websocket-config-handlers.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/server/config/xml/websocket-config-handlers.xml new file mode 100644 index 00000000000..8a9cb3794b7 --- /dev/null +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/server/config/xml/websocket-config-handlers.xml @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + +