diff --git a/build.gradle b/build.gradle index 8d14df6d261..f64da0967c7 100644 --- a/build.gradle +++ b/build.gradle @@ -558,15 +558,17 @@ project("spring-web") { optional("org.codehaus.jackson:jackson-mapper-asl:1.9.12") optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") optional("taglibs:standard:1.1.2") - optional("org.eclipse.jetty:jetty-servlet:8.1.5.v20120716") { + optional("org.eclipse.jetty:jetty-servlet:9.0.5.v20130815") { exclude group: "org.eclipse.jetty.orbit", module: "javax.servlet" } - optional("org.eclipse.jetty:jetty-server:8.1.5.v20120716") { + optional("org.eclipse.jetty:jetty-server:9.0.5.v20130815") { exclude group: "org.eclipse.jetty.orbit", module: "javax.servlet" } optional("log4j:log4j:1.2.17") testCompile(project(":spring-context-support")) // for JafMediaTypeFactory testCompile("xmlunit:xmlunit:1.3") + testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}") + testCompile("log4j:log4j:1.2.17") } // pick up ContextLoader.properties in src/main diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java index 1d3dcb30a04..fa780ed94df 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java @@ -69,7 +69,7 @@ public interface SubProtocolHandler { /** * Resolve the session id from the given message or return {@code null}. * - * @param the message to resolve the session id from + * @param message the message to resolve the session id from */ String resolveSessionId(Message message); diff --git a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java index 136ff118686..f5391599d85 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java +++ b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java @@ -92,16 +92,6 @@ public class HttpHeaders implements MultiValueMap, Serializable private static final String ORIGIN = "Origin"; - private static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept"; - - private static final String SEC_WEBSOCKET_EXTENSIONS = "Sec-WebSocket-Extensions"; - - private static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key"; - - private static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; - - private static final String SEC_WEBSOCKET_VERSION = "Sec-WebSocket-Version"; - private static final String PRAGMA = "Pragma"; private static final String UPGARDE = "Upgrade"; @@ -452,7 +442,7 @@ public class HttpHeaders implements MultiValueMap, Serializable set(IF_NONE_MATCH, toCommaDelimitedString(ifNoneMatchList)); } - private String toCommaDelimitedString(List list) { + protected String toCommaDelimitedString(List list) { StringBuilder builder = new StringBuilder(); for (Iterator iterator = list.iterator(); iterator.hasNext();) { String ifNoneMatch = iterator.next(); @@ -472,7 +462,7 @@ public class HttpHeaders implements MultiValueMap, Serializable return getFirstValueAsList(IF_NONE_MATCH); } - private List getFirstValueAsList(String header) { + protected List getFirstValueAsList(String header) { List result = new ArrayList(); String value = getFirst(header); @@ -537,114 +527,6 @@ public class HttpHeaders implements MultiValueMap, Serializable return getFirst(ORIGIN); } - /** - * Sets the (new) value of the {@code Sec-WebSocket-Accept} header. - * @param secWebSocketAccept the value of the header - */ - public void setSecWebSocketAccept(String secWebSocketAccept) { - set(SEC_WEBSOCKET_ACCEPT, secWebSocketAccept); - } - - /** - * Returns the value of the {@code Sec-WebSocket-Accept} header. - * @return the value of the header - */ - public String getSecWebSocketAccept() { - return getFirst(SEC_WEBSOCKET_ACCEPT); - } - - /** - * Returns the value of the {@code Sec-WebSocket-Extensions} header. - * @return the value of the header - */ - public List getSecWebSocketExtensions() { - List values = get(SEC_WEBSOCKET_EXTENSIONS); - if (CollectionUtils.isEmpty(values)) { - return Collections.emptyList(); - } - else if (values.size() == 1) { - return getFirstValueAsList(SEC_WEBSOCKET_EXTENSIONS); - } - else { - return values; - } - } - - /** - * Sets the (new) value of the {@code Sec-WebSocket-Extensions} header. - * @param secWebSocketExtensions the value of the header - */ - public void setSecWebSocketExtensions(List secWebSocketExtensions) { - set(SEC_WEBSOCKET_EXTENSIONS, toCommaDelimitedString(secWebSocketExtensions)); - } - - /** - * Sets the (new) value of the {@code Sec-WebSocket-Key} header. - * @param secWebSocketKey the value of the header - */ - public void setSecWebSocketKey(String secWebSocketKey) { - set(SEC_WEBSOCKET_KEY, secWebSocketKey); - } - - /** - * Returns the value of the {@code Sec-WebSocket-Key} header. - * @return the value of the header - */ - public String getSecWebSocketKey() { - return getFirst(SEC_WEBSOCKET_KEY); - } - - /** - * Sets the (new) value of the {@code Sec-WebSocket-Protocol} header. - * @param secWebSocketProtocol the value of the header - */ - public void setSecWebSocketProtocol(String secWebSocketProtocol) { - if (secWebSocketProtocol != null) { - set(SEC_WEBSOCKET_PROTOCOL, secWebSocketProtocol); - } - } - - /** - * Sets the (new) value of the {@code Sec-WebSocket-Protocol} header. - * @param secWebSocketProtocols the value of the header - */ - public void setSecWebSocketProtocol(List secWebSocketProtocols) { - set(SEC_WEBSOCKET_PROTOCOL, toCommaDelimitedString(secWebSocketProtocols)); - } - - /** - * Returns the value of the {@code Sec-WebSocket-Key} header. - * @return the value of the header - */ - public List getSecWebSocketProtocol() { - List values = get(SEC_WEBSOCKET_PROTOCOL); - if (CollectionUtils.isEmpty(values)) { - return Collections.emptyList(); - } - else if (values.size() == 1) { - return getFirstValueAsList(SEC_WEBSOCKET_PROTOCOL); - } - else { - return values; - } - } - - /** - * Sets the (new) value of the {@code Sec-WebSocket-Version} header. - * @param secWebSocketKey the value of the header - */ - public void setSecWebSocketVersion(String secWebSocketVersion) { - set(SEC_WEBSOCKET_VERSION, secWebSocketVersion); - } - - /** - * Returns the value of the {@code Sec-WebSocket-Version} header. - * @return the value of the header - */ - public String getSecWebSocketVersion() { - return getFirst(SEC_WEBSOCKET_VERSION); - } - /** * Sets the (new) value of the {@code Pragma} header. * @param pragma the value of the header diff --git a/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java index 66e91dda57b..e859bfca3fd 100644 --- a/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java @@ -19,7 +19,9 @@ package org.springframework.web.bind.support; import java.io.IOException; import java.util.List; +import javax.servlet.MultipartConfigElement; import javax.servlet.ServletException; +import javax.servlet.annotation.MultipartConfig; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -79,8 +81,15 @@ public class WebRequestDataBinderIntegrationTests { partsServlet = new PartsServlet(); partListServlet = new PartListServlet(); - handler.addServlet(new ServletHolder(partsServlet), "/parts"); - handler.addServlet(new ServletHolder(partListServlet), "/partlist"); + MultipartConfigElement multipartConfig = new MultipartConfigElement(""); + + ServletHolder holder = new ServletHolder(partsServlet); + holder.getRegistration().setMultipartConfig(multipartConfig); + handler.addServlet(holder, "/parts"); + + holder = new ServletHolder(partListServlet); + holder.getRegistration().setMultipartConfig(multipartConfig); + handler.addServlet(holder, "/partlist"); jettyServer.setHandler(handler); jettyServer.start(); } @@ -178,7 +187,6 @@ public class WebRequestDataBinderIntegrationTests { @SuppressWarnings("serial") private static class PartsServlet extends AbstractStandardMultipartServlet { - } private static class PartListBean { @@ -197,7 +205,6 @@ public class WebRequestDataBinderIntegrationTests { @SuppressWarnings("serial") private static class PartListServlet extends AbstractStandardMultipartServlet { - } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java index b935e21889d..7d5d1a5e4de 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java @@ -20,9 +20,11 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.List; import java.util.Map; import org.springframework.http.HttpHeaders; +import org.springframework.web.socket.support.WebSocketExtension; /** * A WebSocket session abstraction. Allows sending messages over a WebSocket connection @@ -79,6 +81,12 @@ public interface WebSocketSession { */ String getAcceptedProtocol(); + /** + * Return the negotiated extensions or {@code null} if none was specified or + * negotiated successfully. + */ + List getExtensions(); + /** * Return whether the connection is still open. */ diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java index b7bc0a8130a..5e507f6d065 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java @@ -20,8 +20,11 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.ArrayList; +import java.util.List; import java.util.Map; +import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig; import org.springframework.http.HttpHeaders; import org.springframework.util.ObjectUtils; import org.springframework.web.socket.BinaryMessage; @@ -29,6 +32,7 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.PongMessage; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketSession; /** @@ -42,6 +46,8 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion extensions; + private final Principal principal; @@ -104,6 +110,19 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion getExtensions() { + checkNativeSessionInitialized(); + if(this.extensions == null) { + List source = getNativeSession().getUpgradeResponse().getExtensions(); + this.extensions = new ArrayList(source.size()); + for(ExtensionConfig e : source) { + this.extensions.add(new WebSocketExtension(e.getName(), e.getParameters())); + } + } + return this.extensions; + } + @Override public boolean isOpen() { return ((getNativeSession() != null) && getNativeSession().isOpen()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java index 7ddd3a87187..d01d33b9fd0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java @@ -20,10 +20,14 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import javax.websocket.CloseReason; import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.Extension; import org.springframework.http.HttpHeaders; import org.springframework.util.StringUtils; @@ -32,7 +36,9 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.PongMessage; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.support.WebSocketHttpHeaders; /** * A {@link WebSocketSession} for use with the standard WebSocket for Java API. @@ -48,22 +54,24 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion extensions; + /** * Class constructor. * - * @param handshakeHeaders the headers of the handshake request + * @param headers the headers of the handshake request * @param handshakeAttributes attributes from the HTTP handshake to make available * through the WebSocket session * @param localAddress the address on which the request was received * @param remoteAddress the address of the remote client */ - public StandardWebSocketSession(HttpHeaders handshakeHeaders, Map handshakeAttributes, + public StandardWebSocketSession(HttpHeaders headers, Map handshakeAttributes, InetSocketAddress localAddress, InetSocketAddress remoteAddress) { super(handshakeAttributes); - handshakeHeaders = (handshakeHeaders != null) ? handshakeHeaders : new HttpHeaders(); - this.handshakeHeaders = HttpHeaders.readOnlyHttpHeaders(handshakeHeaders); + headers = (headers != null) ? headers : new HttpHeaders(); + this.handshakeHeaders = HttpHeaders.readOnlyHttpHeaders(headers); this.localAddress = localAddress; this.remoteAddress = remoteAddress; } @@ -108,6 +116,19 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion getExtensions() { + checkNativeSessionInitialized(); + if(this.extensions == null) { + List source = getNativeSession().getNegotiatedExtensions(); + this.extensions = new ArrayList(source.size()); + for(Extension e : source) { + this.extensions.add(new WebSocketExtension.StandardToWebSocketExtensionAdapter(e)); + } + } + return this.extensions; + } + @Override public boolean isOpen() { return ((getNativeSession() != null) && getNativeSession().isOpen()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java index ca8faaa7881..5c01e61cf07 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java @@ -31,6 +31,8 @@ import org.springframework.util.Assert; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.support.WebSocketExtension; +import org.springframework.web.socket.support.WebSocketHttpHeaders; import org.springframework.web.util.UriComponentsBuilder; @@ -44,19 +46,19 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { protected final Log logger = LogFactory.getLog(getClass()); - private static final Set disallowedHeaders = new HashSet(); + private static final Set specialHeaders = new HashSet(); static { - disallowedHeaders.add("cache-control"); - disallowedHeaders.add("cookie"); - disallowedHeaders.add("connection"); - disallowedHeaders.add("host"); - disallowedHeaders.add("sec-websocket-extensions"); - disallowedHeaders.add("sec-websocket-key"); - disallowedHeaders.add("sec-websocket-protocol"); - disallowedHeaders.add("sec-websocket-version"); - disallowedHeaders.add("pragma"); - disallowedHeaders.add("upgrade"); + specialHeaders.add("cache-control"); + specialHeaders.add("cookie"); + specialHeaders.add("connection"); + specialHeaders.add("host"); + specialHeaders.add("sec-websocket-extensions"); + specialHeaders.add("sec-websocket-key"); + specialHeaders.add("sec-websocket-protocol"); + specialHeaders.add("sec-websocket-version"); + specialHeaders.add("pragma"); + specialHeaders.add("upgrade"); } @@ -71,7 +73,7 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { @Override public final ListenableFuture doHandshake(WebSocketHandler webSocketHandler, - HttpHeaders headers, URI uri) { + WebSocketHttpHeaders headers, URI uri) { Assert.notNull(webSocketHandler, "webSocketHandler must not be null"); Assert.notNull(uri, "uri must not be null"); @@ -86,18 +88,19 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { HttpHeaders headersToUse = new HttpHeaders(); if (headers != null) { for (String header : headers.keySet()) { - if (!disallowedHeaders.contains(header.toLowerCase())) { + if (!specialHeaders.contains(header.toLowerCase())) { headersToUse.put(header, headers.get(header)); } } } - List subProtocols = new ArrayList(); - if ((headers != null) && (headers.getSecWebSocketProtocol() != null)) { - subProtocols.addAll(headers.getSecWebSocketProtocol()); - } + List subProtocols = ((headers != null) && (headers.getSecWebSocketProtocol() != null)) ? + headers.getSecWebSocketProtocol() : Collections.emptyList(); + + List extensions = ((headers != null) && (headers.getSecWebSocketExtensions() != null)) ? + headers.getSecWebSocketExtensions() : Collections.emptyList(); - return doHandshakeInternal(webSocketHandler, headersToUse, uri, subProtocols, + return doHandshakeInternal(webSocketHandler, headersToUse, uri, subProtocols, extensions, Collections.emptyMap()); } @@ -109,12 +112,14 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { * headers filtered out, never {@code null} * @param uri the target URI for the handshake, never {@code null} * @param subProtocols requested sub-protocols, or an empty list + * @param extensions requested WebSocket extensions, or an empty list * @param handshakeAttributes attributes to make available via * {@link WebSocketSession#getHandshakeAttributes()}; currently always an empty map. * * @return the established WebSocket session wrapped in a ListenableFuture. */ protected abstract ListenableFuture doHandshakeInternal(WebSocketHandler webSocketHandler, - HttpHeaders headers, URI uri, List subProtocols, Map handshakeAttributes); + HttpHeaders headers, URI uri, List subProtocols, List extensions, + Map handshakeAttributes); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketClient.java index 8711be594d9..ad6dfa6ce65 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketClient.java @@ -22,6 +22,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.support.WebSocketHttpHeaders; /** * Contract for initiating a WebSocket request. As an alternative considering using the @@ -38,6 +39,7 @@ public interface WebSocketClient { ListenableFuture doHandshake(WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVariables); - ListenableFuture doHandshake(WebSocketHandler webSocketHandler, HttpHeaders headers, URI uri); + ListenableFuture doHandshake(WebSocketHandler webSocketHandler, + WebSocketHttpHeaders headers, URI uri); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java index c7d1aded1a4..7a115312e26 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java @@ -25,6 +25,7 @@ import org.springframework.util.concurrent.ListenableFutureCallback; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator; +import org.springframework.web.socket.support.WebSocketHttpHeaders; /** * A WebSocket connection manager that is given a URI, a {@link WebSocketClient}, and a @@ -43,7 +44,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport { private WebSocketSession webSocketSession; - private HttpHeaders headers = new HttpHeaders(); + private WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); private final boolean syncClientLifecycle; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java index bcbebb1c007..55421c5866b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java @@ -20,17 +20,14 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.URI; import java.net.UnknownHostException; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.concurrent.Callable; -import javax.websocket.ClientEndpointConfig; +import javax.websocket.*; import javax.websocket.ClientEndpointConfig.Configurator; -import javax.websocket.ContainerProvider; -import javax.websocket.Endpoint; -import javax.websocket.HandshakeResponse; -import javax.websocket.WebSocketContainer; import org.springframework.core.task.AsyncListenableTaskExecutor; import org.springframework.core.task.SimpleAsyncTaskExecutor; @@ -43,6 +40,7 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.StandardWebSocketSession; import org.springframework.web.socket.client.AbstractWebSocketClient; +import org.springframework.web.socket.support.WebSocketExtension; /** * Initiates WebSocket requests to a WebSocket server programatically through the standard @@ -96,7 +94,8 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { @Override protected ListenableFuture doHandshakeInternal(WebSocketHandler webSocketHandler, - HttpHeaders headers, final URI uri, List protocols, Map handshakeAttributes) { + HttpHeaders headers, final URI uri, List protocols, + List extensions, Map handshakeAttributes) { int port = getPort(uri); InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port); @@ -108,6 +107,7 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { final ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create(); configBuidler.configurator(new StandardWebSocketClientConfigurator(headers)); configBuidler.preferredSubprotocols(protocols); + configBuidler.extensions(adaptExtensions(extensions)); final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session); return this.taskExecutor.submitListenable(new Callable() { @@ -119,6 +119,14 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { }); } + private static List adaptExtensions(List extensions) { + List result = new ArrayList(); + for (WebSocketExtension e : extensions) { + result.add(new WebSocketExtension.WebSocketToStandardExtensionAdapter(e)); + } + return result; + } + private InetAddress getLocalHost() { try { return InetAddress.getLocalHost(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java index 5ed8e8914d5..d76ec05d0a7 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java @@ -38,6 +38,7 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSession; import org.springframework.web.socket.client.AbstractWebSocketClient; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; @@ -169,10 +170,16 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma @Override public ListenableFuture doHandshakeInternal(WebSocketHandler wsHandler, - HttpHeaders headers, final URI uri, List protocols, Map handshakeAttributes) { + HttpHeaders headers, final URI uri, List protocols, + List extensions, Map handshakeAttributes) { final ClientUpgradeRequest request = new ClientUpgradeRequest(); request.setSubProtocols(protocols); + + for (WebSocketExtension e : extensions) { + request.addExtensions(new WebSocketExtension.WebSocketToJettyExtensionConfigAdapter(e)); + } + for (String header : headers.keySet()) { request.setHeader(header, headers.get(header)); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java index 6c0bfa0e84a..a5a8338b1a5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java @@ -32,13 +32,16 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.support.WebSocketHttpHeaders; /** * A default {@link HandshakeHandler} implementation. Performs initial validation of the * WebSocket handshake request -- possibly rejecting it through the appropriate HTTP * status code -- while also allowing sub-classes to override various parts of the - * negotiation process (e.g. origin validation, sub-protocol negotiation, etc). + * negotiation process (e.g. origin validation, sub-protocol negotiation, + * extensions negotiation, etc). * *

