diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java index bf14730e9cb..b935e21889d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.Map; import org.springframework.http.HttpHeaders; @@ -48,6 +49,13 @@ public interface WebSocketSession { */ HttpHeaders getHandshakeHeaders(); + /** + * Handshake request specific attributes. + * To add attributes to a server-side WebSocket session see + * {@link org.springframework.web.socket.server.HandshakeInterceptor}. + */ + Map getHandshakeAttributes(); + /** * Return a {@link java.security.Principal} instance containing the name of the * authenticated user. If the user has not been authenticated, the method returns diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java index 2faffbfb13c..fc3fe828849 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java @@ -16,6 +16,7 @@ package org.springframework.web.socket.adapter; import java.io.IOException; +import java.util.Map; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -38,6 +39,24 @@ public abstract class AbstractWebSocketSesssion implements DelegatingWebSocke private T delegateSession; + private final Map handshakeAttributes; + + + /** + * Class constructor + * + * @param handshakeAttributes attributes from the HTTP handshake to make available + * through the WebSocket session + */ + public AbstractWebSocketSesssion(Map handshakeAttributes) { + this.handshakeAttributes = handshakeAttributes; + } + + + @Override + public Map getHandshakeAttributes() { + return this.handshakeAttributes; + } /** * @return the WebSocket session to delegate to diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java index 73260ca668e..807eed19088 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java @@ -20,8 +20,8 @@ import org.springframework.web.socket.WebSocketSession; /** - * A contract for {@link WebSocketSession} implementations that delegate to another - * WebSocket session (e.g. a native session). + * A contract for a {@link WebSocketSession} that delegates to another WebSocket session + * (e.g. a native session). * * @param T the type of the delegate WebSocket session * diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java index 708ac66bc1b..e67d94fc51b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.util.ObjectUtils; @@ -46,8 +47,11 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion handshakeAttributes) { + super(handshakeAttributes); this.principal = principal; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java index 8dd2c1fd38e..f59ac55b275 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.Map; import javax.websocket.CloseReason; import javax.websocket.CloseReason.CloseCodes; @@ -39,7 +40,7 @@ import org.springframework.web.socket.WebSocketSession; */ public class StandardWebSocketSession extends AbstractWebSocketSesssion { - private final HttpHeaders headers; + private final HttpHeaders handshakeHeaders; private final InetSocketAddress localAddress; @@ -50,12 +51,17 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion handshakeAttributes, + InetSocketAddress localAddress, InetSocketAddress remoteAddress) { + super(handshakeAttributes); handshakeHeaders = (handshakeHeaders != null) ? handshakeHeaders : new HttpHeaders(); - this.headers = HttpHeaders.readOnlyHttpHeaders(handshakeHeaders); + this.handshakeHeaders = HttpHeaders.readOnlyHttpHeaders(handshakeHeaders); this.localAddress = localAddress; this.remoteAddress = remoteAddress; } @@ -74,7 +80,7 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssionemptyMap()); } /** - * + * Perform the actual handshake to establish a connection to the server. * * @param webSocketHandler the client-side handler for WebSocket messages * @param headers HTTP headers to use for the handshake, with unwanted (forbidden) * headers filtered out, never {@code null} * @param uri the target URI for the handshake, never {@code null} * @param subProtocols requested sub-protocols, or an empty list + * @param handshakeAttributes attributes to make available via + * {@link WebSocketSession#getHandshakeAttributes()}; currently always an empty map. + * * @return the established WebSocket session + * * @throws WebSocketConnectFailureException */ protected abstract WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler, - HttpHeaders headers, URI uri, List subProtocols) throws WebSocketConnectFailureException; + HttpHeaders headers, URI uri, List subProtocols, + Map handshakeAttributes) throws WebSocketConnectFailureException; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java index 931262d15f7..a6ff22586fe 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java @@ -63,14 +63,16 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { @Override - protected WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler, - HttpHeaders headers, URI uri, List protocols) throws WebSocketConnectFailureException { + protected WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler, HttpHeaders headers, + URI uri, List protocols, Map handshakeAttributes) + throws WebSocketConnectFailureException { int port = getPort(uri); InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port); InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port); - StandardWebSocketSession session = new StandardWebSocketSession(headers, localAddress, remoteAddress); + StandardWebSocketSession session = new StandardWebSocketSession(headers, + handshakeAttributes, localAddress, remoteAddress); ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create(); configBuidler.configurator(new StandardWebSocketClientConfigurator(headers)); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java index 302b66ab334..159d683e9ee 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java @@ -17,7 +17,9 @@ package org.springframework.web.socket.client.jetty; import java.net.URI; +import java.security.Principal; import java.util.List; +import java.util.Map; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.springframework.context.SmartLifecycle; @@ -122,25 +124,26 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma } @Override - public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVariables) + public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVars) throws WebSocketConnectFailureException { - UriComponents uriComponents = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVariables).encode(); + UriComponents uriComponents = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode(); return doHandshake(webSocketHandler, null, uriComponents.toUri()); } @Override public WebSocketSession doHandshakeInternal(WebSocketHandler wsHandler, HttpHeaders headers, - URI uri, List protocols) throws WebSocketConnectFailureException { + URI uri, List protocols, Map handshakeAttributes) + throws WebSocketConnectFailureException { ClientUpgradeRequest request = new ClientUpgradeRequest(); request.setSubProtocols(protocols); - for (String header : headers.keySet()) { request.setHeader(header, headers.get(header)); } - JettyWebSocketSession wsSession = new JettyWebSocketSession(null); + Principal user = getUser(); + JettyWebSocketSession wsSession = new JettyWebSocketSession(user, handshakeAttributes); JettyWebSocketHandlerAdapter listener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession); try { @@ -153,4 +156,13 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma } } + + /** + * @return the user to make available through {@link WebSocketSession#getPrincipal()}; + * by default this method returns {@code null} + */ + protected Principal getUser() { + return null; + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java index 87587e9dbd6..660dac28ea2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java @@ -23,6 +23,7 @@ import java.security.NoSuchAlgorithmException; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import javax.xml.bind.DatatypeConverter; @@ -98,7 +99,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { @Override public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler) throws IOException, HandshakeFailureException { + WebSocketHandler webSocketHandler, Map attributes) throws IOException, HandshakeFailureException { logger.debug("Starting handshake for " + request.getURI()); @@ -150,7 +151,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { logger.trace("Upgrading with " + webSocketHandler); } - this.requestUpgradeStrategy.upgrade(request, response, selectedProtocol, webSocketHandler); + this.requestUpgradeStrategy.upgrade(request, response, selectedProtocol, webSocketHandler, attributes); return true; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java index 75654847ce9..a463961fc3b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.server; import java.io.IOException; +import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -29,7 +30,9 @@ import org.springframework.web.socket.support.PerConnectionWebSocketHandler; * @author Rossen Stoyanchev * @since 4.0 * + * @see HandshakeInterceptor * @see org.springframework.web.socket.server.support.WebSocketHttpRequestHandler + * @see org.springframework.web.socket.sockjs.SockJsService */ public interface HandshakeHandler { @@ -38,9 +41,11 @@ public interface HandshakeHandler { * * @param request the current request * @param response the current response - * @param webSocketHandler the handler to process WebSocket messages; see + * @param wsHandler the handler to process WebSocket messages; see * {@link PerConnectionWebSocketHandler} for providing a handler with * per-connection lifecycle. + * @param attributes handshake request specific attributes to be set on the WebSocket + * session and thus made available to the {@link WebSocketHandler} * * @return whether the handshake negotiation was successful or not. In either case the * response status, headers, and body will have been updated to reflect the @@ -53,7 +58,7 @@ public interface HandshakeHandler { * opposed to a failure to successfully negotiate the requirements of the * handshake request. */ - boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler webSocketHandler) - throws IOException, HandshakeFailureException; + boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, + Map attributes) throws IOException, HandshakeFailureException; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeInterceptor.java new file mode 100644 index 00000000000..86f7e0f3fbf --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeInterceptor.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server; + +import java.util.Map; + +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; + + +/** + * Interceptor for WebSocket handshake requests. Can be used to inspect the handshake + * request and response as well as to pass attributes to the target + * {@link WebSocketHandler}. + * + * @author Rossen Stoyanchev + * @since 4.0 + * + * @see org.springframework.web.socket.server.support.WebSocketHttpRequestHandler + * @see org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService + */ +public interface HandshakeInterceptor { + + /** + * Invoked before the handshake is processed. + * + * @param request the current request + * @param response the current response + * @param wsHandler the target WebSocket handler + * @param attributes attributes to make available via + * {@link WebSocketSession#getHandshakeAttributes()} + * + * @return whether to proceed with the handshake {@code true} or abort {@code false} + */ + boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Map attributes) throws Exception; + + /** + * Invoked after the handshake is done. The response status and headers indicate the + * results of the handshake, i.e. whether it was successful or not. + * + * @param request the current request + * @param response the current response + * @param wsHandler the target WebSocket handler + * @param exception an exception raised during the handshake, or {@code null} + */ + void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Exception exception); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java index 1421f519844..3dc5d83aed1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.server; import java.io.IOException; +import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -40,14 +41,18 @@ public interface RequestUpgradeStrategy { * Perform runtime specific steps to complete the upgrade. Invoked after successful * negotiation of the handshake request. * - * @param webSocketHandler the handler for WebSocket messages + * @param request the current request + * @param response the current response + * @param acceptedProtocol the accepted sub-protocol, if any + * @param wsHandler the handler for WebSocket messages + * @param attributes handshake context attributes * * @throws HandshakeFailureException thrown when handshake processing failed to * complete due to an internal, unrecoverable error, i.e. a server error as * opposed to a failure to successfully negotiate the requirements of the * handshake request. */ - void upgrade(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, - WebSocketHandler webSocketHandler) throws IOException, HandshakeFailureException; + void upgrade(ServerHttpRequest request, ServerHttpResponse response, String acceptedProtocol, + WebSocketHandler wsHandler, Map attributes) throws IOException, HandshakeFailureException; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java index 0a02e477618..127e6511690 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java @@ -18,6 +18,7 @@ package org.springframework.web.socket.server.support; import java.io.IOException; import java.net.InetSocketAddress; +import java.util.Map; import javax.websocket.Endpoint; @@ -44,14 +45,15 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS @Override - public void upgrade(ServerHttpRequest request, ServerHttpResponse response, - String acceptedProtocol, WebSocketHandler wsHandler) throws IOException, HandshakeFailureException { + public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String acceptedProtocol, + WebSocketHandler wsHandler, Map attributes) + throws IOException, HandshakeFailureException { HttpHeaders headers = request.getHeaders(); - InetSocketAddress localAddress = request.getLocalAddress(); - InetSocketAddress remoteAddress = request.getRemoteAddress(); + InetSocketAddress localAddr = request.getLocalAddress(); + InetSocketAddress remoteAddr = request.getRemoteAddress(); - StandardWebSocketSession wsSession = new StandardWebSocketSession(headers, localAddress, remoteAddress); + StandardWebSocketSession wsSession = new StandardWebSocketSession(headers, attributes, localAddr, remoteAddr); StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, wsSession); upgradeInternal(request, response, acceptedProtocol, endpoint); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HandshakeInterceptorChain.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HandshakeInterceptorChain.java new file mode 100644 index 00000000000..76dc2a82f53 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HandshakeInterceptorChain.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.support; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; + + +/** + * A helper class that assists with invoking a list of handshake interceptors. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class HandshakeInterceptorChain { + + private static final Log logger = LogFactory.getLog(WebSocketHttpRequestHandler.class); + + private final List interceptors; + + private final WebSocketHandler wsHandler; + + private int interceptorIndex = -1; + + + public HandshakeInterceptorChain(List interceptors, WebSocketHandler wsHandler) { + this.interceptors = (interceptors != null) ? interceptors : Collections.emptyList(); + this.wsHandler = wsHandler; + } + + + public boolean applyBeforeHandshake(ServerHttpRequest request, ServerHttpResponse response, + Map attributes) throws Exception { + + for (int i = 0; i < this.interceptors.size(); i++) { + HandshakeInterceptor interceptor = this.interceptors.get(i); + if (!interceptor.beforeHandshake(request, response, this.wsHandler, attributes)) { + applyAfterHandshake(request, response, null); + return false; + } + this.interceptorIndex = i; + } + return true; + } + + + public void applyAfterHandshake(ServerHttpRequest request, ServerHttpResponse response, Exception failure) { + for (int i = this.interceptorIndex; i >= 0; i--) { + HandshakeInterceptor interceptor = this.interceptors.get(i); + try { + interceptor.afterHandshake(request, response, this.wsHandler, failure); + } + catch (Throwable t) { + logger.warn("HandshakeInterceptor afterHandshake threw exception " + t); + } + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java new file mode 100644 index 00000000000..a23a7afbe00 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.support; + +import java.util.Collection; +import java.util.Enumeration; +import java.util.Map; + +import javax.servlet.http.HttpSession; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.util.CollectionUtils; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.server.HandshakeInterceptor; + + +/** + * An interceptor to copy HTTP session attributes into the map of "handshake attributes" + * made available through {@link WebSocketSession#getHandshakeAttributes()}. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor { + + private static Log logger = LogFactory.getLog(HttpSessionHandshakeInterceptor.class); + + private Collection attributeNames; + + + /** + * A constructor for copying all available HTTP session attributes. + */ + public HttpSessionHandshakeInterceptor() { + this(null); + } + + /** + * A constructor for copying a subset of HTTP session attributes. + * @param attributeNames the HTTP session attributes to copy + */ + public HttpSessionHandshakeInterceptor(Collection attributeNames) { + this.attributeNames = attributeNames; + } + + + @Override + public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Map attributes) throws Exception { + + if (request instanceof ServletServerHttpRequest) { + ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request; + HttpSession session = servletRequest.getServletRequest().getSession(false); + if (session != null) { + Enumeration names = session.getAttributeNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + if (CollectionUtils.isEmpty(this.attributeNames) || this.attributeNames.contains(name)) { + if (logger.isTraceEnabled()) { + logger.trace("Adding HTTP session attribute to handshake attributes: " + name); + } + attributes.put(name, session.getAttribute(name)); + } + else { + if (logger.isTraceEnabled()) { + logger.trace("Skipped HTTP session attribute"); + } + } + } + } + } + return true; + } + + @Override + public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Exception ex) { + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java index e6b30aa0d3c..184033ca0a3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.util.Map; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -85,7 +86,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, - String protocol, WebSocketHandler wsHandler) throws IOException { + String protocol, WebSocketHandler wsHandler, Map attrs) throws IOException { Assert.isInstanceOf(ServletServerHttpRequest.class, request); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -98,7 +99,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { throw new HandshakeFailureException("Not a WebSocket request"); } - JettyWebSocketSession wsSession = new JettyWebSocketSession(request.getPrincipal()); + JettyWebSocketSession wsSession = new JettyWebSocketSession(request.getPrincipal(), attrs); JettyWebSocketHandlerAdapter wsListener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession); servletRequest.setAttribute(WEBSOCKET_LISTENER_ATTR_NAME, wsListener); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java index c3caa213d84..95df66ad68a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java @@ -17,6 +17,10 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -30,7 +34,9 @@ import org.springframework.util.Assert; import org.springframework.web.HttpRequestHandler; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.DefaultHandshakeHandler; +import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator; import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator; @@ -56,6 +62,8 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler { private final WebSocketHandler webSocketHandler; + private final List interceptors = new ArrayList(); + public WebSocketHttpRequestHandler(WebSocketHandler webSocketHandler) { this(webSocketHandler, new DefaultHandshakeHandler()); @@ -69,6 +77,23 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler { } + /** + * Configure one or more WebSocket handshake request interceptors. + */ + public void setHandshakeInterceptors(List interceptors) { + this.interceptors.clear(); + if (interceptors != null) { + this.interceptors.addAll(interceptors); + } + } + + /** + * Return the configured WebSocket handshake request interceptors. + */ + public List getHandshakeInterceptors() { + return this.interceptors; + } + /** * Decorate the WebSocketHandler provided to the class constructor. * @@ -81,14 +106,36 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler { } @Override - public void handleRequest(HttpServletRequest request, HttpServletResponse response) + public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse servletResponse) throws ServletException, IOException { - ServerHttpRequest httpRequest = new ServletServerHttpRequest(request); - ServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - - this.handshakeHandler.doHandshake(httpRequest, httpResponse, this.webSocketHandler); - httpResponse.flush(); + ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); + ServerHttpResponse response = new ServletServerHttpResponse(servletResponse); + + HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, this.webSocketHandler); + HandshakeFailureException failure = null; + + try { + Map attributes = new HashMap(); + if (!chain.applyBeforeHandshake(request, response, attributes)) { + return; + } + this.handshakeHandler.doHandshake(request, response, this.webSocketHandler, attributes); + chain.applyAfterHandshake(request, response, null); + } + catch (HandshakeFailureException ex) { + failure = ex; + } + catch (Throwable t) { + failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), t); + } + finally { + if (failure != null) { + chain.applyAfterHandshake(request, response, failure); + throw failure; + } + response.flush(); + } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index 44597106f93..7a1fbc15027 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -269,8 +269,8 @@ public abstract class AbstractSockJsService implements SockJsService { * and raw WebSocket requests are delegated to abstract methods. */ @Override - public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler handler) - throws SockJsException { + public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler) throws SockJsException { String sockJsPath = getSockJsPath(request); if (sockJsPath == null) { @@ -301,7 +301,7 @@ public abstract class AbstractSockJsService implements SockJsService { this.iframeHandler.handle(request, response); } else if (sockJsPath.equals("/websocket")) { - handleRawWebSocketRequest(request, response, handler); + handleRawWebSocketRequest(request, response, wsHandler); } else { String[] pathSegments = StringUtils.tokenizeToStringArray(sockJsPath.substring(1), "/"); @@ -318,7 +318,7 @@ public abstract class AbstractSockJsService implements SockJsService { response.setStatusCode(HttpStatus.NOT_FOUND); return; } - handleTransportRequest(request, response, handler, sessionId, transport); + handleTransportRequest(request, response, wsHandler, sessionId, transport); } response.flush(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java index a64ff7c6a69..6d6033c9604 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -41,7 +42,10 @@ import org.springframework.util.ObjectUtils; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.server.DefaultHandshakeHandler; +import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; +import org.springframework.web.socket.server.support.HandshakeInterceptorChain; import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.support.AbstractSockJsService; @@ -75,6 +79,8 @@ public class DefaultSockJsService extends AbstractSockJsService { private SockJsMessageCodec messageCodec; + private final List interceptors = new ArrayList(); + private final Map sessions = new ConcurrentHashMap(); private ScheduledFuture sessionCleanupTask; @@ -167,6 +173,23 @@ public class DefaultSockJsService extends AbstractSockJsService { } + /** + * Configure one or more WebSocket handshake request interceptors. + */ + public void setHandshakeInterceptors(List interceptors) { + this.interceptors.clear(); + if (interceptors != null) { + this.interceptors.addAll(interceptors); + } + } + + /** + * Return the configured WebSocket handshake request interceptors. + */ + public List getHandshakeInterceptors() { + return this.interceptors; + } + /** * The codec to use for encoding and decoding SockJS messages. * @exception IllegalStateException if no {@link SockJsMessageCodec} is available @@ -185,19 +208,42 @@ public class DefaultSockJsService extends AbstractSockJsService { @Override protected void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler) throws IOException { + WebSocketHandler wsHandler) throws IOException { - if (isWebSocketEnabled()) { - TransportHandler transportHandler = this.transportHandlers.get(TransportType.WEBSOCKET); - if (transportHandler != null) { - if (transportHandler instanceof HandshakeHandler) { - ((HandshakeHandler) transportHandler).doHandshake(request, response, webSocketHandler); - return; - } - } + if (!isWebSocketEnabled()) { + return; + } + + TransportHandler transportHandler = this.transportHandlers.get(TransportType.WEBSOCKET); + if ((transportHandler == null) || !(transportHandler instanceof HandshakeHandler)) { logger.warn("No handler for raw WebSocket messages"); + response.setStatusCode(HttpStatus.NOT_FOUND); + return; + } + + HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, wsHandler); + HandshakeFailureException failure = null; + + try { + Map attributes = new HashMap(); + if (!chain.applyBeforeHandshake(request, response, attributes)) { + return; + } + ((HandshakeHandler) transportHandler).doHandshake(request, response, wsHandler, attributes); + chain.applyAfterHandshake(request, response, null); + } + catch (HandshakeFailureException ex) { + failure = ex; + } + catch (Throwable t) { + failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), t); + } + finally { + if (failure != null) { + chain.applyAfterHandshake(request, response, failure); + throw failure; + } } - response.setStatusCode(HttpStatus.NOT_FOUND); } @Override @@ -235,38 +281,61 @@ public class DefaultSockJsService extends AbstractSockJsService { return; } - WebSocketSession session = this.sessions.get(sessionId); - if (session == null) { - if (transportHandler instanceof SockJsSessionFactory) { - SockJsSessionFactory sessionFactory = (SockJsSessionFactory) transportHandler; - session = createSockJsSession(sessionId, sessionFactory, wsHandler, request, response); + HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, wsHandler); + SockJsException failure = null; + + try { + WebSocketSession session = this.sessions.get(sessionId); + if (session == null) { + if (transportHandler instanceof SockJsSessionFactory) { + Map attributes = new HashMap(); + if (!chain.applyBeforeHandshake(request, response, attributes)) { + return; + } + SockJsSessionFactory sessionFactory = (SockJsSessionFactory) transportHandler; + session = createSockJsSession(sessionId, sessionFactory, wsHandler, attributes, request, response); + } + else { + response.setStatusCode(HttpStatus.NOT_FOUND); + logger.warn("Session not found"); + return; + } } - } - if (session == null) { - response.setStatusCode(HttpStatus.NOT_FOUND); - logger.warn("Session not found"); - return; - } - if (transportType.sendsNoCacheInstruction()) { - addNoCacheHeaders(response); - } + if (transportType.sendsNoCacheInstruction()) { + addNoCacheHeaders(response); + } - if (transportType.sendsSessionCookie() && isDummySessionCookieEnabled()) { - Cookie cookie = request.getCookies().get("JSESSIONID"); - String value = (cookie != null) ? cookie.getValue() : "dummy"; - response.getHeaders().set("Set-Cookie", "JSESSIONID=" + value + ";path=/"); - } + if (transportType.sendsSessionCookie() && isDummySessionCookieEnabled()) { + Cookie cookie = request.getCookies().get("JSESSIONID"); + String value = (cookie != null) ? cookie.getValue() : "dummy"; + response.getHeaders().set("Set-Cookie", "JSESSIONID=" + value + ";path=/"); + } - if (transportType.supportsCors()) { - addCorsHeaders(request, response); - } + if (transportType.supportsCors()) { + addCorsHeaders(request, response); + } - transportHandler.handleRequest(request, response, wsHandler, session); + transportHandler.handleRequest(request, response, wsHandler, session); + chain.applyAfterHandshake(request, response, null); + } + catch (SockJsException ex) { + failure = ex; + } + catch (Throwable t) { + failure = new SockJsException("Uncaught failure for request " + request.getURI(), sessionId, t); + } + finally { + if (failure != null) { + chain.applyAfterHandshake(request, response, failure); + throw failure; + } + } } private WebSocketSession createSockJsSession(String sessionId, SockJsSessionFactory sessionFactory, - WebSocketHandler handler, ServerHttpRequest request, ServerHttpResponse response) { + WebSocketHandler wsHandler, Map handshakeAttributes, + ServerHttpRequest request, ServerHttpResponse response) { synchronized (this.sessions) { AbstractSockJsSession session = this.sessions.get(sessionId); @@ -276,8 +345,9 @@ public class DefaultSockJsService extends AbstractSockJsService { if (this.sessionCleanupTask == null) { scheduleSessionTask(); } + logger.debug("Creating new session with session id \"" + sessionId + "\""); - session = sessionFactory.createSession(sessionId, handler); + session = sessionFactory.createSession(sessionId, wsHandler, handshakeAttributes); this.sessions.put(sessionId, session); return session; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java index 082a8e181a2..f13b3f8ba40 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java @@ -18,6 +18,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.io.IOException; import java.nio.charset.Charset; +import java.util.Map; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; @@ -49,8 +50,10 @@ public class EventSourceTransportHandler extends AbstractHttpSendingTransportHan } @Override - public StreamingSockJsSession createSession(String sessionId, WebSocketHandler handler) { - return new EventSourceStreamingSockJsSession(sessionId, getSockJsServiceConfig(), handler); + public StreamingSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, + Map attributes) { + + return new EventSourceStreamingSockJsSession(sessionId, getSockJsServiceConfig(), wsHandler, attributes); } @Override @@ -61,8 +64,10 @@ public class EventSourceTransportHandler extends AbstractHttpSendingTransportHan private final class EventSourceStreamingSockJsSession extends StreamingSockJsSession { - private EventSourceStreamingSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + private EventSourceStreamingSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java index e09df759947..980afa4fb19 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java @@ -18,6 +18,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.io.IOException; import java.nio.charset.Charset; +import java.util.Map; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -87,8 +88,10 @@ public class HtmlFileTransportHandler extends AbstractHttpSendingTransportHandle } @Override - public StreamingSockJsSession createSession(String sessionId, WebSocketHandler handler) { - return new HtmlFileStreamingSockJsSession(sessionId, getSockJsServiceConfig(), handler); + public StreamingSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, + Map attributes) { + + return new HtmlFileStreamingSockJsSession(sessionId, getSockJsServiceConfig(), wsHandler, attributes); } @Override @@ -124,8 +127,10 @@ public class HtmlFileTransportHandler extends AbstractHttpSendingTransportHandle private final class HtmlFileStreamingSockJsSession extends StreamingSockJsSession { - private HtmlFileStreamingSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + private HtmlFileStreamingSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/JsonpPollingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/JsonpPollingTransportHandler.java index 0d8cea2ddea..5c501646ef5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/JsonpPollingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/JsonpPollingTransportHandler.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.nio.charset.Charset; +import java.util.Map; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -53,8 +54,10 @@ public class JsonpPollingTransportHandler extends AbstractHttpSendingTransportHa } @Override - public PollingSockJsSession createSession(String sessionId, WebSocketHandler handler) { - return new PollingSockJsSession(sessionId, getSockJsServiceConfig(), handler); + public PollingSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, + Map attributes) { + + return new PollingSockJsSession(sessionId, getSockJsServiceConfig(), wsHandler, attributes); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsSessionFactory.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsSessionFactory.java index 06d54ae09e5..ae15c847881 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsSessionFactory.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsSessionFactory.java @@ -16,6 +16,8 @@ package org.springframework.web.socket.sockjs.transport.handler; +import java.util.Map; + import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.transport.TransportHandler; import org.springframework.web.socket.sockjs.transport.session.AbstractSockJsSession; @@ -31,10 +33,13 @@ public interface SockJsSessionFactory { /** * Create a new SockJS session. + * * @param sessionId the ID of the session - * @param webSocketHandler the underlying {@link WebSocketHandler} - * @return a new non-null session + * @param wsHandler the underlying {@link WebSocketHandler} + * @param attributes handshake request specific attributes + * + * @return a new session, never {@code null} */ - AbstractSockJsSession createSession(String sessionId, WebSocketHandler webSocketHandler); + AbstractSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, Map attributes); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java index f6d40b33b31..788cb93bd6f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java @@ -17,6 +17,8 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.io.IOException; +import java.util.Collections; +import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -60,8 +62,10 @@ public class WebSocketTransportHandler extends TransportHandlerSupport } @Override - public AbstractSockJsSession createSession(String sessionId, WebSocketHandler webSocketHandler) { - return new WebSocketServerSockJsSession(sessionId, getSockJsServiceConfig(), webSocketHandler); + public AbstractSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, + Map attributes) { + + return new WebSocketServerSockJsSession(sessionId, getSockJsServiceConfig(), wsHandler, attributes); } @Override @@ -71,7 +75,7 @@ public class WebSocketTransportHandler extends TransportHandlerSupport WebSocketServerSockJsSession sockJsSession = (WebSocketServerSockJsSession) wsSession; try { wsHandler = new SockJsWebSocketHandler(getSockJsServiceConfig(), wsHandler, sockJsSession); - this.handshakeHandler.doHandshake(request, response, wsHandler); + this.handshakeHandler.doHandshake(request, response, wsHandler, Collections.emptyMap()); } catch (Throwable t) { sockJsSession.tryCloseWithSockJsTransportError(t, CloseStatus.SERVER_ERROR); @@ -82,10 +86,10 @@ public class WebSocketTransportHandler extends TransportHandlerSupport // HandshakeHandler methods @Override - public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler handler) - throws IOException { + public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler handler, Map attributes) throws IOException { - return this.handshakeHandler.doHandshake(request, response, handler); + return this.handshakeHandler.doHandshake(request, response, handler, attributes); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrPollingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrPollingTransportHandler.java index ed59a339e2d..832525dfa20 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrPollingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrPollingTransportHandler.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.nio.charset.Charset; +import java.util.Map; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; @@ -51,8 +52,10 @@ public class XhrPollingTransportHandler extends AbstractHttpSendingTransportHand } @Override - public PollingSockJsSession createSession(String sessionId, WebSocketHandler handler) { - return new PollingSockJsSession(sessionId, getSockJsServiceConfig(), handler); + public PollingSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, + Map attributes) { + + return new PollingSockJsSession(sessionId, getSockJsServiceConfig(), wsHandler, attributes); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java index 307471c7059..03da01986fc 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java @@ -18,6 +18,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.io.IOException; import java.nio.charset.Charset; +import java.util.Map; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; @@ -49,8 +50,10 @@ public class XhrStreamingTransportHandler extends AbstractHttpSendingTransportHa } @Override - public StreamingSockJsSession createSession(String sessionId, WebSocketHandler handler) { - return new XhrStreamingSockJsSession(sessionId, getSockJsServiceConfig(), handler); + public StreamingSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, + Map attributes) { + + return new XhrStreamingSockJsSession(sessionId, getSockJsServiceConfig(), wsHandler, attributes); } @Override @@ -61,8 +64,10 @@ public class XhrStreamingTransportHandler extends AbstractHttpSendingTransportHa private final class XhrStreamingSockJsSession extends StreamingSockJsSession { - private XhrStreamingSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + private XhrStreamingSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index 8f85b4ca2ed..f3a0c9d3472 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -18,7 +18,9 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; import java.net.InetSocketAddress; +import java.net.URI; import java.security.Principal; +import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -52,7 +54,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private ServerHttpAsyncRequestControl asyncRequestControl; - private String protocol; + private URI uri; private HttpHeaders handshakeHeaders; @@ -62,12 +64,21 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private InetSocketAddress remoteAddress; + private String acceptedProtocol; - public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler) { - super(id, config, wsHandler); + + public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map handshakeAttributes) { + + super(id, config, wsHandler, handshakeAttributes); } + @Override + public URI getUri() { + return this.uri; + } + @Override public HttpHeaders getHandshakeHeaders() { return this.handshakeHeaders; @@ -112,14 +123,14 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { * @param protocol the sub-protocol to set */ public void setAcceptedProtocol(String protocol) { - this.protocol = protocol; + this.acceptedProtocol = protocol; } /** * Return the selected sub-protocol to use. */ public String getAcceptedProtocol() { - return this.protocol; + return this.acceptedProtocol; } public synchronized void setInitialRequest(ServerHttpRequest request, ServerHttpResponse response, @@ -135,6 +146,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { throw new SockJsTransportFailureException("Failed to send \"open\" frame", getId(), t); } + this.uri = request.getURI(); this.handshakeHeaders = request.getHeaders(); this.principal = request.getPrincipal(); this.localAddress = request.getLocalAddress(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java index 0a94cce68ca..1fbd5e1de3a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java @@ -19,12 +19,11 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.EOFException; import java.io.IOException; import java.net.SocketException; -import java.net.URI; -import java.security.Principal; import java.util.ArrayList; import java.util.Arrays; import java.util.Date; import java.util.List; +import java.util.Map; import java.util.concurrent.ScheduledFuture; import org.apache.commons.logging.Log; @@ -51,18 +50,12 @@ public abstract class AbstractSockJsSession implements WebSocketSession { private final String id; - private URI uri; - - private String remoteHostName; - - private String remoteAddress; - - private Principal principal; - - private final SockJsServiceConfig sockJsServiceConfig; + private final SockJsServiceConfig config; private final WebSocketHandler handler; + private final Map handshakeAttributes; + private State state = State.NEW; private final long timeCreated = System.currentTimeMillis(); @@ -73,17 +66,21 @@ public abstract class AbstractSockJsSession implements WebSocketSession { /** - * @param sessionId the session ID + * @param id the session ID * @param config SockJS service configuration options * @param wsHandler the recipient of SockJS messages */ - public AbstractSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler wsHandler) { - Assert.notNull(sessionId, "sessionId is required"); + public AbstractSockJsSession(String id, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map handshakeAttributes) { + + Assert.notNull(id, "sessionId is required"); Assert.notNull(config, "sockJsConfig is required"); Assert.notNull(wsHandler, "webSocketHandler is required"); - this.id = sessionId; - this.sockJsServiceConfig = config; + + this.id = id; + this.config = config; this.handler = wsHandler; + this.handshakeAttributes = handshakeAttributes; } @Override @@ -91,13 +88,13 @@ public abstract class AbstractSockJsSession implements WebSocketSession { return this.id; } - @Override - public URI getUri() { - return this.uri; + public SockJsServiceConfig getSockJsServiceConfig() { + return this.config; } - public SockJsServiceConfig getSockJsServiceConfig() { - return this.sockJsServiceConfig; + @Override + public Map getHandshakeAttributes() { + return this.handshakeAttributes; } public boolean isNew() { @@ -306,13 +303,13 @@ public abstract class AbstractSockJsSession implements WebSocketSession { } protected void scheduleHeartbeat() { - Assert.state(this.sockJsServiceConfig.getTaskScheduler() != null, "heartbeatScheduler not configured"); + Assert.state(this.config.getTaskScheduler() != null, "heartbeatScheduler not configured"); cancelHeartbeat(); if (!isActive()) { return; } - Date time = new Date(System.currentTimeMillis() + this.sockJsServiceConfig.getHeartbeatTime()); - this.heartbeatTask = this.sockJsServiceConfig.getTaskScheduler().schedule(new Runnable() { + Date time = new Date(System.currentTimeMillis() + this.config.getHeartbeatTime()); + this.heartbeatTask = this.config.getTaskScheduler().schedule(new Runnable() { public void run() { try { sendHeartbeat(); @@ -323,7 +320,7 @@ public abstract class AbstractSockJsSession implements WebSocketSession { } }, time); if (logger.isTraceEnabled()) { - logger.trace("Scheduled heartbeat after " + this.sockJsServiceConfig.getHeartbeatTime()/1000 + " seconds"); + logger.trace("Scheduled heartbeat after " + this.config.getHeartbeatTime()/1000 + " seconds"); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java index 9fe6a5b8c72..1f1e4668f55 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java @@ -16,6 +16,8 @@ package org.springframework.web.socket.sockjs.transport.session; +import java.util.Map; + import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; @@ -28,8 +30,11 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsMessageCodec; */ public class PollingSockJsSession extends AbstractHttpSockJsSession { - public PollingSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + + public PollingSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java index 7ae5770e440..017bc67f2ce 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; +import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -37,8 +38,10 @@ public class StreamingSockJsSession extends AbstractHttpSockJsSession { private int byteCount; - public StreamingSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + public StreamingSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java index 3f93b23cbcd..ac3cbfcf6a4 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java @@ -18,7 +18,9 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; import java.net.InetSocketAddress; +import java.net.URI; import java.security.Principal; +import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.util.Assert; @@ -44,8 +46,17 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession private WebSocketSession wsSession; - public WebSocketServerSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler) { - super(id, config, wsHandler); + public WebSocketServerSockJsSession(String id, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(id, config, wsHandler, attributes); + } + + + @Override + public URI getUri() { + checkDelegateSessionInitialized(); + return this.wsSession.getUri(); } @Override diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java index a964543b15d..cd797a92aa1 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java @@ -44,7 +44,7 @@ public class JettyWebSocketHandlerAdapterTests { public void setup() { this.session = mock(Session.class); this.webSocketHandler = mock(WebSocketHandler.class); - this.webSocketSession = new JettyWebSocketSession(null); + this.webSocketSession = new JettyWebSocketSession(null, null); this.adapter = new JettyWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java index ede2a5e9514..6a3b071adf7 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java @@ -50,7 +50,7 @@ public class StandardWebSocketHandlerAdapterTests { public void setup() { this.session = mock(Session.class); this.webSocketHandler = mock(WebSocketHandler.class); - this.webSocketSession = new StandardWebSocketSession(null, null, null); + this.webSocketSession = new StandardWebSocketSession(null, null, null, null); this.adapter = new StandardWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java index 2eaa55054ea..a761fd9c54e 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java @@ -113,7 +113,7 @@ public class JettyWebSocketClientTests { resp.setAcceptedSubProtocol(req.getSubProtocols().get(0)); } - JettyWebSocketSession session = new JettyWebSocketSession(null); + JettyWebSocketSession session = new JettyWebSocketSession(null, null); return new JettyWebSocketHandlerAdapter(webSocketHandler, session); } }); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java index c47ed6906da..537b8a524df 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java @@ -16,6 +16,9 @@ package org.springframework.web.socket.server; +import java.util.Collections; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -62,10 +65,11 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { this.request.getHeaders().setSecWebSocketProtocol("STOMP"); WebSocketHandler handler = new TextWebSocketHandlerAdapter(); + Map attributes = Collections.emptyMap(); - this.handshakeHandler.doHandshake(this.request, this.response, handler); + this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes); - verify(this.upgradeStrategy).upgrade(request, response, "STOMP", handler); + verify(this.upgradeStrategy).upgrade(this.request, this.response, "STOMP", handler, attributes); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HandshakeInterceptorChainTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HandshakeInterceptorChainTests.java new file mode 100644 index 00000000000..e0f28e5f922 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HandshakeInterceptorChainTests.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.support; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.web.socket.AbstractHttpRequestTests; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; + +import static org.mockito.Mockito.*; + + +/** + * Test fixture for {@link HandshakeInterceptorChain}. + * + * @author Rossen Stoyanchev + */ +public class HandshakeInterceptorChainTests extends AbstractHttpRequestTests { + + private HandshakeInterceptor i1; + + private HandshakeInterceptor i2; + + private HandshakeInterceptor i3; + + private List interceptors; + + private WebSocketHandler wsHandler; + + private Map attributes; + + + @Before + public void setup() { + i1 = mock(HandshakeInterceptor.class); + i2 = mock(HandshakeInterceptor.class); + i3 = mock(HandshakeInterceptor.class); + interceptors = Arrays.asList(i1, i2, i3); + wsHandler = mock(WebSocketHandler.class); + attributes = new HashMap(); + } + + + @Test + public void success() throws Exception { + when(i1.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(true); + when(i2.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(true); + when(i3.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(true); + + HandshakeInterceptorChain chain = new HandshakeInterceptorChain(interceptors, wsHandler); + chain.applyBeforeHandshake(request, response, attributes); + + verify(i1).beforeHandshake(request, response, wsHandler, attributes); + verify(i2).beforeHandshake(request, response, wsHandler, attributes); + verify(i3).beforeHandshake(request, response, wsHandler, attributes); + verifyNoMoreInteractions(i1, i2, i3); + } + + @Test + public void applyBeforeHandshakeWithFalseReturnValue() throws Exception { + when(i1.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(true); + when(i2.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(false); + + HandshakeInterceptorChain chain = new HandshakeInterceptorChain(interceptors, wsHandler); + chain.applyBeforeHandshake(request, response, attributes); + + verify(i1).beforeHandshake(request, response, wsHandler, attributes); + verify(i1).afterHandshake(request, response, wsHandler, null); + verify(i2).beforeHandshake(request, response, wsHandler, attributes); + verifyNoMoreInteractions(i1, i2, i3); + } + + @Test + public void applyAfterHandshakeOnly() { + HandshakeInterceptorChain chain = new HandshakeInterceptorChain(interceptors, wsHandler); + chain.applyAfterHandshake(request, response, null); + + verifyNoMoreInteractions(i1, i2, i3); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java new file mode 100644 index 00000000000..bc51fd5d77f --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.support; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import org.junit.Test; +import org.mockito.Mockito; +import org.springframework.web.socket.AbstractHttpRequestTests; +import org.springframework.web.socket.WebSocketHandler; + +import static org.junit.Assert.*; + + +/** + * Test fixture for {@link HttpSessionHandshakeInterceptor}. + * + * @author Rossen Stoyanchev + */ +public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTests { + + + @Test + public void copyAllAttributes() throws Exception { + + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + + this.servletRequest.getSession().setAttribute("foo", "bar"); + this.servletRequest.getSession().setAttribute("bar", "baz"); + + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); + interceptor.beforeHandshake(request, response, wsHandler, attributes); + + assertEquals(2, attributes.size()); + assertEquals("bar", attributes.get("foo")); + assertEquals("baz", attributes.get("bar")); + } + + @Test + public void copySelectedAttributes() throws Exception { + + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + + this.servletRequest.getSession().setAttribute("foo", "bar"); + this.servletRequest.getSession().setAttribute("bar", "baz"); + + Set names = Collections.singleton("foo"); + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(names); + interceptor.beforeHandshake(request, response, wsHandler, attributes); + + assertEquals(1, attributes.size()); + assertEquals("bar", attributes.get("foo")); + } + + @Test + public void doNotCauseSessionCreation() throws Exception { + + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); + interceptor.beforeHandshake(request, response, wsHandler, attributes); + + assertNull(this.servletRequest.getSession(false)); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java index 85c3540e393..ba1229a92b4 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -70,10 +71,11 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { MockitoAnnotations.initMocks(this); - this.session = new TestSockJsSession(sessionId, new StubSockJsServiceConfig(), this.wsHandler); + Map attributes = Collections.emptyMap(); + this.session = new TestSockJsSession(sessionId, new StubSockJsServiceConfig(), this.wsHandler, attributes); when(this.xhrHandler.getTransportType()).thenReturn(TransportType.XHR); - when(this.xhrHandler.createSession(sessionId, this.wsHandler)).thenReturn(this.session); + when(this.xhrHandler.createSession(sessionId, this.wsHandler, attributes)).thenReturn(this.session); when(this.xhrSendHandler.getTransportType()).thenReturn(TransportType.XHR_SEND); this.service = new DefaultSockJsService(this.taskScheduler, diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpReceivingTransportHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpReceivingTransportHandlerTests.java index 0f3f54ce2e7..ef8396abd10 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpReceivingTransportHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpReceivingTransportHandlerTests.java @@ -107,7 +107,7 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest this.servletRequest.setContent("[\"x\"]".getBytes("UTF-8")); WebSocketHandler wsHandler = mock(WebSocketHandler.class); - TestHttpSockJsSession session = new TestHttpSockJsSession("1", sockJsConfig, wsHandler); + TestHttpSockJsSession session = new TestHttpSockJsSession("1", sockJsConfig, wsHandler, null); session.delegateConnectionEstablished(); doThrow(new Exception()).when(wsHandler).handleMessage(session, new TextMessage("x")); @@ -127,7 +127,7 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest private void handleRequest(AbstractHttpReceivingTransportHandler transportHandler) throws Exception { WebSocketHandler wsHandler = mock(WebSocketHandler.class); - AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler); + AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler, null); transportHandler.setSockJsServiceConfiguration(new StubSockJsServiceConfig()); transportHandler.handleRequest(this.request, this.response, wsHandler, session); @@ -141,7 +141,7 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest resetResponse(); WebSocketHandler wsHandler = mock(WebSocketHandler.class); - AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler); + AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler, null); new XhrReceivingTransportHandler().handleRequest(this.request, this.response, wsHandler, session); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java index b58937b540a..7e7c9f1b18d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java @@ -66,7 +66,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests XhrPollingTransportHandler transportHandler = new XhrPollingTransportHandler(); transportHandler.setSockJsServiceConfiguration(this.sockJsConfig); - AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler); + AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session); assertEquals("application/javascript;charset=UTF-8", this.response.getHeaders().getContentType().toString()); @@ -92,7 +92,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests JsonpPollingTransportHandler transportHandler = new JsonpPollingTransportHandler(); transportHandler.setSockJsServiceConfiguration(this.sockJsConfig); - PollingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler); + PollingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session); @@ -114,7 +114,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests XhrStreamingTransportHandler transportHandler = new XhrStreamingTransportHandler(); transportHandler.setSockJsServiceConfiguration(this.sockJsConfig); - AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler); + AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session); @@ -128,7 +128,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests HtmlFileTransportHandler transportHandler = new HtmlFileTransportHandler(); transportHandler.setSockJsServiceConfiguration(this.sockJsConfig); - StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler); + StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session); @@ -150,7 +150,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests EventSourceTransportHandler transportHandler = new EventSourceTransportHandler(); transportHandler.setSockJsServiceConfiguration(this.sockJsConfig); - StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler); + StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSessionTests.java index 74abb4fb31b..ed8c9b28a4d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSessionTests.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; +import java.util.Map; import org.junit.Before; import org.junit.Test; @@ -71,7 +72,7 @@ public class AbstractHttpSockJsSessionTests extends BaseAbstractSockJsSessionTes @Override protected TestAbstractHttpSockJsSession initSockJsSession() { - return new TestAbstractHttpSockJsSession(this.sockJsConfig, this.webSocketHandler); + return new TestAbstractHttpSockJsSession(this.sockJsConfig, this.webSocketHandler, null); } @Test @@ -126,8 +127,10 @@ public class AbstractHttpSockJsSessionTests extends BaseAbstractSockJsSessionTes private boolean heartbeatScheduled; - public TestAbstractHttpSockJsSession(SockJsServiceConfig config, WebSocketHandler handler) { - super("1", config, handler); + public TestAbstractHttpSockJsSession(SockJsServiceConfig config, WebSocketHandler handler, + Map attributes) { + + super("1", config, handler, attributes); } public boolean wasCacheFlushed() { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSessionTests.java index d993960d39c..f4a76713bb9 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSessionTests.java @@ -45,7 +45,8 @@ public class AbstractSockJsSessionTests extends BaseAbstractSockJsSessionTestsemptyMap()); } @Test @@ -102,7 +103,8 @@ public class AbstractSockJsSessionTests extends BaseAbstractSockJsSessionTestsemptyMap()); String msg1 = "message 1"; String msg2 = "message 2"; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestHttpSockJsSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestHttpSockJsSession.java index ac4fefeb294..6269883e2b1 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestHttpSockJsSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestHttpSockJsSession.java @@ -19,6 +19,7 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketHandler; @@ -45,8 +46,10 @@ public class TestHttpSockJsSession extends AbstractHttpSockJsSession { private String subProtocol; - public TestHttpSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + public TestHttpSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } @Override diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java index 4aad2c63363..1cee547aab2 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java @@ -18,9 +18,11 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; import java.net.InetSocketAddress; +import java.net.URI; import java.security.Principal; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; @@ -32,6 +34,8 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; */ public class TestSockJsSession extends AbstractSockJsSession { + private URI uri; + private HttpHeaders headers; private Principal principal; @@ -55,11 +59,22 @@ public class TestSockJsSession extends AbstractSockJsSession { private String subProtocol; - public TestSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + public TestSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } + public void setUri(URI uri) { + this.uri = uri; + } + + @Override + public URI getUri() { + return this.uri; + } + @Override public HttpHeaders getHandshakeHeaders() { return this.headers; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java index 003f2503c0b..f6ff6e50435 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Test; @@ -53,7 +54,8 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession @Override protected TestWebSocketServerSockJsSession initSockJsSession() { - return new TestWebSocketServerSockJsSession(this.sockJsConfig, this.webSocketHandler); + return new TestWebSocketServerSockJsSession(this.sockJsConfig, this.webSocketHandler, + Collections.emptyMap()); } @Test @@ -132,8 +134,10 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession private final List heartbeatSchedulingEvents = new ArrayList<>(); - public TestWebSocketServerSockJsSession(SockJsServiceConfig config, WebSocketHandler handler) { - super("1", config, handler); + public TestWebSocketServerSockJsSession(SockJsServiceConfig config, WebSocketHandler handler, + Map attributes) { + + super("1", config, handler, attributes); } @Override diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java index 53d78aa1ea8..6f16888e688 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java @@ -21,7 +21,9 @@ import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; @@ -39,6 +41,8 @@ public class TestWebSocketSession implements WebSocketSession { private URI uri; + private Map attributes = new HashMap(); + private Principal principal; private InetSocketAddress localAddress; @@ -106,6 +110,21 @@ public class TestWebSocketSession implements WebSocketSession { this.headers = headers; } + /** + * @param attributes the attributes to set + */ + public void setHandshakeAttributes(Map attributes) { + this.attributes = attributes; + } + + /** + * @return the attributes + */ + @Override + public Map getHandshakeAttributes() { + return this.attributes; + } + /** * @return the principal */