From 6d00a3f0ee795e1a5bd9ffe02c584623828dbc37 Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Thu, 24 Oct 2013 11:28:21 +0200 Subject: [PATCH 1/2] Add support for WebSocket Protocol Extensions This commits adds simple, overridable WebSocket Extension filtering during the handshake phase and adds that information in the WebSocket session. The actual WebSocket Extension negotiation happens within the server implementation (Glassfish, Jetty, Tomcat...), so one can only remove requested extensions from the list provided by the WebSocket client. See RFC6455 Section 9. Issue: SPR-10843 --- .../handler/websocket/SubProtocolHandler.java | 2 +- .../web/socket/WebSocketExtension.java | 167 ++++++++++++++++++ .../web/socket/WebSocketSession.java | 7 + .../socket/adapter/JettyWebSocketSession.java | 18 ++ .../adapter/StandardWebSocketSession.java | 24 +++ .../server/DefaultHandshakeHandler.java | 34 +++- .../socket/server/RequestUpgradeStrategy.java | 8 + .../server/config/WebSocketConfigurer.java | 3 +- ...stractGlassFishRequestUpgradeStrategy.java | 29 +++ .../AbstractStandardUpgradeStrategy.java | 11 ++ .../support/JettyRequestUpgradeStrategy.java | 16 ++ .../support/TomcatRequestUpgradeStrategy.java | 18 ++ .../session/AbstractHttpSockJsSession.java | 7 + .../session/WebSocketServerSockJsSession.java | 8 + .../web/socket/WebSocketExtensionTests.java | 61 +++++++ .../transport/session/TestSockJsSession.java | 17 +- .../socket/support/TestWebSocketSession.java | 17 +- 17 files changed, 441 insertions(+), 6 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java index 1d3dcb30a04..fa780ed94df 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java @@ -69,7 +69,7 @@ public interface SubProtocolHandler { /** * Resolve the session id from the given message or return {@code null}. * - * @param the message to resolve the session id from + * @param message the message to resolve the session id from */ String resolveSessionId(Message message); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java new file mode 100644 index 00000000000..399a3eca810 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java @@ -0,0 +1,167 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket; + +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.StringUtils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * WebSocket Protocol extension. + * Adds new protocol features to the WebSocket protocol; the extensions used + * within a session are negotiated during the handshake phase: + * + * + *

WebSocket Extension HTTP headers may include parameters and follow + * RFC 2616 Section 4.2 + * specifications.

+ * + *

Note that the order of extensions in HTTP headers defines their order of execution, + * e.g. extensions "foo, bar" will be executed as "bar(foo(message))".

+ * + * @author Brian Clozel + * @since 4.0 + * @see + * WebSocket Protocol Extensions, RFC 6455 - Section 9 + */ +public class WebSocketExtension { + + private final String name; + + private final Map parameters; + + public WebSocketExtension(String name) { + this(name,null); + } + + public WebSocketExtension(String name, Map parameters) { + Assert.hasLength(name, "extension name must not be empty"); + this.name = name; + if (!CollectionUtils.isEmpty(parameters)) { + Map m = new LinkedCaseInsensitiveMap(parameters.size(), Locale.ENGLISH); + m.putAll(parameters); + this.parameters = Collections.unmodifiableMap(m); + } + else { + this.parameters = Collections.emptyMap(); + } + } + + /** + * @return the name of the extension + */ + public String getName() { + return this.name; + } + + /** + * @return the parameters of the extension + */ + public Map getParameters() { + return this.parameters; + } + + /** + * Parse a list of raw WebSocket extension headers + */ + public static List parseHeaders(List headers) { + if (headers == null || headers.isEmpty()) { + return Collections.emptyList(); + } + else { + List result = new ArrayList(headers.size()); + for (String header : headers) { + result.addAll(parseHeader(header)); + } + return result; + } + } + + /** + * Parse a raw WebSocket extension header + */ + public static List parseHeader(String header) { + if (header == null || !StringUtils.hasText(header)) { + return Collections.emptyList(); + } + else { + List result = new ArrayList(); + for(String token : header.split(",")) { + result.add(parse(token)); + } + return result; + } + } + + private static WebSocketExtension parse(String extension) { + Assert.doesNotContain(extension,",","this string contains multiple extension declarations"); + String[] parts = StringUtils.tokenizeToStringArray(extension, ";"); + String name = parts[0].trim(); + + Map parameters = null; + if (parts.length > 1) { + parameters = new LinkedHashMap(parts.length - 1); + for (int i = 1; i < parts.length; i++) { + String parameter = parts[i]; + int eqIndex = parameter.indexOf('='); + if (eqIndex != -1) { + String attribute = parameter.substring(0, eqIndex); + String value = parameter.substring(eqIndex + 1, parameter.length()); + parameters.put(attribute, value); + } + } + } + + return new WebSocketExtension(name,parameters); + } + + /** + * Convert a list of WebSocketExtensions to a list of String, + * which is convenient for native HTTP headers. + */ + public static List toStringList(List extensions) { + List result = new ArrayList(extensions.size()); + for(WebSocketExtension extension : extensions) { + result.add(extension.toString()); + } + return result; + } + + @Override + public String toString() { + StringBuilder str = new StringBuilder(); + str.append(this.name); + for (String param : parameters.keySet()) { + str.append(';'); + str.append(param); + str.append('='); + str.append(this.parameters.get(param)); + } + return str.toString(); + } +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java index b935e21889d..a4634ddc398 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.List; import java.util.Map; import org.springframework.http.HttpHeaders; @@ -79,6 +80,12 @@ public interface WebSocketSession { */ String getAcceptedProtocol(); + /** + * Return the negotiated extensions or {@code null} if none was specified or + * negotiated successfully. + */ + List getExtensions(); + /** * Return whether the connection is still open. */ diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java index b7bc0a8130a..62222e394ee 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java @@ -20,8 +20,11 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.ArrayList; +import java.util.List; import java.util.Map; +import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig; import org.springframework.http.HttpHeaders; import org.springframework.util.ObjectUtils; import org.springframework.web.socket.BinaryMessage; @@ -29,6 +32,7 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.PongMessage; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketSession; /** @@ -42,6 +46,8 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion extensions; + private final Principal principal; @@ -104,6 +110,18 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion getExtensions() { + checkNativeSessionInitialized(); + if(this.extensions == null) { + this.extensions = new ArrayList(); + for(ExtensionConfig ext : getNativeSession().getUpgradeResponse().getExtensions()) { + this.extensions.add(new WebSocketExtension(ext.getName(),ext.getParameters())); + } + } + return this.extensions; + } + @Override public boolean isOpen() { return ((getNativeSession() != null) && getNativeSession().isOpen()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java index 7ddd3a87187..6a9fa9be8c5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java @@ -20,10 +20,14 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import javax.websocket.CloseReason; import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.Extension; import org.springframework.http.HttpHeaders; import org.springframework.util.StringUtils; @@ -32,6 +36,7 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.PongMessage; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketSession; /** @@ -48,6 +53,8 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion extensions; + /** * Class constructor. @@ -108,6 +115,23 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion getExtensions() { + checkNativeSessionInitialized(); + if(this.extensions == null) { + List nativeExtensions = getNativeSession().getNegotiatedExtensions(); + this.extensions = new ArrayList(nativeExtensions.size()); + for(Extension nativeExtension : nativeExtensions) { + Map parameters = new HashMap(); + for (Extension.Parameter param : nativeExtension.getParameters()) { + parameters.put(param.getName(),param.getValue()); + } + this.extensions.add(new WebSocketExtension(nativeExtension.getName(),parameters)); + } + } + return this.extensions; + } + @Override public boolean isOpen() { return ((getNativeSession() != null) && getNativeSession().isOpen()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java index 6c0bfa0e84a..f4b02b3a01d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java @@ -32,13 +32,15 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; /** * A default {@link HandshakeHandler} implementation. Performs initial validation of the * WebSocket handshake request -- possibly rejecting it through the appropriate HTTP * status code -- while also allowing sub-classes to override various parts of the - * negotiation process (e.g. origin validation, sub-protocol negotiation, etc). + * negotiation process (e.g. origin validation, sub-protocol negotiation, + * extensions negotiation, etc). * *

* If the negotiation succeeds, the actual upgrade is delegated to a server-specific @@ -188,6 +190,13 @@ public class DefaultHandshakeHandler implements HandshakeHandler { logger.debug("Upgrading request, sub-protocol=" + subProtocol); } + List requestedExtensions = WebSocketExtension + .parseHeaders(request.getHeaders().getSecWebSocketExtensions()); + + List filteredExtensions = filterRequestedExtensions(requestedExtensions, + this.requestUpgradeStrategy.getAvailableExtensions(request)); + request.getHeaders().setSecWebSocketExtensions(WebSocketExtension.toStringList(filteredExtensions)); + this.requestUpgradeStrategy.upgrade(request, response, subProtocol, wsHandler, attributes); return true; @@ -254,4 +263,27 @@ public class DefaultHandshakeHandler implements HandshakeHandler { return null; } + /** + * Filter the list of WebSocket Extensions requested by the client. + * Since the negotiation process happens during the upgrade phase within the server + * implementation, one can customize the applied extensions only by filtering the + * requested extensions by the client. + * + *

The default implementation of this method doesn't filter any of the extensions + * requested by the client. + * @param requestedExtensions the list of extensions requested by the client + * @param supportedExtensions the list of extensions supported by the server + * @return the filtered list of requested extensions + */ + protected List filterRequestedExtensions(List requestedExtensions, + List supportedExtensions) { + + if (requestedExtensions != null) { + if (logger.isDebugEnabled()) { + logger.debug("Requested extension(s): " + requestedExtensions + + ", supported extension(s): " + supportedExtensions); + } + } + return requestedExtensions; + } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java index b7814264455..89b197fe38c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java @@ -16,10 +16,12 @@ package org.springframework.web.socket.server; +import java.util.List; import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; /** @@ -36,6 +38,12 @@ public interface RequestUpgradeStrategy { */ String[] getSupportedVersions(); + /** + * @return the list of available WebSocket protocol extensions, + * implemented by the underlying WebSocket server. + */ + List getAvailableExtensions(ServerHttpRequest request); + /** * Perform runtime specific steps to complete the upgrade. Invoked after successful * negotiation of the handshake request. diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java index 1d41f06bb14..221ac25634a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java @@ -16,8 +16,7 @@ package org.springframework.web.socket.server.config; -import org.eclipse.jetty.websocket.server.WebSocketHandler; - +import org.springframework.web.socket.WebSocketHandler; /** diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java index 6f0b7838879..7349615f500 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java @@ -19,15 +19,21 @@ package org.springframework.web.socket.server.support; import java.io.IOException; import java.lang.reflect.Constructor; import java.net.URI; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Random; +import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.DeploymentException; import javax.websocket.Endpoint; +import javax.websocket.Extension; +import javax.websocket.server.ServerContainer; +import org.apache.tomcat.websocket.server.WsServerContainer; import org.glassfish.tyrus.core.ComponentProviderService; import org.glassfish.tyrus.core.EndpointWrapper; import org.glassfish.tyrus.core.ErrorCollector; @@ -47,6 +53,7 @@ import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration; import org.springframework.web.socket.server.endpoint.ServletServerContainerFactoryBean; @@ -66,11 +73,26 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt private final static Random random = new Random(); + private List availableExtensions; + @Override public String[] getSupportedVersions() { return StringUtils.commaDelimitedListToStringArray(Version.getSupportedWireProtocolVersions()); } + @Override + public List getAvailableExtensions(ServerHttpRequest request) { + + if(this.availableExtensions == null) { + this.availableExtensions = new ArrayList(); + HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); + for(Extension extension : getContainer(servletRequest).getInstalledExtensions()) { + this.availableExtensions.add(parseStandardExtension(extension)); + } + } + return this.availableExtensions; + } + @Override public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, Endpoint endpoint) throws HandshakeFailureException { @@ -103,6 +125,13 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt } } + public ServerContainer getContainer(HttpServletRequest servletRequest) { + + String attributeName = "javax.websocket.server.ServerContainer"; + ServletContext servletContext = servletRequest.getServletContext(); + return (ServerContainer)servletContext.getAttribute(attributeName); + } + private boolean performUpgrade(HttpServletRequest request, HttpServletResponse response, HttpHeaders headers, WebSocketApplication wsApp) throws IOException { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java index 925540c3fcb..0f62add3587 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java @@ -17,15 +17,18 @@ package org.springframework.web.socket.server.support; import java.net.InetSocketAddress; +import java.util.HashMap; import java.util.Map; import javax.websocket.Endpoint; +import javax.websocket.Extension; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.StandardWebSocketSession; @@ -62,4 +65,12 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, Endpoint endpoint) throws HandshakeFailureException; + protected WebSocketExtension parseStandardExtension(Extension extension) { + Map params = new HashMap(extension.getParameters().size()); + for(Extension.Parameter param : extension.getParameters()) { + params.put(param.getName(),param.getValue()); + } + return new WebSocketExtension(extension.getName(),params); + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java index 14f6d1ce6b1..3064adb417e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java @@ -17,6 +17,8 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import javax.servlet.http.HttpServletRequest; @@ -34,6 +36,7 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSession; @@ -54,6 +57,8 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { private WebSocketServerFactory factory; + private List availableExtensions; + /** * Default constructor that creates {@link WebSocketServerFactory} through its default @@ -92,6 +97,17 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { return new String[] { String.valueOf(HandshakeRFC6455.VERSION) }; } + @Override + public List getAvailableExtensions(ServerHttpRequest request) { + if(this.availableExtensions == null) { + this.availableExtensions = new ArrayList(); + for(String extensionName : this.factory.getExtensionFactory().getExtensionNames()) { + this.availableExtensions.add(new WebSocketExtension(extensionName)); + } + } + return this.availableExtensions; + } + @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol, WebSocketHandler wsHandler, Map attrs) throws HandshakeFailureException { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java index 4a96dff264d..7a5608ae8b9 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java @@ -17,8 +17,10 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; import javax.servlet.ServletContext; @@ -26,6 +28,7 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.Endpoint; +import javax.websocket.Extension; import org.apache.tomcat.websocket.server.WsServerContainer; import org.springframework.http.server.ServerHttpRequest; @@ -33,6 +36,7 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration; import org.springframework.web.socket.server.endpoint.ServletServerContainerFactoryBean; @@ -50,12 +54,26 @@ import org.springframework.web.socket.server.endpoint.ServletServerContainerFact */ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy { + private List availableExtensions; @Override public String[] getSupportedVersions() { return new String[] { "13" }; } + @Override + public List getAvailableExtensions(ServerHttpRequest request) { + + if(this.availableExtensions == null) { + this.availableExtensions = new ArrayList(); + HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); + for(Extension extension : getContainer(servletRequest).getInstalledExtensions()) { + this.availableExtensions.add(parseStandardExtension(extension)); + } + } + return this.availableExtensions; + } + @Override public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, String acceptedProtocol, Endpoint endpoint) throws HandshakeFailureException { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index 5b76febfafd..474d94dcd2f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.List; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -30,6 +31,7 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; @@ -66,6 +68,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private String acceptedProtocol; + private List extensions; public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler, Map handshakeAttributes) { @@ -116,6 +119,9 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { this.remoteAddress = remoteAddress; } + @Override + public List getExtensions() { return this.extensions; } + /** * Unlike WebSocket where sub-protocol negotiation is part of the * initial handshake, in HTTP transports the same negotiation must @@ -152,6 +158,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { this.principal = request.getPrincipal(); this.localAddress = request.getLocalAddress(); this.remoteAddress = request.getRemoteAddress(); + this.extensions = WebSocketExtension.parseHeaders(response.getHeaders().getSecWebSocketExtensions()); try { delegateConnectionEstablished(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java index 8474f108c06..e66af40e7e0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.List; import java.util.Map; import org.springframework.http.HttpHeaders; @@ -27,6 +28,7 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.NativeWebSocketSession; @@ -89,6 +91,12 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession return this.wsSession.getAcceptedProtocol(); } + @Override + public List getExtensions() { + checkDelegateSessionInitialized(); + return this.wsSession.getExtensions(); + } + private void checkDelegateSessionInitialized() { Assert.state(this.wsSession != null, "WebSocketSession not yet initialized"); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java new file mode 100644 index 00000000000..2a674af1736 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import org.hamcrest.Matchers; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +/** + * Test fixture for {@link WebSocketExtension} + * @author Brian Clozel + */ +public class WebSocketExtensionTests { + + @Test + public void parseHeaderSingle() { + List extensions = WebSocketExtension.parseHeader("x-test-extension ; foo=bar"); + assertThat(extensions, Matchers.hasSize(1)); + WebSocketExtension extension = extensions.get(0); + assertEquals("x-test-extension", extension.getName()); + assertEquals(1, extension.getParameters().size()); + assertEquals("bar", extension.getParameters().get("foo")); + } + + @Test + public void parseHeaderMultiple() { + List extensions = WebSocketExtension.parseHeader("x-foo-extension, x-bar-extension"); + assertThat(extensions, Matchers.hasSize(2)); + assertEquals("x-foo-extension", extensions.get(0).getName()); + assertEquals("x-bar-extension", extensions.get(1).getName()); + } + + @Test + public void parseHeaders() { + List extensions = new ArrayList(); + extensions.add("x-foo-extension, x-bar-extension"); + extensions.add("x-test-extension"); + List parsedExtensions = WebSocketExtension.parseHeaders(extensions); + assertThat(parsedExtensions, Matchers.hasSize(3)); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java index 1cee547aab2..9f145cdd1e5 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java @@ -26,6 +26,7 @@ import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; @@ -58,6 +59,8 @@ public class TestSockJsSession extends AbstractSockJsSession { private String subProtocol; + private List extensions = new ArrayList(); + public TestSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler wsHandler, Map attributes) { @@ -118,7 +121,7 @@ public class TestSockJsSession extends AbstractSockJsSession { } /** - * @param remoteAddress the remoteAddress to set + * @param localAddress the remoteAddress to set */ public void setLocalAddress(InetSocketAddress localAddress) { this.localAddress = localAddress; @@ -148,6 +151,18 @@ public class TestSockJsSession extends AbstractSockJsSession { this.subProtocol = protocol; } + /** + * @return the extensions + */ + @Override + public List getExtensions() { return this.extensions; } + + /** + * + * @param extensions the extensions to set + */ + public void setExtensions(List extensions) { this.extensions = extensions; } + public CloseStatus getCloseStatus() { return this.closeStatus; } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java index 6f16888e688..17935ac62db 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java @@ -27,6 +27,7 @@ import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; @@ -51,6 +52,8 @@ public class TestWebSocketSession implements WebSocketSession { private String protocol; + private List extensions = new ArrayList(); + private boolean open; private final List> messages = new ArrayList<>(); @@ -149,7 +152,7 @@ public class TestWebSocketSession implements WebSocketSession { } /** - * @param remoteAddress the remoteAddress to set + * @param localAddress the remoteAddress to set */ public void setLocalAddress(InetSocketAddress localAddress) { this.localAddress = localAddress; @@ -184,6 +187,18 @@ public class TestWebSocketSession implements WebSocketSession { this.protocol = protocol; } + /** + * @return the extensions + */ + @Override + public List getExtensions() { return this.extensions; } + + /** + * + * @param extensions the extensions to set + */ + public void setExtensions(List extensions) { this.extensions = extensions; } + /** * @return the open */ From 81dda069af4289395b85a6649b272b39372eb968 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 28 Oct 2013 22:37:01 -0400 Subject: [PATCH 2/2] Update WebSocket extensions change - add WebSocketHttpHeaders - client-side support for WebSocket extensions - DefaultHandshakeHandler updates - replace use of ServletAttributes in JettyRequestUpgradeStratey - upgrade spring-web to jetty 9.0.5 --- build.gradle | 6 +- .../org/springframework/http/HttpHeaders.java | 122 +------ .../WebRequestDataBinderIntegrationTests.java | 15 +- .../web/socket/WebSocketExtension.java | 167 --------- .../web/socket/WebSocketSession.java | 1 + .../socket/adapter/JettyWebSocketSession.java | 9 +- .../adapter/StandardWebSocketSession.java | 23 +- .../client/AbstractWebSocketClient.java | 43 ++- .../web/socket/client/WebSocketClient.java | 4 +- .../client/WebSocketConnectionManager.java | 3 +- .../endpoint/StandardWebSocketClient.java | 20 +- .../client/jetty/JettyWebSocketClient.java | 9 +- .../server/DefaultHandshakeHandler.java | 71 ++-- .../web/socket/server/HandshakeHandler.java | 3 +- .../socket/server/RequestUpgradeStrategy.java | 25 +- ...stractGlassFishRequestUpgradeStrategy.java | 36 +- .../AbstractStandardUpgradeStrategy.java | 55 ++- .../support/JettyRequestUpgradeStrategy.java | 91 ++++- .../support/TomcatRequestUpgradeStrategy.java | 29 +- .../session/AbstractHttpSockJsSession.java | 9 +- .../session/WebSocketServerSockJsSession.java | 2 +- .../socket/support/WebSocketExtension.java | 243 +++++++++++++ .../socket/support/WebSocketHttpHeaders.java | 328 ++++++++++++++++++ .../AbstractWebSocketIntegrationTests.java | 2 +- .../web/socket/WebSocketExtensionTests.java | 20 +- .../web/socket/WebSocketIntegrationTests.java | 173 +++++++++ .../WebSocketConnectionManagerTests.java | 5 +- .../StandardWebSocketClientTests.java | 5 +- .../jetty/JettyWebSocketClientTests.java | 3 +- .../server/DefaultHandshakeHandlerTests.java | 17 +- .../config/WebSocketConfigurationTests.java | 2 +- .../transport/session/TestSockJsSession.java | 43 +-- .../socket/support/TestWebSocketSession.java | 70 +--- .../support/WebSocketHttpHeadersTests.java | 54 +++ 34 files changed, 1112 insertions(+), 596 deletions(-) delete mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketExtension.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketHttpHeaders.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/support/WebSocketHttpHeadersTests.java diff --git a/build.gradle b/build.gradle index 8d14df6d261..f64da0967c7 100644 --- a/build.gradle +++ b/build.gradle @@ -558,15 +558,17 @@ project("spring-web") { optional("org.codehaus.jackson:jackson-mapper-asl:1.9.12") optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") optional("taglibs:standard:1.1.2") - optional("org.eclipse.jetty:jetty-servlet:8.1.5.v20120716") { + optional("org.eclipse.jetty:jetty-servlet:9.0.5.v20130815") { exclude group: "org.eclipse.jetty.orbit", module: "javax.servlet" } - optional("org.eclipse.jetty:jetty-server:8.1.5.v20120716") { + optional("org.eclipse.jetty:jetty-server:9.0.5.v20130815") { exclude group: "org.eclipse.jetty.orbit", module: "javax.servlet" } optional("log4j:log4j:1.2.17") testCompile(project(":spring-context-support")) // for JafMediaTypeFactory testCompile("xmlunit:xmlunit:1.3") + testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}") + testCompile("log4j:log4j:1.2.17") } // pick up ContextLoader.properties in src/main diff --git a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java index 136ff118686..f5391599d85 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java +++ b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java @@ -92,16 +92,6 @@ public class HttpHeaders implements MultiValueMap, Serializable private static final String ORIGIN = "Origin"; - private static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept"; - - private static final String SEC_WEBSOCKET_EXTENSIONS = "Sec-WebSocket-Extensions"; - - private static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key"; - - private static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; - - private static final String SEC_WEBSOCKET_VERSION = "Sec-WebSocket-Version"; - private static final String PRAGMA = "Pragma"; private static final String UPGARDE = "Upgrade"; @@ -452,7 +442,7 @@ public class HttpHeaders implements MultiValueMap, Serializable set(IF_NONE_MATCH, toCommaDelimitedString(ifNoneMatchList)); } - private String toCommaDelimitedString(List list) { + protected String toCommaDelimitedString(List list) { StringBuilder builder = new StringBuilder(); for (Iterator iterator = list.iterator(); iterator.hasNext();) { String ifNoneMatch = iterator.next(); @@ -472,7 +462,7 @@ public class HttpHeaders implements MultiValueMap, Serializable return getFirstValueAsList(IF_NONE_MATCH); } - private List getFirstValueAsList(String header) { + protected List getFirstValueAsList(String header) { List result = new ArrayList(); String value = getFirst(header); @@ -537,114 +527,6 @@ public class HttpHeaders implements MultiValueMap, Serializable return getFirst(ORIGIN); } - /** - * Sets the (new) value of the {@code Sec-WebSocket-Accept} header. - * @param secWebSocketAccept the value of the header - */ - public void setSecWebSocketAccept(String secWebSocketAccept) { - set(SEC_WEBSOCKET_ACCEPT, secWebSocketAccept); - } - - /** - * Returns the value of the {@code Sec-WebSocket-Accept} header. - * @return the value of the header - */ - public String getSecWebSocketAccept() { - return getFirst(SEC_WEBSOCKET_ACCEPT); - } - - /** - * Returns the value of the {@code Sec-WebSocket-Extensions} header. - * @return the value of the header - */ - public List getSecWebSocketExtensions() { - List values = get(SEC_WEBSOCKET_EXTENSIONS); - if (CollectionUtils.isEmpty(values)) { - return Collections.emptyList(); - } - else if (values.size() == 1) { - return getFirstValueAsList(SEC_WEBSOCKET_EXTENSIONS); - } - else { - return values; - } - } - - /** - * Sets the (new) value of the {@code Sec-WebSocket-Extensions} header. - * @param secWebSocketExtensions the value of the header - */ - public void setSecWebSocketExtensions(List secWebSocketExtensions) { - set(SEC_WEBSOCKET_EXTENSIONS, toCommaDelimitedString(secWebSocketExtensions)); - } - - /** - * Sets the (new) value of the {@code Sec-WebSocket-Key} header. - * @param secWebSocketKey the value of the header - */ - public void setSecWebSocketKey(String secWebSocketKey) { - set(SEC_WEBSOCKET_KEY, secWebSocketKey); - } - - /** - * Returns the value of the {@code Sec-WebSocket-Key} header. - * @return the value of the header - */ - public String getSecWebSocketKey() { - return getFirst(SEC_WEBSOCKET_KEY); - } - - /** - * Sets the (new) value of the {@code Sec-WebSocket-Protocol} header. - * @param secWebSocketProtocol the value of the header - */ - public void setSecWebSocketProtocol(String secWebSocketProtocol) { - if (secWebSocketProtocol != null) { - set(SEC_WEBSOCKET_PROTOCOL, secWebSocketProtocol); - } - } - - /** - * Sets the (new) value of the {@code Sec-WebSocket-Protocol} header. - * @param secWebSocketProtocols the value of the header - */ - public void setSecWebSocketProtocol(List secWebSocketProtocols) { - set(SEC_WEBSOCKET_PROTOCOL, toCommaDelimitedString(secWebSocketProtocols)); - } - - /** - * Returns the value of the {@code Sec-WebSocket-Key} header. - * @return the value of the header - */ - public List getSecWebSocketProtocol() { - List values = get(SEC_WEBSOCKET_PROTOCOL); - if (CollectionUtils.isEmpty(values)) { - return Collections.emptyList(); - } - else if (values.size() == 1) { - return getFirstValueAsList(SEC_WEBSOCKET_PROTOCOL); - } - else { - return values; - } - } - - /** - * Sets the (new) value of the {@code Sec-WebSocket-Version} header. - * @param secWebSocketKey the value of the header - */ - public void setSecWebSocketVersion(String secWebSocketVersion) { - set(SEC_WEBSOCKET_VERSION, secWebSocketVersion); - } - - /** - * Returns the value of the {@code Sec-WebSocket-Version} header. - * @return the value of the header - */ - public String getSecWebSocketVersion() { - return getFirst(SEC_WEBSOCKET_VERSION); - } - /** * Sets the (new) value of the {@code Pragma} header. * @param pragma the value of the header diff --git a/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java index 66e91dda57b..e859bfca3fd 100644 --- a/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java @@ -19,7 +19,9 @@ package org.springframework.web.bind.support; import java.io.IOException; import java.util.List; +import javax.servlet.MultipartConfigElement; import javax.servlet.ServletException; +import javax.servlet.annotation.MultipartConfig; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -79,8 +81,15 @@ public class WebRequestDataBinderIntegrationTests { partsServlet = new PartsServlet(); partListServlet = new PartListServlet(); - handler.addServlet(new ServletHolder(partsServlet), "/parts"); - handler.addServlet(new ServletHolder(partListServlet), "/partlist"); + MultipartConfigElement multipartConfig = new MultipartConfigElement(""); + + ServletHolder holder = new ServletHolder(partsServlet); + holder.getRegistration().setMultipartConfig(multipartConfig); + handler.addServlet(holder, "/parts"); + + holder = new ServletHolder(partListServlet); + holder.getRegistration().setMultipartConfig(multipartConfig); + handler.addServlet(holder, "/partlist"); jettyServer.setHandler(handler); jettyServer.start(); } @@ -178,7 +187,6 @@ public class WebRequestDataBinderIntegrationTests { @SuppressWarnings("serial") private static class PartsServlet extends AbstractStandardMultipartServlet { - } private static class PartListBean { @@ -197,7 +205,6 @@ public class WebRequestDataBinderIntegrationTests { @SuppressWarnings("serial") private static class PartListServlet extends AbstractStandardMultipartServlet { - } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java deleted file mode 100644 index 399a3eca810..00000000000 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.socket; - -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.LinkedCaseInsensitiveMap; -import org.springframework.util.StringUtils; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; - -/** - * WebSocket Protocol extension. - * Adds new protocol features to the WebSocket protocol; the extensions used - * within a session are negotiated during the handshake phase: - *

    - *
  • the client may ask for specific extensions in the HTTP request
  • - *
  • the server declares the final list of supported extensions for the current session in the HTTP response
  • - *
- * - *

WebSocket Extension HTTP headers may include parameters and follow - * RFC 2616 Section 4.2 - * specifications.

- * - *

Note that the order of extensions in HTTP headers defines their order of execution, - * e.g. extensions "foo, bar" will be executed as "bar(foo(message))".

- * - * @author Brian Clozel - * @since 4.0 - * @see - * WebSocket Protocol Extensions, RFC 6455 - Section 9 - */ -public class WebSocketExtension { - - private final String name; - - private final Map parameters; - - public WebSocketExtension(String name) { - this(name,null); - } - - public WebSocketExtension(String name, Map parameters) { - Assert.hasLength(name, "extension name must not be empty"); - this.name = name; - if (!CollectionUtils.isEmpty(parameters)) { - Map m = new LinkedCaseInsensitiveMap(parameters.size(), Locale.ENGLISH); - m.putAll(parameters); - this.parameters = Collections.unmodifiableMap(m); - } - else { - this.parameters = Collections.emptyMap(); - } - } - - /** - * @return the name of the extension - */ - public String getName() { - return this.name; - } - - /** - * @return the parameters of the extension - */ - public Map getParameters() { - return this.parameters; - } - - /** - * Parse a list of raw WebSocket extension headers - */ - public static List parseHeaders(List headers) { - if (headers == null || headers.isEmpty()) { - return Collections.emptyList(); - } - else { - List result = new ArrayList(headers.size()); - for (String header : headers) { - result.addAll(parseHeader(header)); - } - return result; - } - } - - /** - * Parse a raw WebSocket extension header - */ - public static List parseHeader(String header) { - if (header == null || !StringUtils.hasText(header)) { - return Collections.emptyList(); - } - else { - List result = new ArrayList(); - for(String token : header.split(",")) { - result.add(parse(token)); - } - return result; - } - } - - private static WebSocketExtension parse(String extension) { - Assert.doesNotContain(extension,",","this string contains multiple extension declarations"); - String[] parts = StringUtils.tokenizeToStringArray(extension, ";"); - String name = parts[0].trim(); - - Map parameters = null; - if (parts.length > 1) { - parameters = new LinkedHashMap(parts.length - 1); - for (int i = 1; i < parts.length; i++) { - String parameter = parts[i]; - int eqIndex = parameter.indexOf('='); - if (eqIndex != -1) { - String attribute = parameter.substring(0, eqIndex); - String value = parameter.substring(eqIndex + 1, parameter.length()); - parameters.put(attribute, value); - } - } - } - - return new WebSocketExtension(name,parameters); - } - - /** - * Convert a list of WebSocketExtensions to a list of String, - * which is convenient for native HTTP headers. - */ - public static List toStringList(List extensions) { - List result = new ArrayList(extensions.size()); - for(WebSocketExtension extension : extensions) { - result.add(extension.toString()); - } - return result; - } - - @Override - public String toString() { - StringBuilder str = new StringBuilder(); - str.append(this.name); - for (String param : parameters.keySet()) { - str.append(';'); - str.append(param); - str.append('='); - str.append(this.parameters.get(param)); - } - return str.toString(); - } -} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java index a4634ddc398..7d5d1a5e4de 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import org.springframework.http.HttpHeaders; +import org.springframework.web.socket.support.WebSocketExtension; /** * A WebSocket session abstraction. Allows sending messages over a WebSocket connection diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java index 62222e394ee..5e507f6d065 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java @@ -32,7 +32,7 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.PongMessage; import org.springframework.web.socket.TextMessage; -import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketSession; /** @@ -114,9 +114,10 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion getExtensions() { checkNativeSessionInitialized(); if(this.extensions == null) { - this.extensions = new ArrayList(); - for(ExtensionConfig ext : getNativeSession().getUpgradeResponse().getExtensions()) { - this.extensions.add(new WebSocketExtension(ext.getName(),ext.getParameters())); + List source = getNativeSession().getUpgradeResponse().getExtensions(); + this.extensions = new ArrayList(source.size()); + for(ExtensionConfig e : source) { + this.extensions.add(new WebSocketExtension(e.getName(), e.getParameters())); } } return this.extensions; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java index 6a9fa9be8c5..d01d33b9fd0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java @@ -36,8 +36,9 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.PongMessage; import org.springframework.web.socket.TextMessage; -import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.support.WebSocketHttpHeaders; /** * A {@link WebSocketSession} for use with the standard WebSocket for Java API. @@ -59,18 +60,18 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion handshakeAttributes, + public StandardWebSocketSession(HttpHeaders headers, Map handshakeAttributes, InetSocketAddress localAddress, InetSocketAddress remoteAddress) { super(handshakeAttributes); - handshakeHeaders = (handshakeHeaders != null) ? handshakeHeaders : new HttpHeaders(); - this.handshakeHeaders = HttpHeaders.readOnlyHttpHeaders(handshakeHeaders); + headers = (headers != null) ? headers : new HttpHeaders(); + this.handshakeHeaders = HttpHeaders.readOnlyHttpHeaders(headers); this.localAddress = localAddress; this.remoteAddress = remoteAddress; } @@ -119,14 +120,10 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion getExtensions() { checkNativeSessionInitialized(); if(this.extensions == null) { - List nativeExtensions = getNativeSession().getNegotiatedExtensions(); - this.extensions = new ArrayList(nativeExtensions.size()); - for(Extension nativeExtension : nativeExtensions) { - Map parameters = new HashMap(); - for (Extension.Parameter param : nativeExtension.getParameters()) { - parameters.put(param.getName(),param.getValue()); - } - this.extensions.add(new WebSocketExtension(nativeExtension.getName(),parameters)); + List source = getNativeSession().getNegotiatedExtensions(); + this.extensions = new ArrayList(source.size()); + for(Extension e : source) { + this.extensions.add(new WebSocketExtension.StandardToWebSocketExtensionAdapter(e)); } } return this.extensions; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java index ca8faaa7881..5c01e61cf07 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java @@ -31,6 +31,8 @@ import org.springframework.util.Assert; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.support.WebSocketExtension; +import org.springframework.web.socket.support.WebSocketHttpHeaders; import org.springframework.web.util.UriComponentsBuilder; @@ -44,19 +46,19 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { protected final Log logger = LogFactory.getLog(getClass()); - private static final Set disallowedHeaders = new HashSet(); + private static final Set specialHeaders = new HashSet(); static { - disallowedHeaders.add("cache-control"); - disallowedHeaders.add("cookie"); - disallowedHeaders.add("connection"); - disallowedHeaders.add("host"); - disallowedHeaders.add("sec-websocket-extensions"); - disallowedHeaders.add("sec-websocket-key"); - disallowedHeaders.add("sec-websocket-protocol"); - disallowedHeaders.add("sec-websocket-version"); - disallowedHeaders.add("pragma"); - disallowedHeaders.add("upgrade"); + specialHeaders.add("cache-control"); + specialHeaders.add("cookie"); + specialHeaders.add("connection"); + specialHeaders.add("host"); + specialHeaders.add("sec-websocket-extensions"); + specialHeaders.add("sec-websocket-key"); + specialHeaders.add("sec-websocket-protocol"); + specialHeaders.add("sec-websocket-version"); + specialHeaders.add("pragma"); + specialHeaders.add("upgrade"); } @@ -71,7 +73,7 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { @Override public final ListenableFuture doHandshake(WebSocketHandler webSocketHandler, - HttpHeaders headers, URI uri) { + WebSocketHttpHeaders headers, URI uri) { Assert.notNull(webSocketHandler, "webSocketHandler must not be null"); Assert.notNull(uri, "uri must not be null"); @@ -86,18 +88,19 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { HttpHeaders headersToUse = new HttpHeaders(); if (headers != null) { for (String header : headers.keySet()) { - if (!disallowedHeaders.contains(header.toLowerCase())) { + if (!specialHeaders.contains(header.toLowerCase())) { headersToUse.put(header, headers.get(header)); } } } - List subProtocols = new ArrayList(); - if ((headers != null) && (headers.getSecWebSocketProtocol() != null)) { - subProtocols.addAll(headers.getSecWebSocketProtocol()); - } + List subProtocols = ((headers != null) && (headers.getSecWebSocketProtocol() != null)) ? + headers.getSecWebSocketProtocol() : Collections.emptyList(); + + List extensions = ((headers != null) && (headers.getSecWebSocketExtensions() != null)) ? + headers.getSecWebSocketExtensions() : Collections.emptyList(); - return doHandshakeInternal(webSocketHandler, headersToUse, uri, subProtocols, + return doHandshakeInternal(webSocketHandler, headersToUse, uri, subProtocols, extensions, Collections.emptyMap()); } @@ -109,12 +112,14 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { * headers filtered out, never {@code null} * @param uri the target URI for the handshake, never {@code null} * @param subProtocols requested sub-protocols, or an empty list + * @param extensions requested WebSocket extensions, or an empty list * @param handshakeAttributes attributes to make available via * {@link WebSocketSession#getHandshakeAttributes()}; currently always an empty map. * * @return the established WebSocket session wrapped in a ListenableFuture. */ protected abstract ListenableFuture doHandshakeInternal(WebSocketHandler webSocketHandler, - HttpHeaders headers, URI uri, List subProtocols, Map handshakeAttributes); + HttpHeaders headers, URI uri, List subProtocols, List extensions, + Map handshakeAttributes); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketClient.java index 8711be594d9..ad6dfa6ce65 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketClient.java @@ -22,6 +22,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.support.WebSocketHttpHeaders; /** * Contract for initiating a WebSocket request. As an alternative considering using the @@ -38,6 +39,7 @@ public interface WebSocketClient { ListenableFuture doHandshake(WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVariables); - ListenableFuture doHandshake(WebSocketHandler webSocketHandler, HttpHeaders headers, URI uri); + ListenableFuture doHandshake(WebSocketHandler webSocketHandler, + WebSocketHttpHeaders headers, URI uri); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java index c7d1aded1a4..7a115312e26 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java @@ -25,6 +25,7 @@ import org.springframework.util.concurrent.ListenableFutureCallback; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator; +import org.springframework.web.socket.support.WebSocketHttpHeaders; /** * A WebSocket connection manager that is given a URI, a {@link WebSocketClient}, and a @@ -43,7 +44,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport { private WebSocketSession webSocketSession; - private HttpHeaders headers = new HttpHeaders(); + private WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); private final boolean syncClientLifecycle; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java index bcbebb1c007..55421c5866b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java @@ -20,17 +20,14 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.URI; import java.net.UnknownHostException; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.concurrent.Callable; -import javax.websocket.ClientEndpointConfig; +import javax.websocket.*; import javax.websocket.ClientEndpointConfig.Configurator; -import javax.websocket.ContainerProvider; -import javax.websocket.Endpoint; -import javax.websocket.HandshakeResponse; -import javax.websocket.WebSocketContainer; import org.springframework.core.task.AsyncListenableTaskExecutor; import org.springframework.core.task.SimpleAsyncTaskExecutor; @@ -43,6 +40,7 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.StandardWebSocketSession; import org.springframework.web.socket.client.AbstractWebSocketClient; +import org.springframework.web.socket.support.WebSocketExtension; /** * Initiates WebSocket requests to a WebSocket server programatically through the standard @@ -96,7 +94,8 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { @Override protected ListenableFuture doHandshakeInternal(WebSocketHandler webSocketHandler, - HttpHeaders headers, final URI uri, List protocols, Map handshakeAttributes) { + HttpHeaders headers, final URI uri, List protocols, + List extensions, Map handshakeAttributes) { int port = getPort(uri); InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port); @@ -108,6 +107,7 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { final ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create(); configBuidler.configurator(new StandardWebSocketClientConfigurator(headers)); configBuidler.preferredSubprotocols(protocols); + configBuidler.extensions(adaptExtensions(extensions)); final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session); return this.taskExecutor.submitListenable(new Callable() { @@ -119,6 +119,14 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { }); } + private static List adaptExtensions(List extensions) { + List result = new ArrayList(); + for (WebSocketExtension e : extensions) { + result.add(new WebSocketExtension.WebSocketToStandardExtensionAdapter(e)); + } + return result; + } + private InetAddress getLocalHost() { try { return InetAddress.getLocalHost(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java index 5ed8e8914d5..d76ec05d0a7 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java @@ -38,6 +38,7 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSession; import org.springframework.web.socket.client.AbstractWebSocketClient; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; @@ -169,10 +170,16 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma @Override public ListenableFuture doHandshakeInternal(WebSocketHandler wsHandler, - HttpHeaders headers, final URI uri, List protocols, Map handshakeAttributes) { + HttpHeaders headers, final URI uri, List protocols, + List extensions, Map handshakeAttributes) { final ClientUpgradeRequest request = new ClientUpgradeRequest(); request.setSubProtocols(protocols); + + for (WebSocketExtension e : extensions) { + request.addExtensions(new WebSocketExtension.WebSocketToJettyExtensionConfigAdapter(e)); + } + for (String header : headers.keySet()) { request.setHeader(header, headers.get(header)); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java index f4b02b3a01d..a5a8338b1a5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java @@ -32,8 +32,9 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; -import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.support.WebSocketHttpHeaders; /** * A default {@link HandshakeHandler} implementation. Performs initial validation of the @@ -144,8 +145,10 @@ public class DefaultHandshakeHandler implements HandshakeHandler { public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHeaders()); + if (logger.isDebugEnabled()) { - logger.debug("Initiating handshake for " + request.getURI() + ", headers=" + request.getHeaders()); + logger.debug("Initiating handshake for " + request.getURI() + ", headers=" + headers); } try { @@ -155,16 +158,16 @@ public class DefaultHandshakeHandler implements HandshakeHandler { logger.debug("Only HTTP GET is allowed, current method is " + request.getMethod()); return false; } - if (!"WebSocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) { + if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { handleInvalidUpgradeHeader(request, response); return false; } - if (!request.getHeaders().getConnection().contains("Upgrade") && - !request.getHeaders().getConnection().contains("upgrade")) { + if (!headers.getConnection().contains("Upgrade") && + !headers.getConnection().contains("upgrade")) { handleInvalidConnectHeader(request, response); return false; } - if (!isWebSocketVersionSupported(request)) { + if (!isWebSocketVersionSupported(headers)) { handleWebSocketVersionNotSupported(request, response); return false; } @@ -172,7 +175,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { response.setStatusCode(HttpStatus.FORBIDDEN); return false; } - String wsKey = request.getHeaders().getSecWebSocketKey(); + String wsKey = headers.getSecWebSocketKey(); if (wsKey == null) { logger.debug("Missing \"Sec-WebSocket-Key\" header"); response.setStatusCode(HttpStatus.BAD_REQUEST); @@ -184,20 +187,17 @@ public class DefaultHandshakeHandler implements HandshakeHandler { "Response update failed during upgrade to WebSocket, uri=" + request.getURI(), ex); } - String subProtocol = selectProtocol(request.getHeaders().getSecWebSocketProtocol()); + String subProtocol = selectProtocol(headers.getSecWebSocketProtocol()); + + List requested = headers.getSecWebSocketExtensions(); + List supported = this.requestUpgradeStrategy.getSupportedExtensions(request); + List extensions = filterRequestedExtensions(request, requested, supported); if (logger.isDebugEnabled()) { - logger.debug("Upgrading request, sub-protocol=" + subProtocol); + logger.debug("Upgrading request, sub-protocol=" + subProtocol + ", extensions=" + extensions); } - List requestedExtensions = WebSocketExtension - .parseHeaders(request.getHeaders().getSecWebSocketExtensions()); - - List filteredExtensions = filterRequestedExtensions(requestedExtensions, - this.requestUpgradeStrategy.getAvailableExtensions(request)); - request.getHeaders().setSecWebSocketExtensions(WebSocketExtension.toStringList(filteredExtensions)); - - this.requestUpgradeStrategy.upgrade(request, response, subProtocol, wsHandler, attributes); + this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, wsHandler, attributes); return true; } @@ -214,8 +214,8 @@ public class DefaultHandshakeHandler implements HandshakeHandler { response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes("UTF-8")); } - protected boolean isWebSocketVersionSupported(ServerHttpRequest request) { - String version = request.getHeaders().getSecWebSocketVersion(); + protected boolean isWebSocketVersionSupported(WebSocketHttpHeaders httpHeaders) { + String version = httpHeaders.getSecWebSocketVersion(); String[] supportedVersions = getSupportedVerions(); if (logger.isDebugEnabled()) { logger.debug("Requested version=" + version + ", supported=" + Arrays.toString(supportedVersions)); @@ -238,7 +238,8 @@ public class DefaultHandshakeHandler implements HandshakeHandler { protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) { logger.debug("WebSocket version not supported " + request.getHeaders().get("Sec-WebSocket-Version")); response.setStatusCode(HttpStatus.UPGRADE_REQUIRED); - response.getHeaders().setSecWebSocketVersion(StringUtils.arrayToCommaDelimitedString(getSupportedVerions())); + response.getHeaders().put(WebSocketHttpHeaders.SEC_WEBSOCKET_VERSION, Arrays.asList( + StringUtils.arrayToCommaDelimitedString(getSupportedVerions()))); } protected boolean isValidOrigin(ServerHttpRequest request) { @@ -264,26 +265,26 @@ public class DefaultHandshakeHandler implements HandshakeHandler { } /** - * Filter the list of WebSocket Extensions requested by the client. - * Since the negotiation process happens during the upgrade phase within the server - * implementation, one can customize the applied extensions only by filtering the - * requested extensions by the client. + * Filter the list of requested WebSocket extensions. + *

+ * By default all request extensions are returned. The WebSocket server will further + * compare the requested extensions against the list of supported extensions and + * return only the ones that are both requested and supported. + * + * @param request the current request + * @param requested the list of extensions requested by the client + * @param supported the list of extensions supported by the server * - *

The default implementation of this method doesn't filter any of the extensions - * requested by the client. - * @param requestedExtensions the list of extensions requested by the client - * @param supportedExtensions the list of extensions supported by the server - * @return the filtered list of requested extensions + * @return the selected extensions or an empty list */ - protected List filterRequestedExtensions(List requestedExtensions, - List supportedExtensions) { + protected List filterRequestedExtensions(ServerHttpRequest request, + List requested, List supported) { - if (requestedExtensions != null) { + if (requested != null) { if (logger.isDebugEnabled()) { - logger.debug("Requested extension(s): " + requestedExtensions - + ", supported extension(s): " + supportedExtensions); + logger.debug("Requested extension(s): " + requested + ", supported extension(s): " + supported); } } - return requestedExtensions; + return requested; } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java index f2576f6559d..be0ad444e9a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java @@ -45,7 +45,8 @@ public interface HandshakeHandler { * {@link PerConnectionWebSocketHandler} for providing a handler with * per-connection lifecycle. * @param attributes handshake request specific attributes to be set on the WebSocket - * session and thus made available to the {@link WebSocketHandler} + * session via {@link HandshakeInterceptor} and thus made available to the + * {@link WebSocketHandler}; * * @return whether the handshake negotiation was successful or not. In either case the * response status, headers, and body will have been updated to reflect the diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java index 89b197fe38c..f3e950f2212 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java @@ -21,7 +21,7 @@ import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; -import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; /** @@ -39,27 +39,32 @@ public interface RequestUpgradeStrategy { String[] getSupportedVersions(); /** - * @return the list of available WebSocket protocol extensions, - * implemented by the underlying WebSocket server. + * @return the WebSocket protocol extensions supported by the underlying WebSocket server. */ - List getAvailableExtensions(ServerHttpRequest request); + List getSupportedExtensions(ServerHttpRequest request); /** * Perform runtime specific steps to complete the upgrade. Invoked after successful * negotiation of the handshake request. * + * * @param request the current request * @param response the current response - * @param acceptedProtocol the accepted sub-protocol, if any + * @param selectedProtocol the selected sub-protocol, if any + * @param selectedExtensions the selected WebSocket protocol extensions * @param wsHandler the handler for WebSocket messages - * @param attributes handshake context attributes + * @param attributes handshake request specific attributes to be set on the WebSocket + * session via {@link org.springframework.web.socket.server.HandshakeInterceptor} + * and thus made available to the + * {@link org.springframework.web.socket.WebSocketHandler}; * * @throws HandshakeFailureException thrown when handshake processing failed to - * complete due to an internal, unrecoverable error, i.e. a server error as - * opposed to a failure to successfully negotiate the requirements of the - * handshake request. + * complete due to an internal, unrecoverable error, i.e. a server error as + * opposed to a failure to successfully negotiate the requirements of the + * handshake request. */ - void upgrade(ServerHttpRequest request, ServerHttpResponse response, String acceptedProtocol, + void upgrade(ServerHttpRequest request, ServerHttpResponse response, + String selectedProtocol, List selectedExtensions, WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java index 7349615f500..392eb2a2dd5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java @@ -19,21 +19,17 @@ package org.springframework.web.socket.server.support; import java.io.IOException; import java.lang.reflect.Constructor; import java.net.URI; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; -import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.DeploymentException; import javax.websocket.Endpoint; import javax.websocket.Extension; -import javax.websocket.server.ServerContainer; -import org.apache.tomcat.websocket.server.WsServerContainer; import org.glassfish.tyrus.core.ComponentProviderService; import org.glassfish.tyrus.core.EndpointWrapper; import org.glassfish.tyrus.core.ErrorCollector; @@ -53,7 +49,6 @@ import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; -import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration; import org.springframework.web.socket.server.endpoint.ServletServerContainerFactoryBean; @@ -73,29 +68,16 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt private final static Random random = new Random(); - private List availableExtensions; @Override public String[] getSupportedVersions() { return StringUtils.commaDelimitedListToStringArray(Version.getSupportedWireProtocolVersions()); } - @Override - public List getAvailableExtensions(ServerHttpRequest request) { - - if(this.availableExtensions == null) { - this.availableExtensions = new ArrayList(); - HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); - for(Extension extension : getContainer(servletRequest).getInstalledExtensions()) { - this.availableExtensions.add(parseStandardExtension(extension)); - } - } - return this.availableExtensions; - } - @Override public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, - String selectedProtocol, Endpoint endpoint) throws HandshakeFailureException { + String selectedProtocol, List selectedExtensions, + Endpoint endpoint) throws HandshakeFailureException { Assert.isTrue(request instanceof ServletServerHttpRequest); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -103,7 +85,9 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt Assert.isTrue(response instanceof ServletServerHttpResponse); HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse(); - WebSocketApplication webSocketApplication = createTyrusEndpoint(servletRequest, endpoint, selectedProtocol); + WebSocketApplication webSocketApplication = createTyrusEndpoint(servletRequest, + endpoint, selectedProtocol, selectedExtensions); + WebSocketEngine webSocketEngine = WebSocketEngine.getEngine(); try { @@ -125,13 +109,6 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt } } - public ServerContainer getContainer(HttpServletRequest servletRequest) { - - String attributeName = "javax.websocket.server.ServerContainer"; - ServletContext servletContext = servletRequest.getServletContext(); - return (ServerContainer)servletContext.getAttribute(attributeName); - } - private boolean performUpgrade(HttpServletRequest request, HttpServletResponse response, HttpHeaders headers, WebSocketApplication wsApp) throws IOException { @@ -162,12 +139,13 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt } private WebSocketApplication createTyrusEndpoint(HttpServletRequest request, - Endpoint endpoint, String selectedProtocol) { + Endpoint endpoint, String selectedProtocol, List selectedExtensions) { // shouldn't matter for processing but must be unique String endpointPath = "/" + random.nextLong(); ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(endpointPath, endpoint); endpointConfig.setSubprotocols(Arrays.asList(selectedProtocol)); + endpointConfig.setExtensions(selectedExtensions); return createTyrusEndpoint(new EndpointWrapper(endpoint, endpointConfig, ComponentProviderService.create(), null, "/", new ErrorCollector(), endpointConfig.getConfigurator())); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java index 0f62add3587..2bbcec5f4a8 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java @@ -17,18 +17,22 @@ package org.springframework.web.socket.server.support; import java.net.InetSocketAddress; -import java.util.HashMap; -import java.util.Map; +import java.util.*; +import javax.servlet.ServletContext; +import javax.servlet.http.HttpServletRequest; import javax.websocket.Endpoint; import javax.websocket.Extension; +import javax.websocket.WebSocketContainer; +import javax.websocket.server.ServerContainer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; -import org.springframework.web.socket.WebSocketExtension; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.StandardWebSocketSession; @@ -46,9 +50,34 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS protected final Log logger = LogFactory.getLog(getClass()); + private volatile List extensions; + @Override - public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String acceptedProtocol, + public List getSupportedExtensions(ServerHttpRequest request) { + if(this.extensions == null) { + HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); + this.extensions = getInstalledExtensions(getContainer(servletRequest)); + } + return this.extensions; + } + + protected ServerContainer getContainer(HttpServletRequest request) { + ServletContext servletContext = request.getServletContext(); + return (ServerContainer) servletContext.getAttribute("javax.websocket.server.ServerContainer"); + } + + protected List getInstalledExtensions(WebSocketContainer container) { + List result = new ArrayList(); + for(Extension e : container.getInstalledExtensions()) { + result.add(new WebSocketExtension.StandardToWebSocketExtensionAdapter(e)); + } + return result; + } + + @Override + public void upgrade(ServerHttpRequest request, ServerHttpResponse response, + String selectedProtocol, List selectedExtensions, WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { HttpHeaders headers = request.getHeaders(); @@ -59,18 +88,16 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS StandardWebSocketSession session = new StandardWebSocketSession(headers, attributes, localAddr, remoteAddr); StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, session); - upgradeInternal(request, response, acceptedProtocol, endpoint); + List extensions = new ArrayList(); + for (WebSocketExtension e : selectedExtensions) { + extensions.add(new WebSocketExtension.WebSocketToStandardExtensionAdapter(e)); + } + + upgradeInternal(request, response, selectedProtocol, extensions, endpoint); } protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, - String selectedProtocol, Endpoint endpoint) throws HandshakeFailureException; - - protected WebSocketExtension parseStandardExtension(Extension extension) { - Map params = new HashMap(extension.getParameters().size()); - for(Extension.Parameter param : extension.getParameters()) { - params.put(param.getName(),param.getValue()); - } - return new WebSocketExtension(extension.getName(),params); - } + String selectedProtocol, List selectedExtensions, + Endpoint endpoint) throws HandshakeFailureException; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java index 3064adb417e..a2dfcdd537b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java @@ -27,16 +27,18 @@ import javax.servlet.http.HttpServletResponse; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeResponse; import org.eclipse.jetty.websocket.api.WebSocketPolicy; +import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig; import org.eclipse.jetty.websocket.server.HandshakeRFC6455; import org.eclipse.jetty.websocket.server.WebSocketServerFactory; -import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; import org.eclipse.jetty.websocket.servlet.WebSocketCreator; +import org.springframework.core.NamedThreadLocal; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; -import org.springframework.web.socket.WebSocketExtension; +import org.springframework.util.CollectionUtils; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSession; @@ -53,11 +55,13 @@ import org.springframework.web.socket.server.RequestUpgradeStrategy; */ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { - private static final String WS_HANDLER_ATTR_NAME = JettyRequestUpgradeStrategy.class.getName() + ".WS_LISTENER"; + private static final ThreadLocal wsContainerHolder = + new NamedThreadLocal("WebSocket Handler Container"); + private WebSocketServerFactory factory; - private List availableExtensions; + private volatile List supportedExtensions; /** @@ -79,8 +83,14 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { this.factory.setCreator(new WebSocketCreator() { @Override public Object createWebSocket(UpgradeRequest request, UpgradeResponse response) { - Assert.isInstanceOf(ServletUpgradeRequest.class, request); - return ((ServletUpgradeRequest) request).getServletAttributes().get(WS_HANDLER_ATTR_NAME); + + WebSocketHandlerContainer container = wsContainerHolder.get(); + Assert.state(container != null, "Expected WebSocketHandlerContainer"); + + response.setAcceptedSubProtocol(container.getSelectedProtocol()); + response.setExtensions(container.getExtensionConfigs()); + + return container.getHandler(); } }); try { @@ -98,19 +108,25 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { } @Override - public List getAvailableExtensions(ServerHttpRequest request) { - if(this.availableExtensions == null) { - this.availableExtensions = new ArrayList(); - for(String extensionName : this.factory.getExtensionFactory().getExtensionNames()) { - this.availableExtensions.add(new WebSocketExtension(extensionName)); - } + public List getSupportedExtensions(ServerHttpRequest request) { + if (this.supportedExtensions == null) { + this.supportedExtensions = getWebSocketExtensions(); } - return this.availableExtensions; + return this.supportedExtensions; + } + + private List getWebSocketExtensions() { + List result = new ArrayList(); + for(String name : this.factory.getExtensionFactory().getExtensionNames()) { + result.add(new WebSocketExtension(name)); + } + return result; } @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, - String protocol, WebSocketHandler wsHandler, Map attrs) throws HandshakeFailureException { + String selectedProtocol, List selectedExtensions, + WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { Assert.isInstanceOf(ServletServerHttpRequest.class, request); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -120,11 +136,14 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { Assert.isTrue(this.factory.isUpgradeRequest(servletRequest, servletResponse), "Not a WebSocket handshake"); - JettyWebSocketSession session = new JettyWebSocketSession(request.getPrincipal(), attrs); + JettyWebSocketSession session = new JettyWebSocketSession(request.getPrincipal(), attributes); JettyWebSocketHandlerAdapter handlerAdapter = new JettyWebSocketHandlerAdapter(wsHandler, session); + WebSocketHandlerContainer container = + new WebSocketHandlerContainer(handlerAdapter, selectedProtocol, selectedExtensions); + try { - servletRequest.setAttribute(WS_HANDLER_ATTR_NAME, handlerAdapter); + wsContainerHolder.set(container); this.factory.acceptWebSocket(servletRequest, servletResponse); } catch (IOException ex) { @@ -132,8 +151,46 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { "Response update failed during upgrade to WebSocket, uri=" + request.getURI(), ex); } finally { - servletRequest.removeAttribute(WS_HANDLER_ATTR_NAME); + wsContainerHolder.remove(); } } + + private static class WebSocketHandlerContainer { + + private final JettyWebSocketHandlerAdapter handler; + + private final String selectedProtocol; + + private final List extensionConfigs; + + private WebSocketHandlerContainer(JettyWebSocketHandlerAdapter handler, String protocol, + List extensions) { + + this.handler = handler; + this.selectedProtocol = protocol; + + if (CollectionUtils.isEmpty(extensions)) { + this.extensionConfigs = null; + } + else { + this.extensionConfigs = new ArrayList(); + for (WebSocketExtension e : extensions) { + this.extensionConfigs.add(new WebSocketExtension.WebSocketToJettyExtensionConfigAdapter(e)); + } + } + } + + private JettyWebSocketHandlerAdapter getHandler() { + return this.handler; + } + + private String getSelectedProtocol() { + return this.selectedProtocol; + } + + private List getExtensionConfigs() { + return this.extensionConfigs; + } + } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java index 7a5608ae8b9..661425ae3d3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java @@ -17,13 +17,11 @@ package org.springframework.web.socket.server.support; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; -import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -36,7 +34,6 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; -import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration; import org.springframework.web.socket.server.endpoint.ServletServerContainerFactoryBean; @@ -54,29 +51,16 @@ import org.springframework.web.socket.server.endpoint.ServletServerContainerFact */ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy { - private List availableExtensions; @Override public String[] getSupportedVersions() { return new String[] { "13" }; } - @Override - public List getAvailableExtensions(ServerHttpRequest request) { - - if(this.availableExtensions == null) { - this.availableExtensions = new ArrayList(); - HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); - for(Extension extension : getContainer(servletRequest).getInstalledExtensions()) { - this.availableExtensions.add(parseStandardExtension(extension)); - } - } - return this.availableExtensions; - } - @Override public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, - String acceptedProtocol, Endpoint endpoint) throws HandshakeFailureException { + String selectedProtocol, List selectedExtensions, + Endpoint endpoint) throws HandshakeFailureException { Assert.isTrue(request instanceof ServletServerHttpRequest); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -89,7 +73,8 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg Map pathParams = Collections. emptyMap(); ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(path, endpoint); - endpointConfig.setSubprotocols(Arrays.asList(acceptedProtocol)); + endpointConfig.setSubprotocols(Arrays.asList(selectedProtocol)); + endpointConfig.setExtensions(selectedExtensions); try { getContainer(servletRequest).doUpgrade(servletRequest, servletResponse, endpointConfig, pathParams); @@ -104,10 +89,8 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg } } - public WsServerContainer getContainer(HttpServletRequest servletRequest) { - String attribute = "javax.websocket.server.ServerContainer"; - ServletContext servletContext = servletRequest.getServletContext(); - return (WsServerContainer) servletContext.getAttribute(attribute); + public WsServerContainer getContainer(HttpServletRequest request) { + return (WsServerContainer) super.getContainer(request); } } \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index 474d94dcd2f..7373dc05228 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; @@ -31,7 +32,7 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; -import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; @@ -70,6 +71,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private List extensions; + public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler, Map handshakeAttributes) { @@ -120,7 +122,9 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { } @Override - public List getExtensions() { return this.extensions; } + public List getExtensions() { + return Collections.emptyList(); + } /** * Unlike WebSocket where sub-protocol negotiation is part of the @@ -158,7 +162,6 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { this.principal = request.getPrincipal(); this.localAddress = request.getLocalAddress(); this.remoteAddress = request.getRemoteAddress(); - this.extensions = WebSocketExtension.parseHeaders(response.getHeaders().getSecWebSocketExtensions()); try { delegateConnectionEstablished(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java index e66af40e7e0..b5137f91a47 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java @@ -28,7 +28,7 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; -import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.NativeWebSocketSession; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketExtension.java b/spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketExtension.java new file mode 100644 index 00000000000..6a474313fe4 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketExtension.java @@ -0,0 +1,243 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.support; + +import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.StringUtils; + +import javax.websocket.Extension; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * Represents a WebSocket extension as defined in the RFC 6455. + * WebSocket extensions add protocol features to the WebSocket protocol. The extensions + * used within a session are negotiated during the handshake phase as follows: + *

    + *
  • the client may ask for specific extensions in the HTTP handshake request
  • + *
  • the server responds with the final list of extensions to use in the current session
  • + *
+ * + *

WebSocket Extension HTTP headers may include parameters and follow + * RFC 2616 Section 4.2

+ * + *

Note that the order of extensions in HTTP headers defines their order of execution, + * e.g. extensions "foo, bar" will be executed as "bar(foo(message))".

+ * + * @author Brian Clozel + * @since 4.0 + * @see + * WebSocket Protocol Extensions, RFC 6455 - Section 9 + */ +public class WebSocketExtension { + + private final String name; + + private final Map parameters; + + + /** + * Create a WebSocketExtension with the given name. + * + * @param name the name of the extension + */ + public WebSocketExtension(String name) { + this(name, null); + } + + /** + * Create a WebSocketExtension with the given name and parameters. + * + * @param name the name of the extension + * @param parameters the parameters + */ + public WebSocketExtension(String name, Map parameters) { + Assert.hasLength(name, "extension name must not be empty"); + this.name = name; + if (!CollectionUtils.isEmpty(parameters)) { + Map m = new LinkedCaseInsensitiveMap(parameters.size(), Locale.ENGLISH); + m.putAll(parameters); + this.parameters = Collections.unmodifiableMap(m); + } + else { + this.parameters = Collections.emptyMap(); + } + } + + /** + * @return the name of the extension + */ + public String getName() { + return this.name; + } + + /** + * @return the parameters of the extension, never {@code null} + */ + public Map getParameters() { + return this.parameters; + } + + /** + * Parse the given, comma-separated string into a list of {@code WebSocketExtension} objects. + *

This method can be used to parse a "Sec-WebSocket-Extension" extensions. + * @param extensions the string to parse + * @return the list of extensions + * @throws IllegalArgumentException if the string cannot be parsed + */ + public static List parseExtensions(String extensions) { + if (extensions == null || !StringUtils.hasText(extensions)) { + return Collections.emptyList(); + } + else { + List result = new ArrayList(); + for(String token : extensions.split(",")) { + result.add(parseExtension(token)); + } + return result; + } + } + + private static WebSocketExtension parseExtension(String extension) { + Assert.doesNotContain(extension, ",", "Expected a single extension value: " + extension); + String[] parts = StringUtils.tokenizeToStringArray(extension, ";"); + String name = parts[0].trim(); + + Map parameters = null; + if (parts.length > 1) { + parameters = new LinkedHashMap(parts.length - 1); + for (int i = 1; i < parts.length; i++) { + String parameter = parts[i]; + int eqIndex = parameter.indexOf('='); + if (eqIndex != -1) { + String attribute = parameter.substring(0, eqIndex); + String value = parameter.substring(eqIndex + 1, parameter.length()); + parameters.put(attribute, value); + } + } + } + + return new WebSocketExtension(name, parameters); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if ((o == null) || (getClass() != o.getClass())) { + return false; + } + WebSocketExtension that = (WebSocketExtension) o; + if (!name.equals(that.name)) { + return false; + } + if (!parameters.equals(that.parameters)) { + return false; + } + return true; + } + + @Override + public int hashCode() { + int result = name.hashCode(); + result = 31 * result + parameters.hashCode(); + return result; + } + + @Override + public String toString() { + StringBuilder str = new StringBuilder(); + str.append(this.name); + for (String param : parameters.keySet()) { + str.append(';'); + str.append(param); + str.append('='); + str.append(this.parameters.get(param)); + } + return str.toString(); + } + + + // Standard WebSocketExtension adapters + + public static class StandardToWebSocketExtensionAdapter extends WebSocketExtension { + + public StandardToWebSocketExtensionAdapter(Extension ext) { + super(ext.getName()); + for (Extension.Parameter p : ext.getParameters()) { + super.getParameters().put(p.getName(), p.getValue()); + } + } + } + + public static class WebSocketToStandardExtensionAdapter implements Extension { + + private final String name; + + private final List parameters = new ArrayList(); + + public WebSocketToStandardExtensionAdapter(final WebSocketExtension ext) { + this.name = ext.getName(); + List params = new ArrayList(); + for (final String paramName : ext.getParameters().keySet()) { + this.parameters.add(new Parameter() { + @Override + public String getName() { + return paramName; + } + @Override + public String getValue() { + return ext.getParameters().get(paramName); + } + }); + } + } + + @Override + public String getName() { + return name; + } + + @Override + public List getParameters() { + return this.parameters; + } + } + + // Jetty WebSocketExtension adapters + + public static class WebSocketToJettyExtensionConfigAdapter extends ExtensionConfig { + + public WebSocketToJettyExtensionConfigAdapter(WebSocketExtension extension) { + super(extension.getName()); + for (Map.Entry p : extension.getParameters().entrySet()) { + super.setParameter(p.getKey(), p.getValue()); + } + } + } + + + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketHttpHeaders.java b/spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketHttpHeaders.java new file mode 100644 index 00000000000..c6a8e5a11e1 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/support/WebSocketHttpHeaders.java @@ -0,0 +1,328 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.support; + +import org.springframework.http.HttpHeaders; +import org.springframework.util.CollectionUtils; + +import java.io.Serializable; +import java.util.*; + +/** + * An {@link org.springframework.http.HttpHeaders} variant that adds support for + * the HTTP headers defined by the WebSocket specification RFC 6455. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class WebSocketHttpHeaders extends HttpHeaders { + + public static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept"; + + public static final String SEC_WEBSOCKET_EXTENSIONS = "Sec-WebSocket-Extensions"; + + public static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key"; + + public static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; + + public static final String SEC_WEBSOCKET_VERSION = "Sec-WebSocket-Version"; + + private final HttpHeaders headers; + + + + + /** + * Create a new instance. + */ + public WebSocketHttpHeaders() { + this(new HttpHeaders(), false); + } + + /** + * Create an instance that wraps the given pre-existing HttpHeaders and also + * propagate all changes to it. + * + * @param headers the HTTP headers to wrap + */ + public WebSocketHttpHeaders(HttpHeaders headers) { + this(headers, false); + } + + /** + * Private constructor that can create read-only {@code WebSocketHttpHeader} instances. + */ + private WebSocketHttpHeaders(HttpHeaders headers, boolean readOnly) { + this.headers = readOnly ? HttpHeaders.readOnlyHttpHeaders(headers) : headers; + } + + /** + * Returns {@code WebSocketHttpHeaders} object that can only be read, not written to. + */ + public static WebSocketHttpHeaders readOnlyWebSocketHttpHeaders(WebSocketHttpHeaders headers) { + return new WebSocketHttpHeaders(headers, true); + } + + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Accept} header. + * @param secWebSocketAccept the value of the header + */ + public void setSecWebSocketAccept(String secWebSocketAccept) { + set(SEC_WEBSOCKET_ACCEPT, secWebSocketAccept); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Accept} header. + * @return the value of the header + */ + public String getSecWebSocketAccept() { + return getFirst(SEC_WEBSOCKET_ACCEPT); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Extensions} header. + * @return the value of the header + */ + public List getSecWebSocketExtensions() { + List values = get(SEC_WEBSOCKET_EXTENSIONS); + if (CollectionUtils.isEmpty(values)) { + return Collections.emptyList(); + } + else { + List result = new ArrayList(values.size()); + for (String value : values) { + result.addAll(WebSocketExtension.parseExtensions(value)); + } + return result; + } + } + + /** + * Sets the (new) value(s) of the {@code Sec-WebSocket-Extensions} header. + * @param extensions the values for the header + */ + public void setSecWebSocketExtensions(List extensions) { + List result = new ArrayList(extensions.size()); + for(WebSocketExtension extension : extensions) { + result.add(extension.toString()); + } + set(SEC_WEBSOCKET_EXTENSIONS, toCommaDelimitedString(result)); + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Key} header. + * @param secWebSocketKey the value of the header + */ + public void setSecWebSocketKey(String secWebSocketKey) { + set(SEC_WEBSOCKET_KEY, secWebSocketKey); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Key} header. + * @return the value of the header + */ + public String getSecWebSocketKey() { + return getFirst(SEC_WEBSOCKET_KEY); + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Protocol} header. + * @param secWebSocketProtocol the value of the header + */ + public void setSecWebSocketProtocol(String secWebSocketProtocol) { + if (secWebSocketProtocol != null) { + set(SEC_WEBSOCKET_PROTOCOL, secWebSocketProtocol); + } + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Protocol} header. + * @param secWebSocketProtocols the value of the header + */ + public void setSecWebSocketProtocol(List secWebSocketProtocols) { + set(SEC_WEBSOCKET_PROTOCOL, toCommaDelimitedString(secWebSocketProtocols)); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Key} header. + * @return the value of the header + */ + public List getSecWebSocketProtocol() { + List values = get(SEC_WEBSOCKET_PROTOCOL); + if (CollectionUtils.isEmpty(values)) { + return Collections.emptyList(); + } + else if (values.size() == 1) { + return getFirstValueAsList(SEC_WEBSOCKET_PROTOCOL); + } + else { + return values; + } + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Version} header. + * @param secWebSocketVersion the value of the header + */ + public void setSecWebSocketVersion(String secWebSocketVersion) { + set(SEC_WEBSOCKET_VERSION, secWebSocketVersion); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Version} header. + * @return the value of the header + */ + public String getSecWebSocketVersion() { + return getFirst(SEC_WEBSOCKET_VERSION); + } + + + // Single string methods + + /** + * Return the first header value for the given header name, if any. + * @param headerName the header name + * @return the first header value; or {@code null} + */ + @Override + public String getFirst(String headerName) { + return this.headers.getFirst(headerName); + } + + /** + * Add the given, single header value under the given name. + * @param headerName the header name + * @param headerValue the header value + * @throws UnsupportedOperationException if adding headers is not supported + * @see #put(String, List) + * @see #set(String, String) + */ + @Override + public void add(String headerName, String headerValue) { + this.headers.add(headerName, headerValue); + } + + /** + * Set the given, single header value under the given name. + * @param headerName the header name + * @param headerValue the header value + * @throws UnsupportedOperationException if adding headers is not supported + * @see #put(String, List) + * @see #add(String, String) + */ + @Override + public void set(String headerName, String headerValue) { + this.headers.set(headerName, headerValue); + } + + @Override + public void setAll(Map values) { + this.headers.setAll(values); + } + + @Override + public Map toSingleValueMap() { + return this.headers.toSingleValueMap(); + } + + // Map implementation + + @Override + public int size() { + return this.headers.size(); + } + + @Override + public boolean isEmpty() { + return this.headers.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return this.headers.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + return this.headers.containsValue(value); + } + + @Override + public List get(Object key) { + return this.headers.get(key); + } + + @Override + public List put(String key, List value) { + return this.headers.put(key, value); + } + + @Override + public List remove(Object key) { + return this.headers.remove(key); + } + + @Override + public void putAll(Map> m) { + this.headers.putAll(m); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.keySet(); + } + + @Override + public Collection> values() { + return this.headers.values(); + } + + @Override + public Set>> entrySet() { + return this.headers.entrySet(); + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof WebSocketHttpHeaders)) { + return false; + } + WebSocketHttpHeaders otherHeaders = (WebSocketHttpHeaders) other; + return this.headers.equals(otherHeaders.headers); + } + + @Override + public int hashCode() { + return this.headers.hashCode(); + } + + @Override + public String toString() { + return this.headers.toString(); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java index d0dd2f95014..0845e83018d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java @@ -117,7 +117,7 @@ public abstract class AbstractWebSocketIntegrationTests { static abstract class AbstractRequestUpgradeStrategyConfig { @Bean - public HandshakeHandler handshakeHandler() { + public DefaultHandshakeHandler handshakeHandler() { return new DefaultHandshakeHandler(requestUpgradeStrategy()); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java index 2a674af1736..f760ee110fb 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java @@ -21,41 +21,35 @@ import static org.junit.Assert.assertThat; import org.hamcrest.Matchers; import org.junit.Test; +import org.springframework.web.socket.support.WebSocketExtension; import java.util.ArrayList; import java.util.List; /** - * Test fixture for {@link WebSocketExtension} + * Test fixture for {@link org.springframework.web.socket.support.WebSocketExtension} * @author Brian Clozel */ public class WebSocketExtensionTests { @Test public void parseHeaderSingle() { - List extensions = WebSocketExtension.parseHeader("x-test-extension ; foo=bar"); + List extensions = WebSocketExtension.parseExtensions("x-test-extension ; foo=bar ; bar=baz"); assertThat(extensions, Matchers.hasSize(1)); WebSocketExtension extension = extensions.get(0); + assertEquals("x-test-extension", extension.getName()); - assertEquals(1, extension.getParameters().size()); + assertEquals(2, extension.getParameters().size()); assertEquals("bar", extension.getParameters().get("foo")); + assertEquals("baz", extension.getParameters().get("bar")); } @Test public void parseHeaderMultiple() { - List extensions = WebSocketExtension.parseHeader("x-foo-extension, x-bar-extension"); + List extensions = WebSocketExtension.parseExtensions("x-foo-extension, x-bar-extension"); assertThat(extensions, Matchers.hasSize(2)); assertEquals("x-foo-extension", extensions.get(0).getName()); assertEquals("x-bar-extension", extensions.get(1).getName()); } - @Test - public void parseHeaders() { - List extensions = new ArrayList(); - extensions.add("x-foo-extension, x-bar-extension"); - extensions.add("x-test-extension"); - List parsedExtensions = WebSocketExtension.parseHeaders(extensions); - assertThat(parsedExtensions, Matchers.hasSize(3)); - } - } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java new file mode 100644 index 00000000000..3a2c5057d65 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java @@ -0,0 +1,173 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket; + + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; +import org.springframework.web.socket.adapter.WebSocketHandlerAdapter; +import org.springframework.web.socket.client.endpoint.StandardWebSocketClient; +import org.springframework.web.socket.client.jetty.JettyWebSocketClient; +import org.springframework.web.socket.server.DefaultHandshakeHandler; +import org.springframework.web.socket.server.HandshakeFailureException; +import org.springframework.web.socket.server.RequestUpgradeStrategy; +import org.springframework.web.socket.server.config.EnableWebSocket; +import org.springframework.web.socket.server.config.WebSocketConfigurer; +import org.springframework.web.socket.server.config.WebSocketHandlerRegistry; +import org.springframework.web.socket.support.WebSocketExtension; +import org.springframework.web.socket.support.WebSocketHttpHeaders; + +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; + +import static org.junit.Assert.assertEquals; + +/** + * Client and server-side WebSocket integration tests. + * + * @author Rossen Stoyanchev + */ +@RunWith(Parameterized.class) +public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { + + @Parameterized.Parameters + public static Iterable arguments() { + return Arrays.asList(new Object[][]{ + {new JettyWebSocketTestServer(), new JettyWebSocketClient()}, + {new TomcatWebSocketTestServer(), new StandardWebSocketClient()} + }); + }; + + + @Override + protected Class[] getAnnotatedConfigClasses() { + return new Class[] { TestWebSocketConfigurer.class }; + } + + @Test + public void subProtocolNegotiation() throws Exception { + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); + headers.setSecWebSocketProtocol("foo"); + + WebSocketSession session = this.webSocketClient.doHandshake( + new WebSocketHandlerAdapter(), headers, new URI(getWsBaseUrl() + "/ws")).get(); + + assertEquals("foo", session.getAcceptedProtocol()); + } + + + @Configuration + @EnableWebSocket + static class TestWebSocketConfigurer implements WebSocketConfigurer { + + @Autowired + private DefaultHandshakeHandler handshakeHandler; // can't rely on classpath for server detection + + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + this.handshakeHandler.setSupportedProtocols("foo", "bar", "baz"); + registry.addHandler(serverHandler(), "/ws").setHandshakeHandler(this.handshakeHandler); + } + + @Bean + public TestServerWebSocketHandler serverHandler() { + return new TestServerWebSocketHandler(); + } + } + + private static class TestClientWebSocketHandler extends TextWebSocketHandlerAdapter { + + private final TextMessage[] messagesToSend; + + private final int expected; + + private final List actual = new CopyOnWriteArrayList(); + + private final CountDownLatch latch; + + + public TestClientWebSocketHandler(int expectedNumberOfMessages, TextMessage... messagesToSend) { + this.messagesToSend = messagesToSend; + this.expected = expectedNumberOfMessages; + this.latch = new CountDownLatch(this.expected); + } + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + for (TextMessage message : this.messagesToSend) { + session.sendMessage(message); + } + } + + @Override + protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { + this.actual.add(message); + this.latch.countDown(); + } + } + + private static class TestServerWebSocketHandler extends TextWebSocketHandlerAdapter { + + } + + private static class TestRequestUpgradeStrategy implements RequestUpgradeStrategy { + + private final RequestUpgradeStrategy delegate; + + private List extensions= new ArrayList(); + + + private TestRequestUpgradeStrategy(RequestUpgradeStrategy delegate, String... supportedExtensions) { + this.delegate = delegate; + for (String name : supportedExtensions) { + this.extensions.add(new WebSocketExtension(name)); + } + } + + @Override + public String[] getSupportedVersions() { + return this.delegate.getSupportedVersions(); + } + + @Override + public List getSupportedExtensions(ServerHttpRequest request) { + return this.extensions; + } + + @Override + public void upgrade(ServerHttpRequest request, ServerHttpResponse response, + String selectedProtocol, List selectedExtensions, + WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { + + this.delegate.upgrade(request, response, selectedProtocol, selectedExtensions, wsHandler, attributes); + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java index 4bf61438eda..0a4c4ec7c4f 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java @@ -31,6 +31,7 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.WebSocketHandlerAdapter; import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator; import org.springframework.web.socket.support.WebSocketHandlerDecorator; +import org.springframework.web.socket.support.WebSocketHttpHeaders; import org.springframework.web.util.UriComponentsBuilder; import static org.junit.Assert.*; @@ -54,7 +55,7 @@ public class WebSocketConnectionManagerTests { manager.setSubProtocols(subprotocols); manager.openConnection(); - HttpHeaders expectedHeaders = new HttpHeaders(); + WebSocketHttpHeaders expectedHeaders = new WebSocketHttpHeaders(); expectedHeaders.setSecWebSocketProtocol(subprotocols); assertEquals(expectedHeaders, client.headers); @@ -150,7 +151,7 @@ public class WebSocketConnectionManagerTests { @Override public ListenableFuture doHandshake(WebSocketHandler webSocketHandler, - HttpHeaders headers, URI uri) { + WebSocketHttpHeaders headers, URI uri) { this.webSocketHandler = webSocketHandler; this.headers = headers; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java index 8fe32355fce..5b2919ed568 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java @@ -34,6 +34,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.WebSocketHandlerAdapter; +import org.springframework.web.socket.support.WebSocketHttpHeaders; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -51,12 +52,12 @@ public class StandardWebSocketClientTests { private WebSocketHandler wsHandler; - private HttpHeaders headers; + private WebSocketHttpHeaders headers; @Before public void setup() { - this.headers = new HttpHeaders(); + this.headers = new WebSocketHttpHeaders(); this.wsHandler = new WebSocketHandlerAdapter(); this.wsContainer = mock(WebSocketContainer.class); this.wsClient = new StandardWebSocketClient(this.wsContainer); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java index c1960a52b3e..14d0b770eae 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java @@ -36,6 +36,7 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSession; import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; +import org.springframework.web.socket.support.WebSocketHttpHeaders; import static org.junit.Assert.*; @@ -80,7 +81,7 @@ public class JettyWebSocketClientTests { @Test public void doHandshake() throws Exception { - HttpHeaders headers = new HttpHeaders(); + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); headers.setSecWebSocketProtocol(Arrays.asList("echo")); this.wsSession = this.client.doHandshake(new TextWebSocketHandlerAdapter(), headers, new URI(this.wsUrl)).get(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java index 537b8a524df..5a96e5ec393 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java @@ -24,8 +24,10 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.web.socket.AbstractHttpRequestTests; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; +import org.springframework.web.socket.support.WebSocketHttpHeaders; import static org.mockito.Mockito.*; @@ -57,19 +59,22 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] { "13" }); + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders()); + this.servletRequest.setMethod("GET"); - this.request.getHeaders().setUpgrade("WebSocket"); - this.request.getHeaders().setConnection("Upgrade"); - this.request.getHeaders().setSecWebSocketVersion("13"); - this.request.getHeaders().setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); - this.request.getHeaders().setSecWebSocketProtocol("STOMP"); + headers.setUpgrade("WebSocket"); + headers.setConnection("Upgrade"); + headers.setSecWebSocketVersion("13"); + headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); + headers.setSecWebSocketProtocol("STOMP"); WebSocketHandler handler = new TextWebSocketHandlerAdapter(); Map attributes = Collections.emptyMap(); this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes); - verify(this.upgradeStrategy).upgrade(this.request, this.response, "STOMP", handler, attributes); + verify(this.upgradeStrategy).upgrade(this.request, this.response, + "STOMP", Collections.emptyList(), handler, attributes); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java index ab36aa549ec..e769d5b6520 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java @@ -41,7 +41,7 @@ import static org.junit.Assert.*; /** - * Test fixture for WebSocket Java config support. + * Integration tests for WebSocket Java server-side configuration. * * @author Rossen Stoyanchev */ diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java index 9f145cdd1e5..730c12cdcdc 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java @@ -26,7 +26,7 @@ import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; -import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; @@ -59,7 +59,7 @@ public class TestSockJsSession extends AbstractSockJsSession { private String subProtocol; - private List extensions = new ArrayList(); + private List extensions = new ArrayList<>(); public TestSockJsSession(String sessionId, SockJsServiceConfig config, @@ -83,61 +83,37 @@ public class TestSockJsSession extends AbstractSockJsSession { return this.headers; } - /** - * @return the headers - */ public HttpHeaders getHeaders() { return this.headers; } - /** - * @param headers the headers to set - */ public void setHeaders(HttpHeaders headers) { this.headers = headers; } - /** - * @return the principal - */ @Override public Principal getPrincipal() { return this.principal; } - /** - * @param principal the principal to set - */ public void setPrincipal(Principal principal) { this.principal = principal; } - /** - * @return the localAddress - */ @Override public InetSocketAddress getLocalAddress() { return this.localAddress; } - /** - * @param localAddress the remoteAddress to set - */ public void setLocalAddress(InetSocketAddress localAddress) { this.localAddress = localAddress; } - /** - * @return the remoteAddress - */ @Override public InetSocketAddress getRemoteAddress() { return this.remoteAddress; } - /** - * @param remoteAddress the remoteAddress to set - */ public void setRemoteAddress(InetSocketAddress remoteAddress) { this.remoteAddress = remoteAddress; } @@ -151,17 +127,14 @@ public class TestSockJsSession extends AbstractSockJsSession { this.subProtocol = protocol; } - /** - * @return the extensions - */ @Override - public List getExtensions() { return this.extensions; } + public List getExtensions() { + return this.extensions; + } - /** - * - * @param extensions the extensions to set - */ - public void setExtensions(List extensions) { this.extensions = extensions; } + public void setExtensions(List extensions) { + this.extensions = extensions; + } public CloseStatus getCloseStatus() { return this.closeStatus; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java index 17935ac62db..7c8775cfc73 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java @@ -27,7 +27,6 @@ import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; -import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; @@ -63,32 +62,20 @@ public class TestWebSocketSession implements WebSocketSession { private HttpHeaders headers; - /** - * @return the id - */ @Override public String getId() { return this.id; } - /** - * @param id the id to set - */ public void setId(String id) { this.id = id; } - /** - * @return the uri - */ @Override public URI getUri() { return this.uri; } - /** - * @param uri the uri to set - */ public void setUri(URI uri) { this.uri = uri; } @@ -99,117 +86,72 @@ public class TestWebSocketSession implements WebSocketSession { return this.headers; } - /** - * @return the headers - */ public HttpHeaders getHeaders() { return this.headers; } - /** - * @param headers the headers to set - */ public void setHeaders(HttpHeaders headers) { this.headers = headers; } - /** - * @param attributes the attributes to set - */ public void setHandshakeAttributes(Map attributes) { this.attributes = attributes; } - /** - * @return the attributes - */ @Override public Map getHandshakeAttributes() { return this.attributes; } - /** - * @return the principal - */ @Override public Principal getPrincipal() { return this.principal; } - /** - * @param principal the principal to set - */ public void setPrincipal(Principal principal) { this.principal = principal; } - /** - * @return the localAddress - */ @Override public InetSocketAddress getLocalAddress() { return this.localAddress; } - /** - * @param localAddress the remoteAddress to set - */ public void setLocalAddress(InetSocketAddress localAddress) { this.localAddress = localAddress; } - /** - * @return the remoteAddress - */ @Override public InetSocketAddress getRemoteAddress() { return this.remoteAddress; } - /** - * @param remoteAddress the remoteAddress to set - */ public void setRemoteAddress(InetSocketAddress remoteAddress) { this.remoteAddress = remoteAddress; } - /** - * @return the subProtocol - */ public String getAcceptedProtocol() { return this.protocol; } - /** - * @param protocol the subProtocol to set - */ public void setAcceptedProtocol(String protocol) { this.protocol = protocol; } - /** - * @return the extensions - */ @Override - public List getExtensions() { return this.extensions; } + public List getExtensions() { + return this.extensions; + } - /** - * - * @param extensions the extensions to set - */ - public void setExtensions(List extensions) { this.extensions = extensions; } + public void setExtensions(List extensions) { + this.extensions = extensions; + } - /** - * @return the open - */ @Override public boolean isOpen() { return this.open; } - /** - * @param open the open to set - */ public void setOpen(boolean open) { this.open = open; } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/WebSocketHttpHeadersTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/WebSocketHttpHeadersTests.java new file mode 100644 index 00000000000..7cca7129775 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/WebSocketHttpHeadersTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.support; + +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertThat; + +/** + * Unit tests for WebSocketHttpHeaders. + * + * @author Rossen Stoyanchev + */ +public class WebSocketHttpHeadersTests { + + private WebSocketHttpHeaders headers; + + @Before + public void setUp() { + headers = new WebSocketHttpHeaders(); + } + + @Test + public void parseWebSocketExtensions() { + List extensions = new ArrayList(); + extensions.add("x-foo-extension, x-bar-extension"); + extensions.add("x-test-extension"); + this.headers.put(WebSocketHttpHeaders.SEC_WEBSOCKET_EXTENSIONS, extensions); + + List parsedExtensions = this.headers.getSecWebSocketExtensions(); + assertThat(parsedExtensions, Matchers.hasSize(3)); + } + +}