* If the negotiation succeeds, the actual upgrade is delegated to a server-specific @@ -142,8 +145,10 @@ public class DefaultHandshakeHandler implements HandshakeHandler { public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHeaders()); + if (logger.isDebugEnabled()) { - logger.debug("Initiating handshake for " + request.getURI() + ", headers=" + request.getHeaders()); + logger.debug("Initiating handshake for " + request.getURI() + ", headers=" + headers); } try { @@ -153,16 +158,16 @@ public class DefaultHandshakeHandler implements HandshakeHandler { logger.debug("Only HTTP GET is allowed, current method is " + request.getMethod()); return false; } - if (!"WebSocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) { + if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { handleInvalidUpgradeHeader(request, response); return false; } - if (!request.getHeaders().getConnection().contains("Upgrade") && - !request.getHeaders().getConnection().contains("upgrade")) { + if (!headers.getConnection().contains("Upgrade") && + !headers.getConnection().contains("upgrade")) { handleInvalidConnectHeader(request, response); return false; } - if (!isWebSocketVersionSupported(request)) { + if (!isWebSocketVersionSupported(headers)) { handleWebSocketVersionNotSupported(request, response); return false; } @@ -170,7 +175,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { response.setStatusCode(HttpStatus.FORBIDDEN); return false; } - String wsKey = request.getHeaders().getSecWebSocketKey(); + String wsKey = headers.getSecWebSocketKey(); if (wsKey == null) { logger.debug("Missing \"Sec-WebSocket-Key\" header"); response.setStatusCode(HttpStatus.BAD_REQUEST); @@ -182,13 +187,17 @@ public class DefaultHandshakeHandler implements HandshakeHandler { "Response update failed during upgrade to WebSocket, uri=" + request.getURI(), ex); } - String subProtocol = selectProtocol(request.getHeaders().getSecWebSocketProtocol()); + String subProtocol = selectProtocol(headers.getSecWebSocketProtocol()); + + List requested = headers.getSecWebSocketExtensions(); + List supported = this.requestUpgradeStrategy.getSupportedExtensions(request); + List extensions = filterRequestedExtensions(request, requested, supported); if (logger.isDebugEnabled()) { - logger.debug("Upgrading request, sub-protocol=" + subProtocol); + logger.debug("Upgrading request, sub-protocol=" + subProtocol + ", extensions=" + extensions); } - this.requestUpgradeStrategy.upgrade(request, response, subProtocol, wsHandler, attributes); + this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, wsHandler, attributes); return true; } @@ -205,8 +214,8 @@ public class DefaultHandshakeHandler implements HandshakeHandler { response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes("UTF-8")); } - protected boolean isWebSocketVersionSupported(ServerHttpRequest request) { - String version = request.getHeaders().getSecWebSocketVersion(); + protected boolean isWebSocketVersionSupported(WebSocketHttpHeaders httpHeaders) { + String version = httpHeaders.getSecWebSocketVersion(); String[] supportedVersions = getSupportedVerions(); if (logger.isDebugEnabled()) { logger.debug("Requested version=" + version + ", supported=" + Arrays.toString(supportedVersions)); @@ -229,7 +238,8 @@ public class DefaultHandshakeHandler implements HandshakeHandler { protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) { logger.debug("WebSocket version not supported " + request.getHeaders().get("Sec-WebSocket-Version")); response.setStatusCode(HttpStatus.UPGRADE_REQUIRED); - response.getHeaders().setSecWebSocketVersion(StringUtils.arrayToCommaDelimitedString(getSupportedVerions())); + response.getHeaders().put(WebSocketHttpHeaders.SEC_WEBSOCKET_VERSION, Arrays.asList( + StringUtils.arrayToCommaDelimitedString(getSupportedVerions()))); } protected boolean isValidOrigin(ServerHttpRequest request) { @@ -254,4 +264,27 @@ public class DefaultHandshakeHandler implements HandshakeHandler { return null; } + /** + * Filter the list of requested WebSocket extensions. + *

+ * By default all request extensions are returned. The WebSocket server will further + * compare the requested extensions against the list of supported extensions and + * return only the ones that are both requested and supported. + * + * @param request the current request + * @param requested the list of extensions requested by the client + * @param supported the list of extensions supported by the server + * + * @return the selected extensions or an empty list + */ + protected List filterRequestedExtensions(ServerHttpRequest request, + List requested, List supported) { + + if (requested != null) { + if (logger.isDebugEnabled()) { + logger.debug("Requested extension(s): " + requested + ", supported extension(s): " + supported); + } + } + return requested; + } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java index f2576f6559d..be0ad444e9a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java @@ -45,7 +45,8 @@ public interface HandshakeHandler { * {@link PerConnectionWebSocketHandler} for providing a handler with * per-connection lifecycle. * @param attributes handshake request specific attributes to be set on the WebSocket - * session and thus made available to the {@link WebSocketHandler} + * session via {@link HandshakeInterceptor} and thus made available to the + * {@link WebSocketHandler}; * * @return whether the handshake negotiation was successful or not. In either case the * response status, headers, and body will have been updated to reflect the diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java index b7814264455..f3e950f2212 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java @@ -16,10 +16,12 @@ package org.springframework.web.socket.server; +import java.util.List; import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; /** @@ -36,22 +38,33 @@ public interface RequestUpgradeStrategy { */ String[] getSupportedVersions(); + /** + * @return the WebSocket protocol extensions supported by the underlying WebSocket server. + */ + List getSupportedExtensions(ServerHttpRequest request); + /** * Perform runtime specific steps to complete the upgrade. Invoked after successful * negotiation of the handshake request. * + * * @param request the current request * @param response the current response - * @param acceptedProtocol the accepted sub-protocol, if any + * @param selectedProtocol the selected sub-protocol, if any + * @param selectedExtensions the selected WebSocket protocol extensions * @param wsHandler the handler for WebSocket messages - * @param attributes handshake context attributes + * @param attributes handshake request specific attributes to be set on the WebSocket + * session via {@link org.springframework.web.socket.server.HandshakeInterceptor} + * and thus made available to the + * {@link org.springframework.web.socket.WebSocketHandler}; * * @throws HandshakeFailureException thrown when handshake processing failed to - * complete due to an internal, unrecoverable error, i.e. a server error as - * opposed to a failure to successfully negotiate the requirements of the - * handshake request. + * complete due to an internal, unrecoverable error, i.e. a server error as + * opposed to a failure to successfully negotiate the requirements of the + * handshake request. */ - void upgrade(ServerHttpRequest request, ServerHttpResponse response, String acceptedProtocol, + void upgrade(ServerHttpRequest request, ServerHttpResponse response, + String selectedProtocol, List selectedExtensions, WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java index 1d41f06bb14..221ac25634a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java @@ -16,8 +16,7 @@ package org.springframework.web.socket.server.config; -import org.eclipse.jetty.websocket.server.WebSocketHandler; - +import org.springframework.web.socket.WebSocketHandler; /** diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java index 6f0b7838879..392eb2a2dd5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.lang.reflect.Constructor; import java.net.URI; import java.util.Arrays; +import java.util.List; import java.util.Random; import javax.servlet.ServletException; @@ -27,6 +28,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.DeploymentException; import javax.websocket.Endpoint; +import javax.websocket.Extension; import org.glassfish.tyrus.core.ComponentProviderService; import org.glassfish.tyrus.core.EndpointWrapper; @@ -66,6 +68,7 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt private final static Random random = new Random(); + @Override public String[] getSupportedVersions() { return StringUtils.commaDelimitedListToStringArray(Version.getSupportedWireProtocolVersions()); @@ -73,7 +76,8 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt @Override public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, - String selectedProtocol, Endpoint endpoint) throws HandshakeFailureException { + String selectedProtocol, List selectedExtensions, + Endpoint endpoint) throws HandshakeFailureException { Assert.isTrue(request instanceof ServletServerHttpRequest); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -81,7 +85,9 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt Assert.isTrue(response instanceof ServletServerHttpResponse); HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse(); - WebSocketApplication webSocketApplication = createTyrusEndpoint(servletRequest, endpoint, selectedProtocol); + WebSocketApplication webSocketApplication = createTyrusEndpoint(servletRequest, + endpoint, selectedProtocol, selectedExtensions); + WebSocketEngine webSocketEngine = WebSocketEngine.getEngine(); try { @@ -133,12 +139,13 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt } private WebSocketApplication createTyrusEndpoint(HttpServletRequest request, - Endpoint endpoint, String selectedProtocol) { + Endpoint endpoint, String selectedProtocol, List selectedExtensions) { // shouldn't matter for processing but must be unique String endpointPath = "/" + random.nextLong(); ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(endpointPath, endpoint); endpointConfig.setSubprotocols(Arrays.asList(selectedProtocol)); + endpointConfig.setExtensions(selectedExtensions); return createTyrusEndpoint(new EndpointWrapper(endpoint, endpointConfig, ComponentProviderService.create(), null, "/", new ErrorCollector(), endpointConfig.getConfigurator())); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java index 925540c3fcb..2bbcec5f4a8 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java @@ -17,15 +17,22 @@ package org.springframework.web.socket.server.support; import java.net.InetSocketAddress; -import java.util.Map; +import java.util.*; +import javax.servlet.ServletContext; +import javax.servlet.http.HttpServletRequest; import javax.websocket.Endpoint; +import javax.websocket.Extension; +import javax.websocket.WebSocketContainer; +import javax.websocket.server.ServerContainer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.StandardWebSocketSession; @@ -43,9 +50,34 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS protected final Log logger = LogFactory.getLog(getClass()); + private volatile List extensions; + @Override - public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String acceptedProtocol, + public List getSupportedExtensions(ServerHttpRequest request) { + if(this.extensions == null) { + HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); + this.extensions = getInstalledExtensions(getContainer(servletRequest)); + } + return this.extensions; + } + + protected ServerContainer getContainer(HttpServletRequest request) { + ServletContext servletContext = request.getServletContext(); + return (ServerContainer) servletContext.getAttribute("javax.websocket.server.ServerContainer"); + } + + protected List getInstalledExtensions(WebSocketContainer container) { + List result = new ArrayList(); + for(Extension e : container.getInstalledExtensions()) { + result.add(new WebSocketExtension.StandardToWebSocketExtensionAdapter(e)); + } + return result; + } + + @Override + public void upgrade(ServerHttpRequest request, ServerHttpResponse response, + String selectedProtocol, List selectedExtensions, WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { HttpHeaders headers = request.getHeaders(); @@ -56,10 +88,16 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS StandardWebSocketSession session = new StandardWebSocketSession(headers, attributes, localAddr, remoteAddr); StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, session); - upgradeInternal(request, response, acceptedProtocol, endpoint); + List extensions = new ArrayList(); + for (WebSocketExtension e : selectedExtensions) { + extensions.add(new WebSocketExtension.WebSocketToStandardExtensionAdapter(e)); + } + + upgradeInternal(request, response, selectedProtocol, extensions, endpoint); } protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, - String selectedProtocol, Endpoint endpoint) throws HandshakeFailureException; + String selectedProtocol, List selectedExtensions, + Endpoint endpoint) throws HandshakeFailureException; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java index 14f6d1ce6b1..a2dfcdd537b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java @@ -17,6 +17,8 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import javax.servlet.http.HttpServletRequest; @@ -25,15 +27,18 @@ import javax.servlet.http.HttpServletResponse; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeResponse; import org.eclipse.jetty.websocket.api.WebSocketPolicy; +import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig; import org.eclipse.jetty.websocket.server.HandshakeRFC6455; import org.eclipse.jetty.websocket.server.WebSocketServerFactory; -import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; import org.eclipse.jetty.websocket.servlet.WebSocketCreator; +import org.springframework.core.NamedThreadLocal; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSession; @@ -50,10 +55,14 @@ import org.springframework.web.socket.server.RequestUpgradeStrategy; */ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { - private static final String WS_HANDLER_ATTR_NAME = JettyRequestUpgradeStrategy.class.getName() + ".WS_LISTENER"; + private static final ThreadLocal wsContainerHolder = + new NamedThreadLocal("WebSocket Handler Container"); + private WebSocketServerFactory factory; + private volatile List supportedExtensions; + /** * Default constructor that creates {@link WebSocketServerFactory} through its default @@ -74,8 +83,14 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { this.factory.setCreator(new WebSocketCreator() { @Override public Object createWebSocket(UpgradeRequest request, UpgradeResponse response) { - Assert.isInstanceOf(ServletUpgradeRequest.class, request); - return ((ServletUpgradeRequest) request).getServletAttributes().get(WS_HANDLER_ATTR_NAME); + + WebSocketHandlerContainer container = wsContainerHolder.get(); + Assert.state(container != null, "Expected WebSocketHandlerContainer"); + + response.setAcceptedSubProtocol(container.getSelectedProtocol()); + response.setExtensions(container.getExtensionConfigs()); + + return container.getHandler(); } }); try { @@ -92,9 +107,26 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { return new String[] { String.valueOf(HandshakeRFC6455.VERSION) }; } + @Override + public List getSupportedExtensions(ServerHttpRequest request) { + if (this.supportedExtensions == null) { + this.supportedExtensions = getWebSocketExtensions(); + } + return this.supportedExtensions; + } + + private List getWebSocketExtensions() { + List result = new ArrayList(); + for(String name : this.factory.getExtensionFactory().getExtensionNames()) { + result.add(new WebSocketExtension(name)); + } + return result; + } + @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, - String protocol, WebSocketHandler wsHandler, Map attrs) throws HandshakeFailureException { + String selectedProtocol, List selectedExtensions, + WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { Assert.isInstanceOf(ServletServerHttpRequest.class, request); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -104,11 +136,14 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { Assert.isTrue(this.factory.isUpgradeRequest(servletRequest, servletResponse), "Not a WebSocket handshake"); - JettyWebSocketSession session = new JettyWebSocketSession(request.getPrincipal(), attrs); + JettyWebSocketSession session = new JettyWebSocketSession(request.getPrincipal(), attributes); JettyWebSocketHandlerAdapter handlerAdapter = new JettyWebSocketHandlerAdapter(wsHandler, session); + WebSocketHandlerContainer container = + new WebSocketHandlerContainer(handlerAdapter, selectedProtocol, selectedExtensions); + try { - servletRequest.setAttribute(WS_HANDLER_ATTR_NAME, handlerAdapter); + wsContainerHolder.set(container); this.factory.acceptWebSocket(servletRequest, servletResponse); } catch (IOException ex) { @@ -116,8 +151,46 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { "Response update failed during upgrade to WebSocket, uri=" + request.getURI(), ex); } finally { - servletRequest.removeAttribute(WS_HANDLER_ATTR_NAME); + wsContainerHolder.remove(); } } + + private static class WebSocketHandlerContainer { + + private final JettyWebSocketHandlerAdapter handler; + + private final String selectedProtocol; + + private final List extensionConfigs; + + private WebSocketHandlerContainer(JettyWebSocketHandlerAdapter handler, String protocol, + List extensions) { + + this.handler = handler; + this.selectedProtocol = protocol; + + if (CollectionUtils.isEmpty(extensions)) { + this.extensionConfigs = null; + } + else { + this.extensionConfigs = new ArrayList(); + for (WebSocketExtension e : extensions) { + this.extensionConfigs.add(new WebSocketExtension.WebSocketToJettyExtensionConfigAdapter(e)); + } + } + } + + private JettyWebSocketHandlerAdapter getHandler() { + return this.handler; + } + + private String getSelectedProtocol() { + return this.selectedProtocol; + } + + private List getExtensionConfigs() { + return this.extensionConfigs; + } + } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java index 4a96dff264d..661425ae3d3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java @@ -19,13 +19,14 @@ package org.springframework.web.socket.server.support; import java.io.IOException; import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; -import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.Endpoint; +import javax.websocket.Extension; import org.apache.tomcat.websocket.server.WsServerContainer; import org.springframework.http.server.ServerHttpRequest; @@ -58,7 +59,8 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg @Override public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, - String acceptedProtocol, Endpoint endpoint) throws HandshakeFailureException { + String selectedProtocol, List selectedExtensions, + Endpoint endpoint) throws HandshakeFailureException { Assert.isTrue(request instanceof ServletServerHttpRequest); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -71,7 +73,8 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg Map pathParams = Collections. emptyMap(); ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(path, endpoint); - endpointConfig.setSubprotocols(Arrays.asList(acceptedProtocol)); + endpointConfig.setSubprotocols(Arrays.asList(selectedProtocol)); + endpointConfig.setExtensions(selectedExtensions); try { getContainer(servletRequest).doUpgrade(servletRequest, servletResponse, endpointConfig, pathParams); @@ -86,10 +89,8 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg } } - public WsServerContainer getContainer(HttpServletRequest servletRequest) { - String attribute = "javax.websocket.server.ServerContainer"; - ServletContext servletContext = servletRequest.getServletContext(); - return (WsServerContainer) servletContext.getAttribute(attribute); + public WsServerContainer getContainer(HttpServletRequest request) { + return (WsServerContainer) super.getContainer(request); } } \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index 5b76febfafd..7373dc05228 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -20,6 +20,8 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -30,6 +32,7 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; @@ -66,6 +69,8 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private String acceptedProtocol; + private List extensions; + public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler, Map handshakeAttributes) { @@ -116,6 +121,11 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { this.remoteAddress = remoteAddress; } + @Override + public List getExtensions() { + return Collections.emptyList(); + } + /** * Unlike WebSocket where sub-protocol negotiation is part of the * initial handshake, in HTTP transports the same negotiation must diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java index 8474f108c06..b5137f91a47 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.List; import java.util.Map; import org.springframework.http.HttpHeaders; @@ -27,6 +28,7 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.NativeWebSocketSession; @@ -89,6 +91,12 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession return this.wsSession.getAcceptedProtocol(); } + @Override + public List getExtensions() { + checkDelegateSessionInitialized(); + return this.wsSession.getExtensions(); + } + private void checkDelegateSessionInitialized() { Assert.state(this.wsSession != null, "WebSocketSession not yet initialized"); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketExtension.java b/spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketExtension.java new file mode 100644 index 00000000000..6a474313fe4 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketExtension.java @@ -0,0 +1,243 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.support; + +import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.StringUtils; + +import javax.websocket.Extension; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * Represents a WebSocket extension as defined in the RFC 6455. + * WebSocket extensions add protocol features to the WebSocket protocol. The extensions + * used within a session are negotiated during the handshake phase as follows: + *

    + *
  • the client may ask for specific extensions in the HTTP handshake request
  • + *
  • the server responds with the final list of extensions to use in the current session
  • + *
+ * + *

WebSocket Extension HTTP headers may include parameters and follow + * RFC 2616 Section 4.2

+ * + *

Note that the order of extensions in HTTP headers defines their order of execution, + * e.g. extensions "foo, bar" will be executed as "bar(foo(message))".

+ * + * @author Brian Clozel + * @since 4.0 + * @see + * WebSocket Protocol Extensions, RFC 6455 - Section 9 + */ +public class WebSocketExtension { + + private final String name; + + private final Map parameters; + + + /** + * Create a WebSocketExtension with the given name. + * + * @param name the name of the extension + */ + public WebSocketExtension(String name) { + this(name, null); + } + + /** + * Create a WebSocketExtension with the given name and parameters. + * + * @param name the name of the extension + * @param parameters the parameters + */ + public WebSocketExtension(String name, Map parameters) { + Assert.hasLength(name, "extension name must not be empty"); + this.name = name; + if (!CollectionUtils.isEmpty(parameters)) { + Map m = new LinkedCaseInsensitiveMap(parameters.size(), Locale.ENGLISH); + m.putAll(parameters); + this.parameters = Collections.unmodifiableMap(m); + } + else { + this.parameters = Collections.emptyMap(); + } + } + + /** + * @return the name of the extension + */ + public String getName() { + return this.name; + } + + /** + * @return the parameters of the extension, never {@code null} + */ + public Map getParameters() { + return this.parameters; + } + + /** + * Parse the given, comma-separated string into a list of {@code WebSocketExtension} objects. + *

This method can be used to parse a "Sec-WebSocket-Extension" extensions. + * @param extensions the string to parse + * @return the list of extensions + * @throws IllegalArgumentException if the string cannot be parsed + */ + public static List parseExtensions(String extensions) { + if (extensions == null || !StringUtils.hasText(extensions)) { + return Collections.emptyList(); + } + else { + List result = new ArrayList(); + for(String token : extensions.split(",")) { + result.add(parseExtension(token)); + } + return result; + } + } + + private static WebSocketExtension parseExtension(String extension) { + Assert.doesNotContain(extension, ",", "Expected a single extension value: " + extension); + String[] parts = StringUtils.tokenizeToStringArray(extension, ";"); + String name = parts[0].trim(); + + Map parameters = null; + if (parts.length > 1) { + parameters = new LinkedHashMap(parts.length - 1); + for (int i = 1; i < parts.length; i++) { + String parameter = parts[i]; + int eqIndex = parameter.indexOf('='); + if (eqIndex != -1) { + String attribute = parameter.substring(0, eqIndex); + String value = parameter.substring(eqIndex + 1, parameter.length()); + parameters.put(attribute, value); + } + } + } + + return new WebSocketExtension(name, parameters); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if ((o == null) || (getClass() != o.getClass())) { + return false; + } + WebSocketExtension that = (WebSocketExtension) o; + if (!name.equals(that.name)) { + return false; + } + if (!parameters.equals(that.parameters)) { + return false; + } + return true; + } + + @Override + public int hashCode() { + int result = name.hashCode(); + result = 31 * result + parameters.hashCode(); + return result; + } + + @Override + public String toString() { + StringBuilder str = new StringBuilder(); + str.append(this.name); + for (String param : parameters.keySet()) { + str.append(';'); + str.append(param); + str.append('='); + str.append(this.parameters.get(param)); + } + return str.toString(); + } + + + // Standard WebSocketExtension adapters + + public static class StandardToWebSocketExtensionAdapter extends WebSocketExtension { + + public StandardToWebSocketExtensionAdapter(Extension ext) { + super(ext.getName()); + for (Extension.Parameter p : ext.getParameters()) { + super.getParameters().put(p.getName(), p.getValue()); + } + } + } + + public static class WebSocketToStandardExtensionAdapter implements Extension { + + private final String name; + + private final List parameters = new ArrayList(); + + public WebSocketToStandardExtensionAdapter(final WebSocketExtension ext) { + this.name = ext.getName(); + List params = new ArrayList(); + for (final String paramName : ext.getParameters().keySet()) { + this.parameters.add(new Parameter() { + @Override + public String getName() { + return paramName; + } + @Override + public String getValue() { + return ext.getParameters().get(paramName); + } + }); + } + } + + @Override + public String getName() { + return name; + } + + @Override + public List getParameters() { + return this.parameters; + } + } + + // Jetty WebSocketExtension adapters + + public static class WebSocketToJettyExtensionConfigAdapter extends ExtensionConfig { + + public WebSocketToJettyExtensionConfigAdapter(WebSocketExtension extension) { + super(extension.getName()); + for (Map.Entry p : extension.getParameters().entrySet()) { + super.setParameter(p.getKey(), p.getValue()); + } + } + } + + + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketHttpHeaders.java b/spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketHttpHeaders.java new file mode 100644 index 00000000000..c6a8e5a11e1 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketHttpHeaders.java @@ -0,0 +1,328 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.support; + +import org.springframework.http.HttpHeaders; +import org.springframework.util.CollectionUtils; + +import java.io.Serializable; +import java.util.*; + +/** + * An {@link org.springframework.http.HttpHeaders} variant that adds support for + * the HTTP headers defined by the WebSocket specification RFC 6455. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class WebSocketHttpHeaders extends HttpHeaders { + + public static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept"; + + public static final String SEC_WEBSOCKET_EXTENSIONS = "Sec-WebSocket-Extensions"; + + public static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key"; + + public static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; + + public static final String SEC_WEBSOCKET_VERSION = "Sec-WebSocket-Version"; + + private final HttpHeaders headers; + + + + + /** + * Create a new instance. + */ + public WebSocketHttpHeaders() { + this(new HttpHeaders(), false); + } + + /** + * Create an instance that wraps the given pre-existing HttpHeaders and also + * propagate all changes to it. + * + * @param headers the HTTP headers to wrap + */ + public WebSocketHttpHeaders(HttpHeaders headers) { + this(headers, false); + } + + /** + * Private constructor that can create read-only {@code WebSocketHttpHeader} instances. + */ + private WebSocketHttpHeaders(HttpHeaders headers, boolean readOnly) { + this.headers = readOnly ? HttpHeaders.readOnlyHttpHeaders(headers) : headers; + } + + /** + * Returns {@code WebSocketHttpHeaders} object that can only be read, not written to. + */ + public static WebSocketHttpHeaders readOnlyWebSocketHttpHeaders(WebSocketHttpHeaders headers) { + return new WebSocketHttpHeaders(headers, true); + } + + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Accept} header. + * @param secWebSocketAccept the value of the header + */ + public void setSecWebSocketAccept(String secWebSocketAccept) { + set(SEC_WEBSOCKET_ACCEPT, secWebSocketAccept); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Accept} header. + * @return the value of the header + */ + public String getSecWebSocketAccept() { + return getFirst(SEC_WEBSOCKET_ACCEPT); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Extensions} header. + * @return the value of the header + */ + public List getSecWebSocketExtensions() { + List values = get(SEC_WEBSOCKET_EXTENSIONS); + if (CollectionUtils.isEmpty(values)) { + return Collections.emptyList(); + } + else { + List result = new ArrayList(values.size()); + for (String value : values) { + result.addAll(WebSocketExtension.parseExtensions(value)); + } + return result; + } + } + + /** + * Sets the (new) value(s) of the {@code Sec-WebSocket-Extensions} header. + * @param extensions the values for the header + */ + public void setSecWebSocketExtensions(List extensions) { + List result = new ArrayList(extensions.size()); + for(WebSocketExtension extension : extensions) { + result.add(extension.toString()); + } + set(SEC_WEBSOCKET_EXTENSIONS, toCommaDelimitedString(result)); + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Key} header. + * @param secWebSocketKey the value of the header + */ + public void setSecWebSocketKey(String secWebSocketKey) { + set(SEC_WEBSOCKET_KEY, secWebSocketKey); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Key} header. + * @return the value of the header + */ + public String getSecWebSocketKey() { + return getFirst(SEC_WEBSOCKET_KEY); + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Protocol} header. + * @param secWebSocketProtocol the value of the header + */ + public void setSecWebSocketProtocol(String secWebSocketProtocol) { + if (secWebSocketProtocol != null) { + set(SEC_WEBSOCKET_PROTOCOL, secWebSocketProtocol); + } + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Protocol} header. + * @param secWebSocketProtocols the value of the header + */ + public void setSecWebSocketProtocol(List secWebSocketProtocols) { + set(SEC_WEBSOCKET_PROTOCOL, toCommaDelimitedString(secWebSocketProtocols)); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Key} header. + * @return the value of the header + */ + public List getSecWebSocketProtocol() { + List values = get(SEC_WEBSOCKET_PROTOCOL); + if (CollectionUtils.isEmpty(values)) { + return Collections.emptyList(); + } + else if (values.size() == 1) { + return getFirstValueAsList(SEC_WEBSOCKET_PROTOCOL); + } + else { + return values; + } + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Version} header. + * @param secWebSocketVersion the value of the header + */ + public void setSecWebSocketVersion(String secWebSocketVersion) { + set(SEC_WEBSOCKET_VERSION, secWebSocketVersion); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Version} header. + * @return the value of the header + */ + public String getSecWebSocketVersion() { + return getFirst(SEC_WEBSOCKET_VERSION); + } + + + // Single string methods + + /** + * Return the first header value for the given header name, if any. + * @param headerName the header name + * @return the first header value; or {@code null} + */ + @Override + public String getFirst(String headerName) { + return this.headers.getFirst(headerName); + } + + /** + * Add the given, single header value under the given name. + * @param headerName the header name + * @param headerValue the header value + * @throws UnsupportedOperationException if adding headers is not supported + * @see #put(String, List) + * @see #set(String, String) + */ + @Override + public void add(String headerName, String headerValue) { + this.headers.add(headerName, headerValue); + } + + /** + * Set the given, single header value under the given name. + * @param headerName the header name + * @param headerValue the header value + * @throws UnsupportedOperationException if adding headers is not supported + * @see #put(String, List) + * @see #add(String, String) + */ + @Override + public void set(String headerName, String headerValue) { + this.headers.set(headerName, headerValue); + } + + @Override + public void setAll(Map values) { + this.headers.setAll(values); + } + + @Override + public Map toSingleValueMap() { + return this.headers.toSingleValueMap(); + } + + // Map implementation + + @Override + public int size() { + return this.headers.size(); + } + + @Override + public boolean isEmpty() { + return this.headers.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return this.headers.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + return this.headers.containsValue(value); + } + + @Override + public List get(Object key) { + return this.headers.get(key); + } + + @Override + public List put(String key, List value) { + return this.headers.put(key, value); + } + + @Override + public List remove(Object key) { + return this.headers.remove(key); + } + + @Override + public void putAll(Map> m) { + this.headers.putAll(m); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.keySet(); + } + + @Override + public Collection> values() { + return this.headers.values(); + } + + @Override + public Set>> entrySet() { + return this.headers.entrySet(); + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof WebSocketHttpHeaders)) { + return false; + } + WebSocketHttpHeaders otherHeaders = (WebSocketHttpHeaders) other; + return this.headers.equals(otherHeaders.headers); + } + + @Override + public int hashCode() { + return this.headers.hashCode(); + } + + @Override + public String toString() { + return this.headers.toString(); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java index d0dd2f95014..0845e83018d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java @@ -117,7 +117,7 @@ public abstract class AbstractWebSocketIntegrationTests { static abstract class AbstractRequestUpgradeStrategyConfig { @Bean - public HandshakeHandler handshakeHandler() { + public DefaultHandshakeHandler handshakeHandler() { return new DefaultHandshakeHandler(requestUpgradeStrategy()); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java new file mode 100644 index 00000000000..f760ee110fb --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import org.hamcrest.Matchers; +import org.junit.Test; +import org.springframework.web.socket.support.WebSocketExtension; + +import java.util.ArrayList; +import java.util.List; + +/** + * Test fixture for {@link org.springframework.web.socket.support.WebSocketExtension} + * @author Brian Clozel + */ +public class WebSocketExtensionTests { + + @Test + public void parseHeaderSingle() { + List extensions = WebSocketExtension.parseExtensions("x-test-extension ; foo=bar ; bar=baz"); + assertThat(extensions, Matchers.hasSize(1)); + WebSocketExtension extension = extensions.get(0); + + assertEquals("x-test-extension", extension.getName()); + assertEquals(2, extension.getParameters().size()); + assertEquals("bar", extension.getParameters().get("foo")); + assertEquals("baz", extension.getParameters().get("bar")); + } + + @Test + public void parseHeaderMultiple() { + List extensions = WebSocketExtension.parseExtensions("x-foo-extension, x-bar-extension"); + assertThat(extensions, Matchers.hasSize(2)); + assertEquals("x-foo-extension", extensions.get(0).getName()); + assertEquals("x-bar-extension", extensions.get(1).getName()); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java new file mode 100644 index 00000000000..3a2c5057d65 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java @@ -0,0 +1,173 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket; + + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; +import org.springframework.web.socket.adapter.WebSocketHandlerAdapter; +import org.springframework.web.socket.client.endpoint.StandardWebSocketClient; +import org.springframework.web.socket.client.jetty.JettyWebSocketClient; +import org.springframework.web.socket.server.DefaultHandshakeHandler; +import org.springframework.web.socket.server.HandshakeFailureException; +import org.springframework.web.socket.server.RequestUpgradeStrategy; +import org.springframework.web.socket.server.config.EnableWebSocket; +import org.springframework.web.socket.server.config.WebSocketConfigurer; +import org.springframework.web.socket.server.config.WebSocketHandlerRegistry; +import org.springframework.web.socket.support.WebSocketExtension; +import org.springframework.web.socket.support.WebSocketHttpHeaders; + +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; + +import static org.junit.Assert.assertEquals; + +/** + * Client and server-side WebSocket integration tests. + * + * @author Rossen Stoyanchev + */ +@RunWith(Parameterized.class) +public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { + + @Parameterized.Parameters + public static Iterable arguments() { + return Arrays.asList(new Object[][]{ + {new JettyWebSocketTestServer(), new JettyWebSocketClient()}, + {new TomcatWebSocketTestServer(), new StandardWebSocketClient()} + }); + }; + + + @Override + protected Class[] getAnnotatedConfigClasses() { + return new Class[] { TestWebSocketConfigurer.class }; + } + + @Test + public void subProtocolNegotiation() throws Exception { + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); + headers.setSecWebSocketProtocol("foo"); + + WebSocketSession session = this.webSocketClient.doHandshake( + new WebSocketHandlerAdapter(), headers, new URI(getWsBaseUrl() + "/ws")).get(); + + assertEquals("foo", session.getAcceptedProtocol()); + } + + + @Configuration + @EnableWebSocket + static class TestWebSocketConfigurer implements WebSocketConfigurer { + + @Autowired + private DefaultHandshakeHandler handshakeHandler; // can't rely on classpath for server detection + + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + this.handshakeHandler.setSupportedProtocols("foo", "bar", "baz"); + registry.addHandler(serverHandler(), "/ws").setHandshakeHandler(this.handshakeHandler); + } + + @Bean + public TestServerWebSocketHandler serverHandler() { + return new TestServerWebSocketHandler(); + } + } + + private static class TestClientWebSocketHandler extends TextWebSocketHandlerAdapter { + + private final TextMessage[] messagesToSend; + + private final int expected; + + private final List actual = new CopyOnWriteArrayList(); + + private final CountDownLatch latch; + + + public TestClientWebSocketHandler(int expectedNumberOfMessages, TextMessage... messagesToSend) { + this.messagesToSend = messagesToSend; + this.expected = expectedNumberOfMessages; + this.latch = new CountDownLatch(this.expected); + } + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + for (TextMessage message : this.messagesToSend) { + session.sendMessage(message); + } + } + + @Override + protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { + this.actual.add(message); + this.latch.countDown(); + } + } + + private static class TestServerWebSocketHandler extends TextWebSocketHandlerAdapter { + + } + + private static class TestRequestUpgradeStrategy implements RequestUpgradeStrategy { + + private final RequestUpgradeStrategy delegate; + + private List extensions= new ArrayList(); + + + private TestRequestUpgradeStrategy(RequestUpgradeStrategy delegate, String... supportedExtensions) { + this.delegate = delegate; + for (String name : supportedExtensions) { + this.extensions.add(new WebSocketExtension(name)); + } + } + + @Override + public String[] getSupportedVersions() { + return this.delegate.getSupportedVersions(); + } + + @Override + public List getSupportedExtensions(ServerHttpRequest request) { + return this.extensions; + } + + @Override + public void upgrade(ServerHttpRequest request, ServerHttpResponse response, + String selectedProtocol, List selectedExtensions, + WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { + + this.delegate.upgrade(request, response, selectedProtocol, selectedExtensions, wsHandler, attributes); + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java index 4bf61438eda..0a4c4ec7c4f 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java @@ -31,6 +31,7 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.WebSocketHandlerAdapter; import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator; import org.springframework.web.socket.support.WebSocketHandlerDecorator; +import org.springframework.web.socket.support.WebSocketHttpHeaders; import org.springframework.web.util.UriComponentsBuilder; import static org.junit.Assert.*; @@ -54,7 +55,7 @@ public class WebSocketConnectionManagerTests { manager.setSubProtocols(subprotocols); manager.openConnection(); - HttpHeaders expectedHeaders = new HttpHeaders(); + WebSocketHttpHeaders expectedHeaders = new WebSocketHttpHeaders(); expectedHeaders.setSecWebSocketProtocol(subprotocols); assertEquals(expectedHeaders, client.headers); @@ -150,7 +151,7 @@ public class WebSocketConnectionManagerTests { @Override public ListenableFuture doHandshake(WebSocketHandler webSocketHandler, - HttpHeaders headers, URI uri) { + WebSocketHttpHeaders headers, URI uri) { this.webSocketHandler = webSocketHandler; this.headers = headers; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java index 8fe32355fce..5b2919ed568 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java @@ -34,6 +34,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.WebSocketHandlerAdapter; +import org.springframework.web.socket.support.WebSocketHttpHeaders; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -51,12 +52,12 @@ public class StandardWebSocketClientTests { private WebSocketHandler wsHandler; - private HttpHeaders headers; + private WebSocketHttpHeaders headers; @Before public void setup() { - this.headers = new HttpHeaders(); + this.headers = new WebSocketHttpHeaders(); this.wsHandler = new WebSocketHandlerAdapter(); this.wsContainer = mock(WebSocketContainer.class); this.wsClient = new StandardWebSocketClient(this.wsContainer); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java index c1960a52b3e..14d0b770eae 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java @@ -36,6 +36,7 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSession; import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; +import org.springframework.web.socket.support.WebSocketHttpHeaders; import static org.junit.Assert.*; @@ -80,7 +81,7 @@ public class JettyWebSocketClientTests { @Test public void doHandshake() throws Exception { - HttpHeaders headers = new HttpHeaders(); + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); headers.setSecWebSocketProtocol(Arrays.asList("echo")); this.wsSession = this.client.doHandshake(new TextWebSocketHandlerAdapter(), headers, new URI(this.wsUrl)).get(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java index 537b8a524df..5a96e5ec393 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java @@ -24,8 +24,10 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.web.socket.AbstractHttpRequestTests; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; +import org.springframework.web.socket.support.WebSocketHttpHeaders; import static org.mockito.Mockito.*; @@ -57,19 +59,22 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] { "13" }); + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders()); + this.servletRequest.setMethod("GET"); - this.request.getHeaders().setUpgrade("WebSocket"); - this.request.getHeaders().setConnection("Upgrade"); - this.request.getHeaders().setSecWebSocketVersion("13"); - this.request.getHeaders().setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); - this.request.getHeaders().setSecWebSocketProtocol("STOMP"); + headers.setUpgrade("WebSocket"); + headers.setConnection("Upgrade"); + headers.setSecWebSocketVersion("13"); + headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); + headers.setSecWebSocketProtocol("STOMP"); WebSocketHandler handler = new TextWebSocketHandlerAdapter(); Map attributes = Collections.emptyMap(); this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes); - verify(this.upgradeStrategy).upgrade(this.request, this.response, "STOMP", handler, attributes); + verify(this.upgradeStrategy).upgrade(this.request, this.response, + "STOMP", Collections.emptyList(), handler, attributes); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java index ab36aa549ec..e769d5b6520 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java @@ -41,7 +41,7 @@ import static org.junit.Assert.*; /** - * Test fixture for WebSocket Java config support. + * Integration tests for WebSocket Java server-side configuration. * * @author Rossen Stoyanchev */ diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java index 1cee547aab2..730c12cdcdc 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java @@ -26,6 +26,7 @@ import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; @@ -58,6 +59,8 @@ public class TestSockJsSession extends AbstractSockJsSession { private String subProtocol; + private List extensions = new ArrayList<>(); + public TestSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler wsHandler, Map attributes) { @@ -80,61 +83,37 @@ public class TestSockJsSession extends AbstractSockJsSession { return this.headers; } - /** - * @return the headers - */ public HttpHeaders getHeaders() { return this.headers; } - /** - * @param headers the headers to set - */ public void setHeaders(HttpHeaders headers) { this.headers = headers; } - /** - * @return the principal - */ @Override public Principal getPrincipal() { return this.principal; } - /** - * @param principal the principal to set - */ public void setPrincipal(Principal principal) { this.principal = principal; } - /** - * @return the localAddress - */ @Override public InetSocketAddress getLocalAddress() { return this.localAddress; } - /** - * @param remoteAddress the remoteAddress to set - */ public void setLocalAddress(InetSocketAddress localAddress) { this.localAddress = localAddress; } - /** - * @return the remoteAddress - */ @Override public InetSocketAddress getRemoteAddress() { return this.remoteAddress; } - /** - * @param remoteAddress the remoteAddress to set - */ public void setRemoteAddress(InetSocketAddress remoteAddress) { this.remoteAddress = remoteAddress; } @@ -148,6 +127,15 @@ public class TestSockJsSession extends AbstractSockJsSession { this.subProtocol = protocol; } + @Override + public List getExtensions() { + return this.extensions; + } + + public void setExtensions(List extensions) { + this.extensions = extensions; + } + public CloseStatus getCloseStatus() { return this.closeStatus; } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java index 6f16888e688..7c8775cfc73 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java @@ -51,6 +51,8 @@ public class TestWebSocketSession implements WebSocketSession { private String protocol; + private List extensions = new ArrayList(); + private boolean open; private final List> messages = new ArrayList<>(); @@ -60,32 +62,20 @@ public class TestWebSocketSession implements WebSocketSession { private HttpHeaders headers; - /** - * @return the id - */ @Override public String getId() { return this.id; } - /** - * @param id the id to set - */ public void setId(String id) { this.id = id; } - /** - * @return the uri - */ @Override public URI getUri() { return this.uri; } - /** - * @param uri the uri to set - */ public void setUri(URI uri) { this.uri = uri; } @@ -96,105 +86,72 @@ public class TestWebSocketSession implements WebSocketSession { return this.headers; } - /** - * @return the headers - */ public HttpHeaders getHeaders() { return this.headers; } - /** - * @param headers the headers to set - */ public void setHeaders(HttpHeaders headers) { this.headers = headers; } - /** - * @param attributes the attributes to set - */ public void setHandshakeAttributes(Map attributes) { this.attributes = attributes; } - /** - * @return the attributes - */ @Override public Map getHandshakeAttributes() { return this.attributes; } - /** - * @return the principal - */ @Override public Principal getPrincipal() { return this.principal; } - /** - * @param principal the principal to set - */ public void setPrincipal(Principal principal) { this.principal = principal; } - /** - * @return the localAddress - */ @Override public InetSocketAddress getLocalAddress() { return this.localAddress; } - /** - * @param remoteAddress the remoteAddress to set - */ public void setLocalAddress(InetSocketAddress localAddress) { this.localAddress = localAddress; } - /** - * @return the remoteAddress - */ @Override public InetSocketAddress getRemoteAddress() { return this.remoteAddress; } - /** - * @param remoteAddress the remoteAddress to set - */ public void setRemoteAddress(InetSocketAddress remoteAddress) { this.remoteAddress = remoteAddress; } - /** - * @return the subProtocol - */ public String getAcceptedProtocol() { return this.protocol; } - /** - * @param protocol the subProtocol to set - */ public void setAcceptedProtocol(String protocol) { this.protocol = protocol; } - /** - * @return the open - */ + @Override + public List getExtensions() { + return this.extensions; + } + + public void setExtensions(List extensions) { + this.extensions = extensions; + } + @Override public boolean isOpen() { return this.open; } - /** - * @param open the open to set - */ public void setOpen(boolean open) { this.open = open; } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/WebSocketHttpHeadersTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/WebSocketHttpHeadersTests.java new file mode 100644 index 00000000000..7cca7129775 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/WebSocketHttpHeadersTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.support; + +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertThat; + +/** + * Unit tests for WebSocketHttpHeaders. + * + * @author Rossen Stoyanchev + */ +public class WebSocketHttpHeadersTests { + + private WebSocketHttpHeaders headers; + + @Before + public void setUp() { + headers = new WebSocketHttpHeaders(); + } + + @Test + public void parseWebSocketExtensions() { + List extensions = new ArrayList(); + extensions.add("x-foo-extension, x-bar-extension"); + extensions.add("x-test-extension"); + this.headers.put(WebSocketHttpHeaders.SEC_WEBSOCKET_EXTENSIONS, extensions); + + List parsedExtensions = this.headers.getSecWebSocketExtensions(); + assertThat(parsedExtensions, Matchers.hasSize(3)); + } + +}