diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractStompEndpointRegistration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractStompEndpointRegistration.java index 4bd710bb867..049b1f3cfdc 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractStompEndpointRegistration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractStompEndpointRegistration.java @@ -22,11 +22,13 @@ import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandl import org.springframework.scheduling.TaskScheduler; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; +import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.DefaultHandshakeHandler; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.config.SockJsServiceRegistration; import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; +import org.springframework.web.socket.support.WebSocketHandlerDecorator; /** @@ -39,7 +41,7 @@ public abstract class AbstractStompEndpointRegistration implements StompEndpo private final String[] paths; - private final SubProtocolWebSocketHandler wsHandler; + private final WebSocketHandler wsHandler; private HandshakeHandler handshakeHandler; @@ -48,7 +50,7 @@ public abstract class AbstractStompEndpointRegistration implements StompEndpo private final TaskScheduler sockJsTaskScheduler; - public AbstractStompEndpointRegistration(String[] paths, SubProtocolWebSocketHandler webSocketHandler, + public AbstractStompEndpointRegistration(String[] paths, WebSocketHandler webSocketHandler, TaskScheduler sockJsTaskScheduler) { Assert.notEmpty(paths, "No paths specified"); @@ -115,7 +117,7 @@ public abstract class AbstractStompEndpointRegistration implements StompEndpo if (handler instanceof DefaultHandshakeHandler) { DefaultHandshakeHandler defaultHandshakeHandler = (DefaultHandshakeHandler) handler; if (ObjectUtils.isEmpty(defaultHandshakeHandler.getSupportedProtocols())) { - Set protocols = this.wsHandler.getSupportedProtocols(); + Set protocols = findSubProtocolWebSocketHandler(this.wsHandler).getSupportedProtocols(); defaultHandshakeHandler.setSupportedProtocols(protocols.toArray(new String[protocols.size()])); } } @@ -123,12 +125,19 @@ public abstract class AbstractStompEndpointRegistration implements StompEndpo return handler; } + private static SubProtocolWebSocketHandler findSubProtocolWebSocketHandler(WebSocketHandler webSocketHandler) { + WebSocketHandler actual = (webSocketHandler instanceof WebSocketHandlerDecorator) ? + ((WebSocketHandlerDecorator) webSocketHandler).getLastHandler() : webSocketHandler; + Assert.isInstanceOf(SubProtocolWebSocketHandler.class, actual, + "No SubProtocolWebSocketHandler found: " + webSocketHandler); + return (SubProtocolWebSocketHandler) actual; + } + protected abstract void addSockJsServiceMapping(M mappings, SockJsService sockJsService, - SubProtocolWebSocketHandler wsHandler, String pathPattern); + WebSocketHandler wsHandler, String pathPattern); protected abstract void addWebSocketHandlerMapping(M mappings, - SubProtocolWebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path); - + WebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path); private class StompSockJsServiceRegistration extends SockJsServiceRegistration { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistration.java index 3f265aab3ac..52f2423241f 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistration.java @@ -16,11 +16,11 @@ package org.springframework.messaging.simp.config; -import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler; import org.springframework.scheduling.TaskScheduler; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.HttpRequestHandler; +import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler; @@ -38,8 +38,8 @@ public class ServletStompEndpointRegistration extends AbstractStompEndpointRegistration> { - public ServletStompEndpointRegistration(String[] paths, SubProtocolWebSocketHandler wsHandler, - TaskScheduler sockJsTaskScheduler) { + public ServletStompEndpointRegistration(String[] paths, + WebSocketHandler wsHandler, TaskScheduler sockJsTaskScheduler) { super(paths, wsHandler, sockJsTaskScheduler); } @@ -51,7 +51,7 @@ public class ServletStompEndpointRegistration @Override protected void addSockJsServiceMapping(MultiValueMap mappings, - SockJsService sockJsService, SubProtocolWebSocketHandler wsHandler, String pathPattern) { + SockJsService sockJsService, WebSocketHandler wsHandler, String pathPattern) { SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, wsHandler); mappings.add(httpHandler, pathPattern); @@ -59,7 +59,7 @@ public class ServletStompEndpointRegistration @Override protected void addWebSocketHandlerMapping(MultiValueMap mappings, - SubProtocolWebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path) { + WebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path) { WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler(wsHandler, handshakeHandler); mappings.add(handler, path); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java index a270e8dc50e..3e1050c9205 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java @@ -30,6 +30,8 @@ import org.springframework.util.MultiValueMap; import org.springframework.web.HttpRequestHandler; import org.springframework.web.servlet.handler.AbstractHandlerMapping; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.support.WebSocketHandlerDecorator; /** @@ -40,7 +42,9 @@ import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; */ public class ServletStompEndpointRegistry implements StompEndpointRegistry { - private final SubProtocolWebSocketHandler wsHandler; + private final WebSocketHandler webSocketHandler; + + private final SubProtocolWebSocketHandler subProtocolWebSocketHandler; private final StompProtocolHandler stompHandler; @@ -49,23 +53,36 @@ public class ServletStompEndpointRegistry implements StompEndpointRegistry { private final TaskScheduler sockJsScheduler; - public ServletStompEndpointRegistry(SubProtocolWebSocketHandler webSocketHandler, + public ServletStompEndpointRegistry(WebSocketHandler webSocketHandler, MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler) { Assert.notNull(webSocketHandler); Assert.notNull(userQueueSuffixResolver); - this.wsHandler = webSocketHandler; + this.webSocketHandler = webSocketHandler; + this.subProtocolWebSocketHandler = findSubProtocolWebSocketHandler(webSocketHandler); this.stompHandler = new StompProtocolHandler(); this.stompHandler.setUserQueueSuffixResolver(userQueueSuffixResolver); this.sockJsScheduler = defaultSockJsTaskScheduler; } + private static SubProtocolWebSocketHandler findSubProtocolWebSocketHandler(WebSocketHandler webSocketHandler) { + + WebSocketHandler actual = (webSocketHandler instanceof WebSocketHandlerDecorator) ? + ((WebSocketHandlerDecorator) webSocketHandler).getLastHandler() : webSocketHandler; + + Assert.isInstanceOf(SubProtocolWebSocketHandler.class, actual, + "No SubProtocolWebSocketHandler found: " + webSocketHandler); + + return (SubProtocolWebSocketHandler) actual; + } + @Override public StompEndpointRegistration addEndpoint(String... paths) { - this.wsHandler.addProtocolHandler(this.stompHandler); - ServletStompEndpointRegistration r = new ServletStompEndpointRegistration(paths, this.wsHandler, this.sockJsScheduler); + this.subProtocolWebSocketHandler.addProtocolHandler(this.stompHandler); + ServletStompEndpointRegistration r = new ServletStompEndpointRegistration( + paths, this.webSocketHandler, this.sockJsScheduler); this.registrations.add(r); return r; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java index 9c0d8d5033c..ef077487fab 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java @@ -34,6 +34,7 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.handler.AbstractHandlerMapping; +import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.config.SockJsServiceRegistration; @@ -65,7 +66,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport { } @Bean - public SubProtocolWebSocketHandler subProtocolWebSocketHandler() { + public WebSocketHandler subProtocolWebSocketHandler() { SubProtocolWebSocketHandler wsHandler = new SubProtocolWebSocketHandler(webSocketRequestChannel()); webSocketResponseChannel().subscribe(wsHandler); return wsHandler; diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/AbstractStompEndpointRegistrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/AbstractStompEndpointRegistrationTests.java index 0a81564b7d9..e5ab9b4f6b2 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/AbstractStompEndpointRegistrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/AbstractStompEndpointRegistrationTests.java @@ -25,6 +25,7 @@ import org.mockito.Mockito; import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler; import org.springframework.messaging.support.channel.ExecutorSubscribableChannel; import org.springframework.scheduling.TaskScheduler; +import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.DefaultHandshakeHandler; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.sockjs.SockJsService; @@ -122,13 +123,13 @@ public class AbstractStompEndpointRegistrationTests { @Override protected void addSockJsServiceMapping(List mappings, SockJsService sockJsService, - SubProtocolWebSocketHandler wsHandler, String pathPattern) { + WebSocketHandler wsHandler, String pathPattern) { mappings.add(new Mapping(wsHandler, pathPattern, sockJsService)); } @Override - protected void addWebSocketHandlerMapping(List mappings, SubProtocolWebSocketHandler wsHandler, + protected void addWebSocketHandlerMapping(List mappings, WebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path) { mappings.add(new Mapping(wsHandler, path, handshakeHandler)); @@ -137,7 +138,7 @@ public class AbstractStompEndpointRegistrationTests { private static class Mapping { - private final SubProtocolWebSocketHandler webSocketHandler; + private final WebSocketHandler webSocketHandler; private final String path; @@ -145,14 +146,14 @@ public class AbstractStompEndpointRegistrationTests { private final DefaultSockJsService sockJsService; - public Mapping(SubProtocolWebSocketHandler handler, String path, SockJsService sockJsService) { + public Mapping(WebSocketHandler handler, String path, SockJsService sockJsService) { this.webSocketHandler = handler; this.path = path; this.handshakeHandler = null; this.sockJsService = (DefaultSockJsService) sockJsService; } - public Mapping(SubProtocolWebSocketHandler h, String path, HandshakeHandler hh) { + public Mapping(WebSocketHandler h, String path, HandshakeHandler hh) { this.webSocketHandler = h; this.path = path; this.handshakeHandler = hh;