Browse Source

Allow plugging in a WebSocketHandlerDecorator

The WebSocketMessageBroker config now allows wrapping the
SubProtocolWebSocketHandler to enable advanced use cases that may
require access to the underlying WebSocketSession.

Issue: SPR-12314
pull/661/head
Rossen Stoyanchev 12 years ago
parent
commit
97596fb9f6
  1. 73
      spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java
  2. 14
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java
  3. 40
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketTransportRegistration.java
  4. 42
      spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecoratorFactory.java
  5. 32
      spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd
  6. 24
      spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java
  7. 87
      spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java
  8. 6
      spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml

73
spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java

@ -21,6 +21,9 @@ import java.util.Collections; @@ -21,6 +21,9 @@ import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
import org.w3c.dom.Element;
import org.springframework.beans.MutablePropertyValues;
@ -89,7 +92,9 @@ import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; @@ -89,7 +92,9 @@ import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
*/
class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
private static final String SOCKJS_SCHEDULER_BEAN_NAME = "messageBrokerSockJsScheduler";
public static final String WEB_SOCKET_HANDLER_BEAN_NAME = "subProtocolWebSocketHandler";
public static final String SOCKJS_SCHEDULER_BEAN_NAME = "messageBrokerSockJsScheduler";
private static final int DEFAULT_MAPPING_ORDER = 1;
@ -156,7 +161,7 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { @@ -156,7 +161,7 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
scopeConfigurer.getPropertyValues().add("scopes", scopeMap);
registerBeanDefByName("webSocketScopeConfigurer", scopeConfigurer, context, source);
registerWebSocketMessageBrokerStats(subProtoHandler, broker, inChannel, outChannel, context, source);
registerWebSocketMessageBrokerStats(broker, inChannel, outChannel, context, source);
context.popAndRegisterContainingComponent();
return null;
@ -228,8 +233,10 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { @@ -228,8 +233,10 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
cavs.addIndexedArgumentValue(0, inChannel);
cavs.addIndexedArgumentValue(1, outChannel);
RootBeanDefinition beanDef = new RootBeanDefinition(SubProtocolWebSocketHandler.class, cavs, null);
beanDef.getPropertyValues().addPropertyValue("protocolHandlers", stompHandlerDef);
RootBeanDefinition handlerDef = new RootBeanDefinition(SubProtocolWebSocketHandler.class, cavs, null);
handlerDef.getPropertyValues().addPropertyValue("protocolHandlers", stompHandlerDef);
registerBeanDefByName(WEB_SOCKET_HANDLER_BEAN_NAME, handlerDef, context, source);
RuntimeBeanReference result = new RuntimeBeanReference(WEB_SOCKET_HANDLER_BEAN_NAME);
Element transportElem = DomUtils.getChildElementByTagName(element, "transport");
if (transportElem != null) {
@ -237,13 +244,21 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { @@ -237,13 +244,21 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
stompHandlerDef.getPropertyValues().add("messageSizeLimit", transportElem.getAttribute("message-size"));
}
if (transportElem.hasAttribute("send-timeout")) {
beanDef.getPropertyValues().add("sendTimeLimit", transportElem.getAttribute("send-timeout"));
handlerDef.getPropertyValues().add("sendTimeLimit", transportElem.getAttribute("send-timeout"));
}
if (transportElem.hasAttribute("send-buffer-size")) {
beanDef.getPropertyValues().add("sendBufferSizeLimit", transportElem.getAttribute("send-buffer-size"));
handlerDef.getPropertyValues().add("sendBufferSizeLimit", transportElem.getAttribute("send-buffer-size"));
}
Element factoriesElement = DomUtils.getChildElementByTagName(transportElem, "decorator-factories");
if (factoriesElement != null) {
ManagedList<Object> factories = extractBeanSubElements(factoriesElement, context);
RootBeanDefinition factoryBean = new RootBeanDefinition(DecoratingFactoryBean.class);
factoryBean.getConstructorArgumentValues().addIndexedArgumentValue(0, handlerDef);
factoryBean.getConstructorArgumentValues().addIndexedArgumentValue(1, factories);
result = new RuntimeBeanReference(registerBeanDef(factoryBean, context, source));
}
}
return new RuntimeBeanReference(registerBeanDef(beanDef, context, source));
return result;
}
private RuntimeBeanReference registerRequestHandler(Element element, RuntimeBeanReference subProtoHandler,
@ -448,14 +463,15 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { @@ -448,14 +463,15 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
return new RuntimeBeanReference(registerBeanDef(beanDef, context, source));
}
private void registerWebSocketMessageBrokerStats(RuntimeBeanReference subProtoHandler,
RootBeanDefinition broker, RuntimeBeanReference inChannel, RuntimeBeanReference outChannel,
ParserContext context, Object source) {
private void registerWebSocketMessageBrokerStats(RootBeanDefinition broker, RuntimeBeanReference inChannel,
RuntimeBeanReference outChannel, ParserContext context, Object source) {
RootBeanDefinition beanDef = new RootBeanDefinition(WebSocketMessageBrokerStats.class);
beanDef.getPropertyValues().add("subProtocolWebSocketHandler", subProtoHandler);
if (StompBrokerRelayMessageHandler.class.equals(broker.getBeanClass())) {
RuntimeBeanReference webSocketHandler = new RuntimeBeanReference(WEB_SOCKET_HANDLER_BEAN_NAME);
beanDef.getPropertyValues().add("subProtocolWebSocketHandler", webSocketHandler);
if (StompBrokerRelayMessageHandler.class.equals(broker.getBeanClass())) {
beanDef.getPropertyValues().add("stompBrokerRelay", broker);
}
String name = inChannel.getBeanName() + "Executor";
@ -486,4 +502,37 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { @@ -486,4 +502,37 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
context.registerComponent(new BeanComponentDefinition(beanDef, name));
}
private static class DecoratingFactoryBean implements FactoryBean<WebSocketHandler> {
private final WebSocketHandler handler;
private final List<WebSocketHandlerDecoratorFactory> factories;
private DecoratingFactoryBean(WebSocketHandler handler, List<WebSocketHandlerDecoratorFactory> factories) {
this.handler = handler;
this.factories = factories;
}
@Override
public WebSocketHandler getObject() throws Exception {
WebSocketHandler result = this.handler;
for (WebSocketHandlerDecoratorFactory factory : this.factories) {
result = factory.decorate(result);
}
return result;
}
@Override
public Class<?> getObjectType() {
return WebSocketHandler.class;
}
@Override
public boolean isSingleton() {
return true;
}
}
}

14
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java

@ -23,9 +23,12 @@ import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler; @@ -23,9 +23,12 @@ import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler;
import org.springframework.messaging.simp.config.AbstractMessageBrokerConfiguration;
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.Assert;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.WebSocketMessageBrokerStats;
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
/**
@ -47,7 +50,9 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac @@ -47,7 +50,9 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
@Bean
public HandlerMapping stompWebSocketHandlerMapping() {
WebMvcStompEndpointRegistry registry = new WebMvcStompEndpointRegistry(subProtocolWebSocketHandler(),
WebSocketHandler handler = subProtocolWebSocketHandler();
handler = decorateWebSocketHandler(handler);
WebMvcStompEndpointRegistry registry = new WebMvcStompEndpointRegistry(handler,
getTransportRegistration(), userSessionRegistry(), messageBrokerSockJsTaskScheduler());
registry.setApplicationContext(getApplicationContext());
registerStompEndpoints(registry);
@ -59,6 +64,13 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac @@ -59,6 +64,13 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
return new SubProtocolWebSocketHandler(clientInboundChannel(), clientOutboundChannel());
}
protected WebSocketHandler decorateWebSocketHandler(WebSocketHandler handler) {
for (WebSocketHandlerDecoratorFactory factory : getTransportRegistration().getDecoratorFactories()) {
handler = factory.decorate(handler);
}
return handler;
}
protected final WebSocketTransportRegistration getTransportRegistration() {
if (this.transportRegistration == null) {
this.transportRegistration = new WebSocketTransportRegistration();

40
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketTransportRegistration.java

@ -16,6 +16,12 @@ @@ -16,6 +16,12 @@
package org.springframework.web.socket.config.annotation;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Configure the processing of messages received from and sent to WebSocket clients.
*
@ -30,6 +36,9 @@ public class WebSocketTransportRegistration { @@ -30,6 +36,9 @@ public class WebSocketTransportRegistration {
private Integer sendBufferSizeLimit;
private final List<WebSocketHandlerDecoratorFactory> decoratorFactories =
new ArrayList<WebSocketHandlerDecoratorFactory>(2);
/**
* Configure the maximum size for an incoming sub-protocol message.
@ -147,4 +156,35 @@ public class WebSocketTransportRegistration { @@ -147,4 +156,35 @@ public class WebSocketTransportRegistration {
protected Integer getSendBufferSizeLimit() {
return this.sendBufferSizeLimit;
}
/**
* Configure one or more factories to decorate the handler used to process
* WebSocket messages. This may be useful in some advanced use cases, for
* example to allow Spring Security to forcibly close the WebSocket session
* when the corresponding HTTP session expires.
* @since 4.1.2
*/
public WebSocketTransportRegistration setDecoratorFactories(WebSocketHandlerDecoratorFactory... factories) {
if (factories != null) {
this.decoratorFactories.addAll(Arrays.asList(factories));
}
return this;
}
/**
* Add a factory that to decorate the handler used to process WebSocket
* messages. This may be useful for some advanced use cases, for example
* to allow Spring Security to forcibly close the WebSocket session when
* the corresponding HTTP session expires.
* @since 4.1.2
*/
public WebSocketTransportRegistration addDecoratorFactory(WebSocketHandlerDecoratorFactory factory) {
this.decoratorFactories.add(factory);
return this;
}
protected List<WebSocketHandlerDecoratorFactory> getDecoratorFactories() {
return this.decoratorFactories;
}
}

42
spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecoratorFactory.java

@ -0,0 +1,42 @@ @@ -0,0 +1,42 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.handler;
import org.springframework.web.socket.WebSocketHandler;
/**
* A factory for applying decorators to a WebSocketHandler.
*
* <p>Decoration should be done through sub-classing
* {@link org.springframework.web.socket.handler.WebSocketHandlerDecorator
* WebSocketHandlerDecorator} to allow any code to traverse decorators and/or
* unwrap the original handler when necessary .
*
* @author Rossen Stoyanchev
* @since 4.1.2
*/
public interface WebSocketHandlerDecoratorFactory {
/**
* Decorate the given WebSocketHandler.
* @param handler the handler to be decorated.
* @return the same handler or the handler wrapped with a sub-class of
* {@code WebSocketHandlerDecorator}.
*/
WebSocketHandler decorate(WebSocketHandler handler);
}

32
spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd

@ -497,6 +497,38 @@ @@ -497,6 +497,38 @@
]]></xsd:documentation>
</xsd:annotation>
<xsd:complexType>
<xsd:sequence>
<xsd:element name="decorator-factories" maxOccurs="1" minOccurs="0">
<xsd:complexType>
<xsd:annotation>
<xsd:documentation source="org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory"><![CDATA[
Configure one or more factories to decorate the handler used to process WebSocket
messages. This may be useful for some advanced use cases, for example to allow
Spring Security to forcibly close the WebSocket session when the corresponding
HTTP session expires.
]]></xsd:documentation>
</xsd:annotation>
<xsd:sequence>
<xsd:choice minOccurs="1" maxOccurs="unbounded">
<xsd:element ref="beans:bean">
<xsd:annotation>
<xsd:documentation source="org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory"><![CDATA[
A WebSocketHandlerDecoratorFactory bean definition.
]]></xsd:documentation>
</xsd:annotation>
</xsd:element>
<xsd:element ref="beans:ref">
<xsd:annotation>
<xsd:documentation source="org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory"><![CDATA[
A reference to a WebSocketHandlerDecoratorFactory bean.
]]></xsd:documentation>
</xsd:annotation>
</xsd:element>
</xsd:choice>
</xsd:sequence>
</xsd:complexType>
</xsd:element>
</xsd:sequence>
<xsd:attribute name="message-size" type="xsd:string">
<xsd:annotation>
<xsd:documentation><![CDATA[

24
spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java

@ -62,7 +62,10 @@ import org.springframework.web.context.support.GenericWebApplicationContext; @@ -62,7 +62,10 @@ import org.springframework.web.context.support.GenericWebApplicationContext;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TestWebSocketSession;
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
import org.springframework.web.socket.messaging.StompSubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
@ -93,7 +96,7 @@ public class MessageBrokerBeanDefinitionParserTests { @@ -93,7 +96,7 @@ public class MessageBrokerBeanDefinitionParserTests {
@Test
public void simpleBroker() {
public void simpleBroker() throws Exception {
loadBeanDefinitions("websocket-config-broker-simple.xml");
HandlerMapping hm = this.appContext.getBean(HandlerMapping.class);
@ -113,6 +116,10 @@ public class MessageBrokerBeanDefinitionParserTests { @@ -113,6 +116,10 @@ public class MessageBrokerBeanDefinitionParserTests {
List<HandshakeInterceptor> interceptors = wsHttpRequestHandler.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class)));
WebSocketSession session = new TestWebSocketSession("id");
wsHttpRequestHandler.getWebSocketHandler().afterConnectionEstablished(session);
assertEquals(true, session.getAttributes().get("decorated"));
WebSocketHandler wsHandler = unwrapWebSocketHandler(wsHttpRequestHandler.getWebSocketHandler());
assertNotNull(wsHandler);
assertThat(wsHandler, Matchers.instanceOf(SubProtocolWebSocketHandler.class));
@ -429,7 +436,6 @@ class CustomArgumentResolver implements HandlerMethodArgumentResolver { @@ -429,7 +436,6 @@ class CustomArgumentResolver implements HandlerMethodArgumentResolver {
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
return null;
}
}
class CustomReturnValueHandler implements HandlerMethodReturnValueHandler {
@ -443,4 +449,18 @@ class CustomReturnValueHandler implements HandlerMethodReturnValueHandler { @@ -443,4 +449,18 @@ class CustomReturnValueHandler implements HandlerMethodReturnValueHandler {
public void handleReturnValue(Object returnValue, MethodParameter returnType, Message<?> message) throws Exception {
}
}
class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorFactory {
@Override
public WebSocketHandler decorate(WebSocketHandler handler) {
return new WebSocketHandlerDecorator(handler) {
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
session.getAttributes().put("decorated", true);
super.afterConnectionEstablished(session);
}
};
}
}

87
spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java

@ -22,9 +22,9 @@ import java.util.Map; @@ -22,9 +22,9 @@ import java.util.Map;
import java.util.Set;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import org.junit.Before;
import org.junit.Test;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@ -43,13 +43,17 @@ import org.springframework.stereotype.Controller; @@ -43,13 +43,17 @@ import org.springframework.stereotype.Controller;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.config.WebSocketMessageBrokerStats;
import org.springframework.web.socket.handler.TestWebSocketSession;
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
import org.springframework.web.socket.messaging.StompSubProtocolHandler;
import org.springframework.web.socket.messaging.StompTextMessageBuilder;
import org.springframework.web.socket.messaging.SubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
@ -62,19 +66,11 @@ import static org.junit.Assert.assertEquals; @@ -62,19 +66,11 @@ import static org.junit.Assert.assertEquals;
*/
public class WebSocketMessageBrokerConfigurationSupportTests {
private AnnotationConfigApplicationContext config;
@Before
public void setupOnce() {
this.config = new AnnotationConfigApplicationContext();
this.config.register(TestWebSocketMessageBrokerConfiguration.class, TestSimpleMessageBrokerConfig.class);
this.config.refresh();
}
@Test
public void handlerMapping() {
SimpleUrlHandlerMapping hm = (SimpleUrlHandlerMapping) this.config.getBean(HandlerMapping.class);
ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class);
SimpleUrlHandlerMapping hm = (SimpleUrlHandlerMapping) config.getBean(HandlerMapping.class);
assertEquals(1, hm.getOrder());
Map<String, Object> handlerMap = hm.getHandlerMap();
@ -84,8 +80,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @@ -84,8 +80,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void clientInboundChannelSendMessage() throws Exception {
TestChannel channel = this.config.getBean("clientInboundChannel", TestChannel.class);
SubProtocolWebSocketHandler webSocketHandler = this.config.getBean(SubProtocolWebSocketHandler.class);
ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class);
TestChannel channel = config.getBean("clientInboundChannel", TestChannel.class);
SubProtocolWebSocketHandler webSocketHandler = config.getBean(SubProtocolWebSocketHandler.class);
WebSocketSession session = new TestWebSocketSession("s1");
webSocketHandler.afterConnectionEstablished(session);
@ -102,7 +99,8 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @@ -102,7 +99,8 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void clientOutboundChannelChannel() {
TestChannel channel = this.config.getBean("clientOutboundChannel", TestChannel.class);
ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class);
TestChannel channel = config.getBean("clientOutboundChannel", TestChannel.class);
Set<MessageHandler> handlers = channel.getSubscribers();
assertEquals(1, handlers.size());
@ -111,8 +109,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @@ -111,8 +109,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void webSocketTransportOptions() {
ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class);
SubProtocolWebSocketHandler subProtocolWebSocketHandler =
this.config.getBean("subProtocolWebSocketHandler", SubProtocolWebSocketHandler.class);
config.getBean("subProtocolWebSocketHandler", SubProtocolWebSocketHandler.class);
assertEquals(1024 * 1024, subProtocolWebSocketHandler.getSendBufferSizeLimit());
assertEquals(25 * 1000, subProtocolWebSocketHandler.getSendTimeLimit());
@ -126,8 +125,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @@ -126,8 +125,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void messageBrokerSockJsTaskScheduler() {
ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class);
ThreadPoolTaskScheduler taskScheduler =
this.config.getBean("messageBrokerSockJsTaskScheduler", ThreadPoolTaskScheduler.class);
config.getBean("messageBrokerSockJsTaskScheduler", ThreadPoolTaskScheduler.class);
ScheduledThreadPoolExecutor executor = taskScheduler.getScheduledThreadPoolExecutor();
assertEquals(Runtime.getRuntime().availableProcessors(), executor.getCorePoolSize());
@ -136,8 +136,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @@ -136,8 +136,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void webSocketMessageBrokerStats() {
ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class);
String name = "webSocketMessageBrokerStats";
WebSocketMessageBrokerStats stats = this.config.getBean(name, WebSocketMessageBrokerStats.class);
WebSocketMessageBrokerStats stats = config.getBean(name, WebSocketMessageBrokerStats.class);
String actual = stats.toString();
String expected = "WebSocketSession\\[0 current WS\\(0\\)-HttpStream\\(0\\)-HttpPoll\\(0\\), " +
"0 total, 0 closed abnormally \\(0 connect failure, 0 send limit, 0 transport error\\)\\], " +
@ -150,6 +151,29 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @@ -150,6 +151,29 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
assertTrue("\nExpected: " + expected.replace("\\", "") + "\n Actual: " + actual, actual.matches(expected));
}
@Test
public void webSocketHandlerDecorator() throws Exception {
ApplicationContext config = createConfig(WebSocketHandlerDecoratorConfig.class);
WebSocketHandler handler = config.getBean(SubProtocolWebSocketHandler.class);
assertNotNull(handler);
SimpleUrlHandlerMapping mapping = (SimpleUrlHandlerMapping) config.getBean("stompWebSocketHandlerMapping");
WebSocketHttpRequestHandler httpHandler = (WebSocketHttpRequestHandler) mapping.getHandlerMap().get("/test");
handler = httpHandler.getWebSocketHandler();
WebSocketSession session = new TestWebSocketSession("id");
handler.afterConnectionEstablished(session);
assertEquals(true, session.getAttributes().get("decorated"));
}
private ApplicationContext createConfig(Class<?>... configClasses) {
AnnotationConfigApplicationContext config = new AnnotationConfigApplicationContext();
config.register(configClasses);
config.refresh();
return config;
}
@Controller
static class TestController {
@ -167,7 +191,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @@ -167,7 +191,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
}
@Configuration
static class TestSimpleMessageBrokerConfig extends AbstractWebSocketMessageBrokerConfigurer {
static class TestConfigurer extends AbstractWebSocketMessageBrokerConfigurer {
@Bean
public TestController subscriptionController() {
@ -188,7 +212,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @@ -188,7 +212,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
}
@Configuration
static class TestWebSocketMessageBrokerConfiguration extends DelegatingWebSocketMessageBrokerConfiguration {
static class TestChannelConfig extends DelegatingWebSocketMessageBrokerConfiguration {
@Override
@Bean
@ -208,6 +232,31 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @@ -208,6 +232,31 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
}
}
@Configuration
static class WebSocketHandlerDecoratorConfig extends WebSocketMessageBrokerConfigurationSupport {
@Override
protected void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/test");
}
@Override
protected void configureWebSocketTransport(WebSocketTransportRegistration registry) {
registry.addDecoratorFactory(new WebSocketHandlerDecoratorFactory() {
@Override
public WebSocketHandlerDecorator decorate(WebSocketHandler handler) {
return new WebSocketHandlerDecorator(handler) {
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
session.getAttributes().put("decorated", true);
super.afterConnectionEstablished(session);
}
};
}
});
}
}
private static class TestChannel extends ExecutorSubscribableChannel {
private final List<Message<?>> messages = new ArrayList<>();

6
spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml

@ -9,7 +9,11 @@ @@ -9,7 +9,11 @@
path-matcher="pathMatcher">
<!-- message-size=128*1024, send-buffer-size=1024*1024 -->
<websocket:transport message-size="131072" send-timeout="25000" send-buffer-size="1048576" />
<websocket:transport message-size="131072" send-timeout="25000" send-buffer-size="1048576">
<websocket:decorator-factories>
<bean class="org.springframework.web.socket.config.TestWebSocketHandlerDecoratorFactory" />
</websocket:decorator-factories>
</websocket:transport>
<websocket:stomp-endpoint path=" /foo,/bar">
<websocket:handshake-handler ref="myHandler"/>

Loading…
Cancel
Save