diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java index b58c3570641..00d2b5efb8c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java @@ -16,11 +16,11 @@ package org.springframework.web.socket.server.support; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.List; +import java.util.LinkedHashSet; import java.util.Map; +import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -34,8 +34,8 @@ import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.util.WebUtils; /** - * An interceptor to check request {@code Origin} header value against a collection of - * allowed origins. + * An interceptor to check request {@code Origin} header value against a + * collection of allowed origins. * * @author Sebastien Deleuze * @since 4.1.2 @@ -44,58 +44,57 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { protected Log logger = LogFactory.getLog(getClass()); - private final List allowedOrigins; + private final Set allowedOrigins = new LinkedHashSet(); /** * Default constructor with only same origin requests allowed. */ public OriginHandshakeInterceptor() { - this.allowedOrigins = new ArrayList(); } /** * Constructor using the specified allowed origin values. - * * @see #setAllowedOrigins(Collection) */ public OriginHandshakeInterceptor(Collection allowedOrigins) { - this(); setAllowedOrigins(allowedOrigins); } + /** - * Configure allowed {@code Origin} header values. This check is mostly designed for - * browser clients. There is nothing preventing other types of client to modify the - * {@code Origin} header value. - * - *

Each provided allowed origin must start by "http://", "https://" or be "*" - * (means that all origins are allowed). - * + * Configure allowed {@code Origin} header values. This check is mostly + * designed for browsers. There is nothing preventing other types of client + * to modify the {@code Origin} header value. + *

Each provided allowed origin must have a scheme, and optionally a port + * (e.g. "http://example.org", "http://example.org:9090"). An allowed origin + * string may also be "*" in which case all origins are allowed. * @see RFC 6454: The Web Origin Concept */ public void setAllowedOrigins(Collection allowedOrigins) { - Assert.notNull(allowedOrigins, "Allowed origin Collection must not be null"); + Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null"); this.allowedOrigins.clear(); this.allowedOrigins.addAll(allowedOrigins); } /** - * @see #setAllowedOrigins(Collection) * @since 4.1.5 + * @see #setAllowedOrigins */ public Collection getAllowedOrigins() { - return Collections.unmodifiableList(this.allowedOrigins); + return Collections.unmodifiableSet(this.allowedOrigins); } + @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception { + if (!WebUtils.isValidOrigin(request, this.allowedOrigins)) { response.setStatusCode(HttpStatus.FORBIDDEN); if (logger.isDebugEnabled()) { - logger.debug("Handshake request rejected, Origin header value " - + request.getHeaders().getOrigin() + " not allowed"); + logger.debug("Handshake request rejected, Origin header value " + + request.getHeaders().getOrigin() + " not allowed"); } return false; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index 9559bc3c65e..f3562f328da 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -18,13 +18,15 @@ package org.springframework.web.socket.sockjs.support; import java.io.IOException; import java.nio.charset.Charset; -import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.Date; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.List; import java.util.Random; +import java.util.Set; import java.util.concurrent.TimeUnit; import org.apache.commons.logging.Log; @@ -53,7 +55,7 @@ import org.springframework.web.util.WebUtils; * path resolution and handling of static SockJS requests (e.g. "/info", "/iframe.html", * etc). Sub-classes must handle session URLs (i.e. transport-specific requests). * - * By default, only same origin requests are allowed. Use {@link #setAllowedOrigins(List)} + * By default, only same origin requests are allowed. Use {@link #setAllowedOrigins} * to specify a list of allowed origins (a list containing "*" will allow all origins). * * @author Rossen Stoyanchev @@ -91,10 +93,10 @@ public abstract class AbstractSockJsService implements SockJsService { private boolean webSocketEnabled = true; - private final List allowedOrigins = new ArrayList(); - private boolean suppressCors = false; + protected final Set allowedOrigins = new LinkedHashSet(); + public AbstractSockJsService(TaskScheduler scheduler) { Assert.notNull(scheduler, "TaskScheduler must not be null"); @@ -271,6 +273,24 @@ public abstract class AbstractSockJsService implements SockJsService { return this.webSocketEnabled; } + /** + * This option can be used to disable automatic addition of CORS headers for + * SockJS requests. + *

The default value is "false". + * @since 4.1.2 + */ + public void setSuppressCors(boolean suppressCors) { + this.suppressCors = suppressCors; + } + + /** + * @since 4.1.2 + * @see #setSuppressCors(boolean) + */ + public boolean shouldSuppressCors() { + return this.suppressCors; + } + /** * Configure allowed {@code Origin} header values. This check is mostly * designed for browsers. There is nothing preventing other types of client @@ -286,36 +306,18 @@ public abstract class AbstractSockJsService implements SockJsService { * @see RFC 6454: The Web Origin Concept * @see SockJS supported transports by browser */ - public void setAllowedOrigins(List allowedOrigins) { - Assert.notNull(allowedOrigins, "Allowed origin List must not be null"); + public void setAllowedOrigins(Collection allowedOrigins) { + Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null"); this.allowedOrigins.clear(); this.allowedOrigins.addAll(allowedOrigins); } /** * @since 4.1.2 - * @see #setAllowedOrigins(List) - */ - public List getAllowedOrigins() { - return Collections.unmodifiableList(this.allowedOrigins); - } - - /** - * This option can be used to disable automatic addition of CORS headers for - * SockJS requests. - *

The default value is "false". - * @since 4.1.2 + * @see #setAllowedOrigins */ - public void setSuppressCors(boolean suppressCors) { - this.suppressCors = suppressCors; - } - - /** - * @since 4.1.2 - * @see #setSuppressCors(boolean) - */ - public boolean shouldSuppressCors() { - return this.suppressCors; + public Collection getAllowedOrigins() { + return Collections.unmodifiableSet(this.allowedOrigins); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index 058b7e60581..c5df1351bf0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -292,7 +292,7 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem return false; } - if (!getAllowedOrigins().contains("*")) { + if (!this.allowedOrigins.contains("*")) { TransportType transportType = TransportType.fromValue(transport); if (transportType == null || !transportType.supportsOrigin()) { if (logger.isWarnEnabled()) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java index 32fa2f9ff03..a90c30dde0e 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java @@ -18,14 +18,11 @@ package org.springframework.web.socket.config; import java.io.IOException; import java.io.InputStream; -import java.util.Arrays; import java.util.Date; import java.util.List; import java.util.Map; import java.util.concurrent.ScheduledFuture; -import static org.junit.Assert.assertEquals; -import org.junit.Before; import org.junit.Test; import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; @@ -77,13 +74,7 @@ import static org.junit.Assert.*; */ public class HandlersBeanDefinitionParserTests { - private GenericWebApplicationContext appContext; - - - @Before - public void setup() { - this.appContext = new GenericWebApplicationContext(); - } + private final GenericWebApplicationContext appContext = new GenericWebApplicationContext(); @Test @@ -235,10 +226,12 @@ public class HandlersBeanDefinitionParserTests { List interceptors = transportService.getHandshakeInterceptors(); assertThat(interceptors, contains(instanceOf(OriginHandshakeInterceptor.class))); - assertEquals(Arrays.asList("http://mydomain1.com", "http://mydomain2.com"), transportService.getAllowedOrigins()); assertTrue(transportService.shouldSuppressCors()); + assertTrue(transportService.getAllowedOrigins().contains("http://mydomain1.com")); + assertTrue(transportService.getAllowedOrigins().contains("http://mydomain2.com")); } + private void loadBeanDefinitions(String fileName) { XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.appContext); ClassPathResource resource = new ClassPathResource(fileName, HandlersBeanDefinitionParserTests.class); @@ -279,9 +272,11 @@ class TestWebSocketHandler implements WebSocketHandler { } } + class FooWebSocketHandler extends TestWebSocketHandler { } + class TestHandshakeHandler implements HandshakeHandler { @Override @@ -292,9 +287,11 @@ class TestHandshakeHandler implements HandshakeHandler { } } + class TestChannelInterceptor extends ChannelInterceptorAdapter { } + class FooTestInterceptor implements HandshakeInterceptor { @Override @@ -310,9 +307,11 @@ class FooTestInterceptor implements HandshakeInterceptor { } } + class BarTestInterceptor extends FooTestInterceptor { } + @SuppressWarnings({ "unchecked", "rawtypes" }) class TestTaskScheduler implements TaskScheduler { @@ -345,9 +344,9 @@ class TestTaskScheduler implements TaskScheduler { public ScheduledFuture scheduleWithFixedDelay(Runnable task, long delay) { return null; } - } + class TestMessageCodec implements SockJsMessageCodec { @Override @@ -364,4 +363,4 @@ class TestMessageCodec implements SockJsMessageCodec { public String[] decodeInputStream(InputStream content) throws IOException { return new String[0]; } -} \ No newline at end of file +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index 4132339a3e9..60cce380083 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ import java.util.Arrays; import java.util.List; import org.hamcrest.Matchers; -import org.junit.Before; import org.junit.Test; import org.springframework.beans.DirectFieldAccessor; @@ -77,15 +76,8 @@ import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.instanceOf; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; /** * Test fixture for MessageBrokerBeanDefinitionParser. @@ -97,16 +89,11 @@ import static org.junit.Assert.fail; */ public class MessageBrokerBeanDefinitionParserTests { - private GenericWebApplicationContext appContext; - - - @Before - public void setup() { - this.appContext = new GenericWebApplicationContext(); - } + private final GenericWebApplicationContext appContext = new GenericWebApplicationContext(); @Test + @SuppressWarnings("unchecked") public void simpleBroker() throws Exception { loadBeanDefinitions("websocket-config-broker-simple.xml"); @@ -178,7 +165,8 @@ public class MessageBrokerBeanDefinitionParserTests { interceptors = defaultSockJsService.getHandshakeInterceptors(); assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class))); - assertEquals(Arrays.asList("http://mydomain3.com", "http://mydomain4.com"), defaultSockJsService.getAllowedOrigins()); + assertTrue(defaultSockJsService.getAllowedOrigins().contains("http://mydomain3.com")); + assertTrue(defaultSockJsService.getAllowedOrigins().contains("http://mydomain4.com")); UserSessionRegistry userSessionRegistry = this.appContext.getBean(UserSessionRegistry.class); assertNotNull(userSessionRegistry); @@ -221,7 +209,7 @@ public class MessageBrokerBeanDefinitionParserTests { assertNotNull(this.appContext.getBean("webSocketScopeConfigurer", CustomScopeConfigurer.class)); DirectFieldAccessor subscriptionRegistryAccessor = new DirectFieldAccessor(brokerMessageHandler.getSubscriptionRegistry()); - String pathSeparator = (String)new DirectFieldAccessor(subscriptionRegistryAccessor.getPropertyValue("pathMatcher")).getPropertyValue("pathSeparator"); + String pathSeparator = (String) new DirectFieldAccessor(subscriptionRegistryAccessor.getPropertyValue("pathMatcher")).getPropertyValue("pathSeparator"); assertEquals(".", pathSeparator); } @@ -330,7 +318,7 @@ public class MessageBrokerBeanDefinitionParserTests { assertEquals(MimeTypeUtils.APPLICATION_JSON, ((DefaultContentTypeResolver) resolver).getDefaultMimeType()); DirectFieldAccessor handlerAccessor = new DirectFieldAccessor(annotationMethodMessageHandler); - String pathSeparator = (String)new DirectFieldAccessor(handlerAccessor.getPropertyValue("pathMatcher")).getPropertyValue("pathSeparator"); + String pathSeparator = (String) new DirectFieldAccessor(handlerAccessor.getPropertyValue("pathMatcher")).getPropertyValue("pathSeparator"); assertEquals(".", pathSeparator); } @@ -445,9 +433,9 @@ public class MessageBrokerBeanDefinitionParserTests { return (handler instanceof WebSocketHandlerDecorator) ? ((WebSocketHandlerDecorator) handler).getLastHandler() : handler; } - } + class CustomArgumentResolver implements HandlerMethodArgumentResolver { @Override @@ -461,6 +449,7 @@ class CustomArgumentResolver implements HandlerMethodArgumentResolver { } } + class CustomReturnValueHandler implements HandlerMethodReturnValueHandler { @Override @@ -474,6 +463,7 @@ class CustomReturnValueHandler implements HandlerMethodReturnValueHandler { } } + class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorFactory { @Override @@ -482,6 +472,7 @@ class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorF } } + class TestWebSocketHandlerDecorator extends WebSocketHandlerDecorator { public TestWebSocketHandlerDecorator(WebSocketHandler delegate) { @@ -493,4 +484,4 @@ class TestWebSocketHandlerDecorator extends WebSocketHandlerDecorator { session.getAttributes().put("decorated", true); super.afterConnectionEstablished(session); } -} \ No newline at end of file +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java index d323e2b8435..435e5158508 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,9 +29,9 @@ import org.springframework.scheduling.TaskScheduler; import org.springframework.util.MultiValueMap; import org.springframework.web.HttpRequestHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; -import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; import org.springframework.web.socket.sockjs.transport.TransportHandler; @@ -117,7 +117,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); assertNotNull(requestHandler.getSockJsService()); DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); - assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins()); + assertTrue(sockJsService.getAllowedOrigins().contains(origin)); assertFalse(sockJsService.shouldSuppressCors()); registration = @@ -128,7 +128,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); assertNotNull(requestHandler.getSockJsService()); sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); - assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins()); + assertTrue(sockJsService.getAllowedOrigins().contains(origin)); assertFalse(sockJsService.shouldSuppressCors()); } @@ -255,7 +255,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0)); assertEquals(OriginHandshakeInterceptor.class, sockJsService.getHandshakeInterceptors().get(1).getClass()); - assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins()); + assertTrue(sockJsService.getAllowedOrigins().contains(origin)); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java index 7591558e108..3a18fb7d13f 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java @@ -17,7 +17,6 @@ package org.springframework.web.socket.config.annotation; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import org.junit.Before; @@ -29,9 +28,9 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.handler.TextWebSocketHandler; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeInterceptor; -import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; @@ -148,8 +147,7 @@ public class WebSocketHandlerRegistrationTests { assertEquals(handler, mapping.webSocketHandler); assertEquals("/foo/**", mapping.path); assertNotNull(mapping.sockJsService); - assertEquals(Arrays.asList("http://mydomain1.com"), - mapping.sockJsService.getAllowedOrigins()); + assertTrue(mapping.sockJsService.getAllowedOrigins().contains("http://mydomain1.com")); List interceptors = mapping.sockJsService.getHandshakeInterceptors(); assertEquals(interceptor, interceptors.get(0)); assertEquals(OriginHandshakeInterceptor.class, interceptors.get(1).getClass()); @@ -218,6 +216,7 @@ public class WebSocketHandlerRegistrationTests { } } + private static class Mapping { private final WebSocketHandler webSocketHandler; @@ -230,7 +229,6 @@ public class WebSocketHandlerRegistrationTests { private final DefaultSockJsService sockJsService; - public Mapping(WebSocketHandler handler, String path, SockJsService sockJsService) { this.webSocketHandler = handler; this.path = path;