diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java index 854d22a59ff..0a4b0acee8b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.lang.reflect.Method; import java.util.Arrays; import java.util.Collections; import java.util.Map; @@ -26,13 +27,17 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.Endpoint; +import javax.websocket.server.ServerEndpointConfig; +import org.apache.tomcat.websocket.server.WsHandshakeRequest; +import org.apache.tomcat.websocket.server.WsHttpUpgradeHandler; import org.apache.tomcat.websocket.server.WsServerContainer; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration; @@ -60,6 +65,18 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg Assert.isTrue(response instanceof ServletServerHttpResponse); HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse(); + if (hasDoUpgrade) { + doUpgrade(servletRequest, servletResponse, acceptedProtocol, endpoint); + } + else { + upgradeTomcat80RC1(servletRequest, acceptedProtocol, endpoint); + } + } + + private void doUpgrade(HttpServletRequest servletRequest, HttpServletResponse servletResponse, + String acceptedProtocol, Endpoint endpoint) { + + StringBuffer requestUrl = servletRequest.getRequestURL(); String path = servletRequest.getRequestURI(); // shouldn't matter Map pathParams = Collections. emptyMap(); @@ -71,11 +88,11 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg } catch (ServletException ex) { throw new HandshakeFailureException( - "Servlet request failed to upgrade to WebSocket, uri=" + request.getURI(), ex); + "Servlet request failed to upgrade to WebSocket, uri=" + requestUrl, ex); } catch (IOException ex) { throw new HandshakeFailureException( - "Response update failed during upgrade to WebSocket, uri=" + request.getURI(), ex); + "Response update failed during upgrade to WebSocket, uri=" + requestUrl, ex); } } @@ -85,4 +102,36 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg return (WsServerContainer) servletContext.getAttribute(attribute); } -} + // FIXME: Remove this after RC2 is out + + private void upgradeTomcat80RC1(HttpServletRequest request, String protocol, Endpoint endpoint) { + + WsHttpUpgradeHandler upgradeHandler; + try { + upgradeHandler = request.upgrade(WsHttpUpgradeHandler.class); + } + catch (Exception e) { + throw new HandshakeFailureException("Unable to create UpgardeHandler", e); + } + + WsHandshakeRequest webSocketRequest = new WsHandshakeRequest(request); + try { + Method method = ReflectionUtils.findMethod(WsHandshakeRequest.class, "finished"); + ReflectionUtils.makeAccessible(method); + method.invoke(webSocketRequest); + } + catch (Exception ex) { + throw new HandshakeFailureException("Failed to upgrade HttpServletRequest", ex); + } + + ServerEndpointConfig endpointConfig = new ServerEndpointRegistration("/shouldntmatter", endpoint); + + upgradeHandler.preInit(endpoint, endpointConfig, getContainer(request), webSocketRequest, + protocol, Collections. emptyMap(), request.isSecure()); + } + + private static boolean hasDoUpgrade = (ReflectionUtils.findMethod(WsServerContainer.class, + "doUpgrade", HttpServletRequest.class, HttpServletResponse.class, + ServerEndpointConfig.class, Map.class) != null); + +} \ No newline at end of file