diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index bb3bc9e5ac5..1b94724b374 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -26,7 +26,6 @@ import java.util.HashSet; import java.util.List; import java.util.Random; import java.util.concurrent.TimeUnit; - import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; @@ -279,16 +278,13 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig * Configure allowed {@code Origin} header values. This check is mostly * designed for browsers. There is nothing preventing other types of client * to modify the {@code Origin} header value. - * *

When SockJS is enabled and origins are restricted, transport types * that do not allow to check request origin (JSONP and Iframe based * transports) are disabled. As a consequence, IE 6 to 9 are not supported * when origins are restricted. - * *

Each provided allowed origin must have a scheme, and optionally a port * (e.g. "http://example.org", "http://example.org:9090"). An allowed origin * string may also be "*" in which case all origins are allowed. - * * @since 4.1.2 * @see RFC 6454: The Web Origin Concept * @see SockJS supported transports by browser @@ -325,6 +321,7 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig return this.suppressCors; } + /** * This method determines the SockJS path and handles SockJS static URLs. * Session URLs and raw WebSocket requests are delegated to abstract methods. @@ -348,22 +345,26 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig // As per SockJS protocol content-type can be ignored (it's always json) } - String requestInfo = logger.isDebugEnabled() ? request.getMethod() + " " + request.getURI() : ""; + String requestInfo = (logger.isDebugEnabled() ? request.getMethod() + " " + request.getURI() : null); try { if (sockJsPath.equals("") || sockJsPath.equals("/")) { - logger.debug(requestInfo); + if (requestInfo != null) { + logger.debug("Processing transport request: " + requestInfo); + } response.getHeaders().setContentType(new MediaType("text", "plain", UTF8_CHARSET)); response.getBody().write("Welcome to SockJS!\n".getBytes(UTF8_CHARSET)); } else if (sockJsPath.equals("/info")) { - logger.debug(requestInfo); + if (requestInfo != null) { + logger.debug("Processing transport request: " + requestInfo); + } this.infoHandler.handle(request, response); } else if (sockJsPath.matches("/iframe[0-9-.a-z_]*.html")) { if (!this.allowedOrigins.isEmpty() && !this.allowedOrigins.contains("*")) { - if (logger.isDebugEnabled()) { - logger.debug("Iframe support is disabled when an origin check is required, ignoring " + - requestInfo); + if (requestInfo != null) { + logger.debug("Iframe support is disabled when an origin check is required. " + + "Ignoring transport request: " + requestInfo); } response.setStatusCode(HttpStatus.NOT_FOUND); return; @@ -371,45 +372,57 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig if (this.allowedOrigins.isEmpty()) { response.getHeaders().add(XFRAME_OPTIONS_HEADER, "SAMEORIGIN"); } - logger.debug(requestInfo); + if (requestInfo != null) { + logger.debug("Processing transport request: " + requestInfo); + } this.iframeHandler.handle(request, response); } else if (sockJsPath.equals("/websocket")) { if (isWebSocketEnabled()) { - logger.debug(requestInfo); + if (requestInfo != null) { + logger.debug("Processing transport request: " + requestInfo); + } handleRawWebSocketRequest(request, response, wsHandler); } - else if (logger.isDebugEnabled()) { - logger.debug("WebSocket disabled, ignoring " + requestInfo); + else if (requestInfo != null) { + logger.debug("WebSocket disabled. Ignoring transport request: " + requestInfo); } } else { String[] pathSegments = StringUtils.tokenizeToStringArray(sockJsPath.substring(1), "/"); if (pathSegments.length != 3) { if (logger.isWarnEnabled()) { - logger.warn("Ignoring invalid transport request " + requestInfo); + logger.warn("Invalid SockJS path '" + sockJsPath + "' - required to have 3 path segments"); + } + if (requestInfo != null) { + logger.debug("Ignoring transport request: " + requestInfo); } response.setStatusCode(HttpStatus.NOT_FOUND); return; } + String serverId = pathSegments[0]; String sessionId = pathSegments[1]; String transport = pathSegments[2]; if (!isWebSocketEnabled() && transport.equals("websocket")) { - if (logger.isDebugEnabled()) { - logger.debug("WebSocket transport is disabled, ignoring " + requestInfo); + if (requestInfo != null) { + logger.debug("WebSocket disabled. Ignoring transport request: " + requestInfo); } response.setStatusCode(HttpStatus.NOT_FOUND); return; } else if (!validateRequest(serverId, sessionId, transport)) { - if (logger.isWarnEnabled()) { - logger.warn("Ignoring transport request " + requestInfo); + if (requestInfo != null) { + logger.debug("Ignoring transport request: " + requestInfo); } response.setStatusCode(HttpStatus.NOT_FOUND); return; } + + if (requestInfo != null) { + logger.debug("Processing transport request: " + requestInfo); + } handleTransportRequest(request, response, wsHandler, sessionId, transport); } response.close(); @@ -421,14 +434,16 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig protected boolean validateRequest(String serverId, String sessionId, String transport) { if (!StringUtils.hasText(serverId) || !StringUtils.hasText(sessionId) || !StringUtils.hasText(transport)) { - logger.warn("No server, session, or transport path segment"); + logger.warn("No server, session, or transport path segment in SockJS request."); return false; } + // Server and session id's must not contain "." if (serverId.contains(".") || sessionId.contains(".")) { logger.warn("Either server or session contains a \".\" which is not allowed by SockJS protocol."); return false; } + return true; } @@ -445,6 +460,7 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException; + protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) throws IOException { @@ -454,7 +470,9 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig if (!WebUtils.isValidOrigin(request, this.allowedOrigins)) { String origin = request.getHeaders().getOrigin(); - logger.debug("Request rejected, Origin header value " + origin + " not allowed"); + if (logger.isWarnEnabled()) { + logger.warn("Origin header value '" + origin + "' not allowed."); + } response.setStatusCode(HttpStatus.FORBIDDEN); return false; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/SockJsHttpRequestHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/SockJsHttpRequestHandler.java index d2ba85ec169..81e0c65a065 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/SockJsHttpRequestHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/SockJsHttpRequestHandler.java @@ -17,7 +17,6 @@ package org.springframework.web.socket.sockjs.support; import java.io.IOException; - import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -66,8 +65,8 @@ public class SockJsHttpRequestHandler * @param webSocketHandler the websocket handler */ public SockJsHttpRequestHandler(SockJsService sockJsService, WebSocketHandler webSocketHandler) { - Assert.notNull(sockJsService, "sockJsService must not be null"); - Assert.notNull(webSocketHandler, "webSocketHandler must not be null"); + Assert.notNull(sockJsService, "SockJsService must not be null"); + Assert.notNull(webSocketHandler, "WebSocketHandler must not be null"); this.sockJsService = sockJsService; this.webSocketHandler = new ExceptionWebSocketHandlerDecorator(new LoggingWebSocketHandlerDecorator(webSocketHandler)); @@ -95,10 +94,6 @@ public class SockJsHttpRequestHandler } } - @Override - public boolean isRunning() { - return this.running; - } @Override public void start() { @@ -120,6 +115,11 @@ public class SockJsHttpRequestHandler } } + @Override + public boolean isRunning() { + return this.running; + } + @Override public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse servletResponse) @@ -139,13 +139,13 @@ public class SockJsHttpRequestHandler private String getSockJsPath(HttpServletRequest servletRequest) { String attribute = HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE; String path = (String) servletRequest.getAttribute(attribute); - return ((path.length() > 0) && (path.charAt(0) != '/')) ? "/" + path : path; + return (path.length() > 0 && path.charAt(0) != '/' ? "/" + path : path); } @Override public CorsConfiguration getCorsConfiguration(HttpServletRequest request) { - if (sockJsService instanceof CorsConfigurationSource) { - return ((CorsConfigurationSource)sockJsService).getCorsConfiguration(request); + if (this.sockJsService instanceof CorsConfigurationSource) { + return ((CorsConfigurationSource) this.sockJsService).getCorsConfiguration(request); } return null; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index 9c27ea4e33b..6854bf5db4e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -60,8 +60,7 @@ import org.springframework.web.socket.sockjs.support.AbstractSockJsService; * @author Sebastien Deleuze * @since 4.0 */ -public class TransportHandlingSockJsService extends AbstractSockJsService - implements SockJsServiceConfig, Lifecycle { +public class TransportHandlingSockJsService extends AbstractSockJsService implements SockJsServiceConfig, Lifecycle { private static final boolean jackson2Present = ClassUtils.isPresent( "com.fasterxml.jackson.databind.ObjectMapper", TransportHandlingSockJsService.class.getClassLoader()); @@ -154,10 +153,6 @@ public class TransportHandlingSockJsService extends AbstractSockJsService return this.interceptors; } - @Override - public boolean isRunning() { - return this.running; - } @Override public void start() { @@ -183,6 +178,11 @@ public class TransportHandlingSockJsService extends AbstractSockJsService } } + @Override + public boolean isRunning() { + return this.running; + } + @Override protected void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response, @@ -322,13 +322,21 @@ public class TransportHandlingSockJsService extends AbstractSockJsService @Override protected boolean validateRequest(String serverId, String sessionId, String transport) { - if (!getAllowedOrigins().contains("*") && !TransportType.fromValue(transport).supportsOrigin()) { - if (logger.isWarnEnabled()) { - logger.warn("Origin check has been enabled, but transport " + transport + " does not support it"); - } + if (!super.validateRequest(serverId, sessionId, transport)) { return false; } - return super.validateRequest(serverId, sessionId, transport); + + if (!getAllowedOrigins().contains("*")) { + TransportType transportType = TransportType.fromValue(transport); + if (transportType == null || !transportType.supportsOrigin()) { + if (logger.isWarnEnabled()) { + logger.warn("Origin check enabled but transport '" + transport + "' does not support it."); + } + return false; + } + } + + return true; } private SockJsSession createSockJsSession(String sessionId, SockJsSessionFactory sessionFactory, diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportType.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportType.java index e3da359f915..78b844240e2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportType.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportType.java @@ -50,13 +50,8 @@ public enum TransportType { HTML_FILE("htmlfile", HttpMethod.GET, "cors", "jsessionid", "no_cache"); - private final String value; - - private final HttpMethod httpMethod; - - private final List headerHints; - private static final Map TRANSPORT_TYPES; + static { Map transportTypes = new HashMap(); for (TransportType type : values()) { @@ -65,8 +60,19 @@ public enum TransportType { TRANSPORT_TYPES = Collections.unmodifiableMap(transportTypes); } + public static TransportType fromValue(String value) { + return TRANSPORT_TYPES.get(value); + } + + + private final String value; + + private final HttpMethod httpMethod; - private TransportType(String value, HttpMethod httpMethod, String... headerHints) { + private final List headerHints; + + + TransportType(String value, HttpMethod httpMethod, String... headerHints) { this.value = value; this.httpMethod = httpMethod; this.headerHints = Arrays.asList(headerHints); @@ -77,9 +83,6 @@ public enum TransportType { return this.value; } - /** - * The HTTP method for this transport. - */ public HttpMethod getHttpMethod() { return this.httpMethod; } @@ -88,6 +91,10 @@ public enum TransportType { return this.headerHints.contains("no_cache"); } + public boolean sendsSessionCookie() { + return this.headerHints.contains("jsessionid"); + } + public boolean supportsCors() { return this.headerHints.contains("cors"); } @@ -96,13 +103,6 @@ public enum TransportType { return this.headerHints.contains("cors") || this.headerHints.contains("origin"); } - public boolean sendsSessionCookie() { - return this.headerHints.contains("jsessionid"); - } - - public static TransportType fromValue(String value) { - return TRANSPORT_TYPES.get(value); - } @Override public String toString() { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java index 6b07be82411..aea9b0e802d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java @@ -16,9 +16,6 @@ package org.springframework.web.socket.sockjs.transport.handler; -import static org.junit.Assert.*; -import static org.mockito.BDDMockito.*; - import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -43,11 +40,15 @@ import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.session.StubSockJsServiceConfig; import org.springframework.web.socket.sockjs.transport.session.TestSockJsSession; +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + /** * Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}. * * @author Rossen Stoyanchev * @author Sebastien Deleuze + * @author Ben Kiefer */ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { @@ -186,6 +187,18 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { assertEquals(200, this.servletResponse.getStatus()); } + @Test // SPR-13545 + public void handleInvalidTransportType() throws Exception { + String sockJsPath = sessionUrlPrefix + "invalid"; + setRequest("POST", sockJsPrefix + sockJsPath); + this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com"); + this.servletRequest.setServerName("mydomain2.com"); + this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + + assertEquals(404, this.servletResponse.getStatus()); + } + @Test public void handleTransportRequestXhrOptions() throws Exception { String sockJsPath = sessionUrlPrefix + "xhr"; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpReceivingTransportHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpReceivingTransportHandlerTests.java index c85a6cf4b9b..070f8d7816f 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpReceivingTransportHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpReceivingTransportHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package org.springframework.web.socket.sockjs.transport.handler; -import org.junit.Before; import org.junit.Test; import org.springframework.http.MediaType; @@ -37,14 +36,7 @@ import static org.mockito.BDDMockito.*; * * @author Rossen Stoyanchev */ -public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTests { - - - @Override - @Before - public void setUp() { - super.setUp(); - } +public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTests { @Test public void readMessagesXhr() throws Exception { @@ -73,9 +65,7 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest assertEquals("ok", this.servletResponse.getContentAsString()); } - // SPR-10621 - - @Test + @Test // SPR-10621 public void readMessagesJsonpFormEncodedWithEncoding() throws Exception { this.servletRequest.setContent("d=[\"x\"]".getBytes("UTF-8")); this.servletRequest.setContentType("application/x-www-form-urlencoded;charset=UTF-8"); @@ -102,9 +92,7 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest @Test public void delegateMessageException() throws Exception { - StubSockJsServiceConfig sockJsConfig = new StubSockJsServiceConfig(); - this.servletRequest.setContent("[\"x\"]".getBytes("UTF-8")); WebSocketHandler wsHandler = mock(WebSocketHandler.class); @@ -126,7 +114,6 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest private void handleRequest(AbstractHttpReceivingTransportHandler transportHandler) throws Exception { - WebSocketHandler wsHandler = mock(WebSocketHandler.class); AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler, null); @@ -138,7 +125,6 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest } private void handleRequestAndExpectFailure() throws Exception { - resetResponse(); WebSocketHandler wsHandler = mock(WebSocketHandler.class); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java index 7747fa815ee..bc780688403 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,9 +62,9 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests setRequest("POST", "/"); } + @Test public void handleRequestXhr() throws Exception { - XhrPollingTransportHandler transportHandler = new XhrPollingTransportHandler(); transportHandler.initialize(this.sockJsConfig); @@ -91,7 +91,6 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests @Test public void jsonpTransport() throws Exception { - JsonpPollingTransportHandler transportHandler = new JsonpPollingTransportHandler(); transportHandler.initialize(this.sockJsConfig); PollingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); @@ -114,7 +113,6 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests @Test public void handleRequestXhrStreaming() throws Exception { - XhrStreamingTransportHandler transportHandler = new XhrStreamingTransportHandler(); transportHandler.initialize(this.sockJsConfig); AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); @@ -128,7 +126,6 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests @Test public void htmlFileTransport() throws Exception { - HtmlFileTransportHandler transportHandler = new HtmlFileTransportHandler(); transportHandler.initialize(this.sockJsConfig); StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); @@ -151,7 +148,6 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests @Test public void eventSourceTransport() throws Exception { - EventSourceTransportHandler transportHandler = new EventSourceTransportHandler(); transportHandler.initialize(this.sockJsConfig); StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); @@ -165,7 +161,6 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests @Test public void frameFormats() throws Exception { - this.servletRequest.setQueryString("c=callback"); this.servletRequest.addParameter("c", "callback"); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandlerTests.java index 8bc7bc679d0..1d8d0efe8de 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,11 +31,11 @@ import static org.mockito.Mockito.*; /** * Unit tests for {@link SockJsWebSocketHandler}. + * * @author Rossen Stoyanchev */ public class SockJsWebSocketHandlerTests { - @Test public void getSubProtocols() throws Exception { SubscribableChannel channel = mock(SubscribableChannel.class);