diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandler.java index c30c0b4139f..f2ede8cba2c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,14 +16,19 @@ package org.springframework.web.socket.sockjs.transport.handler; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.SubProtocolCapable; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.TextWebSocketHandler; +import org.springframework.web.socket.handler.WebSocketHandlerDecorator; import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession; @@ -42,17 +47,19 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo * @author Rossen Stoyanchev * @since 4.0 */ -public class SockJsWebSocketHandler extends TextWebSocketHandler { +public class SockJsWebSocketHandler extends TextWebSocketHandler implements SubProtocolCapable { private final SockJsServiceConfig sockJsServiceConfig; private final WebSocketServerSockJsSession sockJsSession; + private final List subProtocols; + private final AtomicInteger sessionCount = new AtomicInteger(0); - public SockJsWebSocketHandler(SockJsServiceConfig serviceConfig, - WebSocketHandler webSocketHandler, WebSocketServerSockJsSession sockJsSession) { + public SockJsWebSocketHandler(SockJsServiceConfig serviceConfig, WebSocketHandler webSocketHandler, + WebSocketServerSockJsSession sockJsSession) { Assert.notNull(serviceConfig, "serviceConfig must not be null"); Assert.notNull(webSocketHandler, "webSocketHandler must not be null"); @@ -60,6 +67,15 @@ public class SockJsWebSocketHandler extends TextWebSocketHandler { this.sockJsServiceConfig = serviceConfig; this.sockJsSession = sockJsSession; + + webSocketHandler = WebSocketHandlerDecorator.unwrap(webSocketHandler); + this.subProtocols = ((webSocketHandler instanceof SubProtocolCapable) ? + new ArrayList(((SubProtocolCapable) webSocketHandler).getSubProtocols()) : null); + } + + @Override + public List getSubProtocols() { + return this.subProtocols; } protected SockJsServiceConfig getSockJsConfig() { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandlerTests.java new file mode 100644 index 00000000000..ee2e521a66a --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandlerTests.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.transport.handler; + +import org.junit.Test; +import org.springframework.messaging.SubscribableChannel; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.handler.TextWebSocketHandler; +import org.springframework.web.socket.messaging.StompSubProtocolHandler; +import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; +import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link SockJsWebSocketHandler}. + * @author Rossen Stoyanchev + */ +public class SockJsWebSocketHandlerTests { + + + @Test + public void getSubProtocols() throws Exception { + SubscribableChannel channel = mock(SubscribableChannel.class); + SubProtocolWebSocketHandler handler = new SubProtocolWebSocketHandler(channel, channel); + StompSubProtocolHandler stompHandler = new StompSubProtocolHandler(); + handler.addProtocolHandler(stompHandler); + + TaskScheduler scheduler = mock(TaskScheduler.class); + DefaultSockJsService service = new DefaultSockJsService(scheduler); + WebSocketServerSockJsSession session = new WebSocketServerSockJsSession("1", service, handler, null); + SockJsWebSocketHandler sockJsHandler = new SockJsWebSocketHandler(service, handler, session); + + assertEquals(stompHandler.getSupportedProtocols(), sockJsHandler.getSubProtocols()); + } + + @Test + public void getSubProtocolsNone() throws Exception { + WebSocketHandler handler = new TextWebSocketHandler(); + TaskScheduler scheduler = mock(TaskScheduler.class); + DefaultSockJsService service = new DefaultSockJsService(scheduler); + WebSocketServerSockJsSession session = new WebSocketServerSockJsSession("1", service, handler, null); + SockJsWebSocketHandler sockJsHandler = new SockJsWebSocketHandler(service, handler, session); + + assertNull(sockJsHandler.getSubProtocols()); + } + +}