diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java index cbb9353d989..aae5f56f0d1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java @@ -260,7 +260,7 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { if (factoriesElement != null) { ManagedList factories = extractBeanSubElements(factoriesElement, context); RootBeanDefinition factoryBean = new RootBeanDefinition(DecoratingFactoryBean.class); - factoryBean.getConstructorArgumentValues().addIndexedArgumentValue(0, handlerDef); + factoryBean.getConstructorArgumentValues().addIndexedArgumentValue(0, result); factoryBean.getConstructorArgumentValues().addIndexedArgumentValue(1, factories); result = new RuntimeBeanReference(registerBeanDef(factoryBean, context, source)); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index 192aaa53a56..4132339a3e9 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -61,6 +61,8 @@ import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator; +import org.springframework.web.socket.handler.LoggingWebSocketHandlerDecorator; import org.springframework.web.socket.handler.TestWebSocketSession; import org.springframework.web.socket.handler.WebSocketHandlerDecorator; import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory; @@ -75,8 +77,15 @@ import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; -import static org.hamcrest.Matchers.*; -import static org.junit.Assert.*; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * Test fixture for MessageBrokerBeanDefinitionParser. @@ -123,9 +132,15 @@ public class MessageBrokerBeanDefinitionParserTests { wsHttpRequestHandler.getWebSocketHandler().afterConnectionEstablished(session); assertEquals(true, session.getAttributes().get("decorated")); - WebSocketHandler wsHandler = unwrapWebSocketHandler(wsHttpRequestHandler.getWebSocketHandler()); - assertNotNull(wsHandler); + WebSocketHandler wsHandler = wsHttpRequestHandler.getWebSocketHandler(); + assertThat(wsHandler, Matchers.instanceOf(ExceptionWebSocketHandlerDecorator.class)); + wsHandler = ((ExceptionWebSocketHandlerDecorator) wsHandler).getDelegate(); + assertThat(wsHandler, Matchers.instanceOf(LoggingWebSocketHandlerDecorator.class)); + wsHandler = ((LoggingWebSocketHandlerDecorator) wsHandler).getDelegate(); + assertThat(wsHandler, Matchers.instanceOf(TestWebSocketHandlerDecorator.class)); + wsHandler = ((TestWebSocketHandlerDecorator) wsHandler).getDelegate(); assertThat(wsHandler, Matchers.instanceOf(SubProtocolWebSocketHandler.class)); + assertSame(wsHandler, this.appContext.getBean(MessageBrokerBeanDefinitionParser.WEB_SOCKET_HANDLER_BEAN_NAME)); SubProtocolWebSocketHandler subProtocolWsHandler = (SubProtocolWebSocketHandler) wsHandler; assertEquals(Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp"), subProtocolWsHandler.getSubProtocols()); @@ -463,12 +478,19 @@ class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorF @Override public WebSocketHandler decorate(WebSocketHandler handler) { - return new WebSocketHandlerDecorator(handler) { - @Override - public void afterConnectionEstablished(WebSocketSession session) throws Exception { - session.getAttributes().put("decorated", true); - super.afterConnectionEstablished(session); - } - }; + return new TestWebSocketHandlerDecorator(handler); + } +} + +class TestWebSocketHandlerDecorator extends WebSocketHandlerDecorator { + + public TestWebSocketHandlerDecorator(WebSocketHandler delegate) { + super(delegate); + } + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + session.getAttributes().put("decorated", true); + super.afterConnectionEstablished(session); } } \ No newline at end of file