@ -16,24 +16,21 @@
@@ -16,24 +16,21 @@
package org.springframework.web.socket.server.support ;
import java.io.IOException ;
import java.lang.reflect.Method ;
import java.util.Arrays ;
import java.util.Collections ;
import java.util.Map ;
import javax.servlet.ServletContext ;
import javax.servlet.ServletException ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
import javax.websocket.Endpoint ;
import javax.websocket.server.ServerEndpointConfig ;
import org.apache.tomcat.websocket.server.WsHandshakeRequest ;
import org.apache.tomcat.websocket.server.WsHttpUpgradeHandler ;
import org.apache.tomcat.websocket.server.WsServerContainer ;
import org.springframework.http.server.ServerHttpRequest ;
import org.springframework.http.server.ServerHttpResponse ;
import org.springframework.http.server.ServletServerHttpRequest ;
import org.springframework.http.server.ServletServerHttpResponse ;
import org.springframework.util.Assert ;
import org.springframework.util.ReflectionUtils ;
import org.springframework.web.socket.server.HandshakeFailureException ;
import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration ;
@ -45,6 +42,7 @@ import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration
@@ -45,6 +42,7 @@ import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration
* /
public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {
@Override
public String [ ] getSupportedVersions ( ) {
return new String [ ] { "13" } ;
@ -52,37 +50,32 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg
@@ -52,37 +50,32 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg
@Override
public void upgradeInternal ( ServerHttpRequest request , ServerHttpResponse response ,
String acceptedProtocol , Endpoint endpoint ) throws IO Exception {
String acceptedProtocol , Endpoint endpoint ) throws HandshakeFailure Exception {
Assert . isTrue ( request instanceof ServletServerHttpRequest ) ;
HttpServletRequest servletRequest = ( ( ServletServerHttpRequest ) request ) . getServletRequest ( ) ;
WsHttpUpgradeHandler upgradeHandler ;
try {
upgradeHandler = servletRequest . upgrade ( WsHttpUpgradeHandler . class ) ;
}
catch ( ServletException e ) {
throw new HandshakeFailureException ( "Unable to create UpgardeHandler" , e ) ;
}
Assert . isTrue ( response instanceof ServletServerHttpResponse ) ;
HttpServletResponse servletResponse = ( ( ServletServerHttpResponse ) response ) . getServletResponse ( ) ;
String path = servletRequest . getRequestURI ( ) ; // shouldn't matter
Map < String , String > pathParams = Collections . < String , String > emptyMap ( ) ;
ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration ( path , endpoint ) ;
endpointConfig . setSubprotocols ( Arrays . asList ( acceptedProtocol ) ) ;
WsHandshakeRequest webSocketRequest = new WsHandshakeRequest ( servletRequest ) ;
try {
Method method = ReflectionUtils . findMethod ( WsHandshakeRequest . class , "finished" ) ;
ReflectionUtils . makeAccessible ( method ) ;
method . invoke ( webSocketRequest ) ;
getContainer ( servletRequest ) . doUpgrade ( servletRequest , servletResponse , endpointConfig , pathParams ) ;
}
catch ( Exception ex ) {
throw new HandshakeFailureException ( "Failed to upgrade HttpServletRequest" , ex ) ;
}
}
private WsServerContainer getContainer ( HttpServletRequest servletRequest ) {
String attribute = "javax.websocket.server.ServerContainer" ;
ServletContext servletContext = servletRequest . getServletContext ( ) ;
WsServerContainer serverContainer = ( WsServerContainer ) servletContext . getAttribute ( attribute ) ;
ServerEndpointConfig endpointConfig = new ServerEndpointRegistration ( "/shouldntmatter" , endpoint ) ;
upgradeHandler . preInit ( endpoint , endpointConfig , serverContainer , webSocketRequest ,
acceptedProtocol , Collections . < String , String > emptyMap ( ) , servletRequest . isSecure ( ) ) ;
return ( WsServerContainer ) servletContext . getAttribute ( attribute ) ;
}
}