@ -16,9 +16,12 @@
@@ -16,9 +16,12 @@
package org.springframework.web.socket.server.standard ;
import java.lang.reflect.Constructor ;
import java.util.Arrays ;
import java.util.Collections ;
import java.util.List ;
import java.util.Set ;
import java.util.concurrent.ConcurrentHashMap ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
import javax.websocket.Decoder ;
@ -34,9 +37,6 @@ import io.undertow.servlet.websockets.ServletWebSocketHttpExchange;
@@ -34,9 +37,6 @@ import io.undertow.servlet.websockets.ServletWebSocketHttpExchange;
import io.undertow.websockets.core.WebSocketChannel ;
import io.undertow.websockets.core.WebSocketVersion ;
import io.undertow.websockets.core.protocol.Handshake ;
import io.undertow.websockets.core.protocol.version07.Hybi07Handshake ;
import io.undertow.websockets.core.protocol.version08.Hybi08Handshake ;
import io.undertow.websockets.core.protocol.version13.Hybi13Handshake ;
import io.undertow.websockets.jsr.ConfiguredServerEndpoint ;
import io.undertow.websockets.jsr.EncodingFactory ;
import io.undertow.websockets.jsr.EndpointSessionHandler ;
@ -45,6 +45,7 @@ import io.undertow.websockets.jsr.handshake.HandshakeUtil;
@@ -45,6 +45,7 @@ import io.undertow.websockets.jsr.handshake.HandshakeUtil;
import io.undertow.websockets.jsr.handshake.JsrHybi07Handshake ;
import io.undertow.websockets.jsr.handshake.JsrHybi08Handshake ;
import io.undertow.websockets.jsr.handshake.JsrHybi13Handshake ;
import org.springframework.util.ClassUtils ;
import org.xnio.StreamConnection ;
import org.springframework.http.server.ServerHttpRequest ;
@ -61,16 +62,47 @@ import org.springframework.web.socket.server.HandshakeFailureException;
@@ -61,16 +62,47 @@ import org.springframework.web.socket.server.HandshakeFailureException;
* /
public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {
private final String [ ] supportedVersions = new String [ ] {
private static final Constructor < ServletWebSocketHttpExchange > exchangeConstructor ;
private static final boolean undertow10Present ;
static {
Class < ServletWebSocketHttpExchange > type = ServletWebSocketHttpExchange . class ;
Class < ? > [ ] paramTypes = new Class < ? > [ ] { HttpServletRequest . class , HttpServletResponse . class , Set . class } ;
if ( ClassUtils . hasConstructor ( type , paramTypes ) ) {
exchangeConstructor = ClassUtils . getConstructorIfAvailable ( type , paramTypes ) ;
undertow10Present = false ;
}
else {
paramTypes = new Class < ? > [ ] { HttpServletRequest . class , HttpServletResponse . class } ;
exchangeConstructor = ClassUtils . getConstructorIfAvailable ( type , paramTypes ) ;
undertow10Present = true ;
}
}
private static final String [ ] supportedVersions = new String [ ] {
WebSocketVersion . V13 . toHttpHeaderValue ( ) ,
WebSocketVersion . V08 . toHttpHeaderValue ( ) ,
WebSocketVersion . V07 . toHttpHeaderValue ( )
} ;
private Set < WebSocketChannel > peerConnections ;
public UndertowRequestUpgradeStrategy ( ) {
if ( undertow10Present ) {
this . peerConnections = null ;
}
else {
this . peerConnections = Collections . newSetFromMap ( new ConcurrentHashMap < WebSocketChannel , Boolean > ( ) ) ;
}
}
@Override
public String [ ] getSupportedVersions ( ) {
return this . supportedVersions ;
return supportedVersions ;
}
@Override
@ -80,7 +112,7 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat
@@ -80,7 +112,7 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat
HttpServletRequest servletRequest = getHttpServletRequest ( request ) ;
HttpServletResponse servletResponse = getHttpServletResponse ( response ) ;
final ServletWebSocketHttpExchange exchange = new ServletWebSocket HttpExchange( servletRequest , servletResponse ) ;
final ServletWebSocketHttpExchange exchange = create HttpExchange( servletRequest , servletResponse ) ;
exchange . putAttachment ( HandshakeUtil . PATH_PARAMS , Collections . < String , String > emptyMap ( ) ) ;
ServerWebSocketContainer wsContainer = ( ServerWebSocketContainer ) getContainer ( servletRequest ) ;
@ -95,6 +127,9 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat
@@ -95,6 +127,9 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat
@Override
public void handleUpgrade ( StreamConnection connection , HttpServerExchange serverExchange ) {
WebSocketChannel channel = handshake . createChannel ( exchange , connection , exchange . getBufferPool ( ) ) ;
if ( peerConnections ! = null ) {
peerConnections . add ( channel ) ;
}
endpointSessionHandler . onConnect ( exchange , channel ) ;
}
} ) ;
@ -102,6 +137,17 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat
@@ -102,6 +137,17 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat
handshake . handshake ( exchange ) ;
}
private ServletWebSocketHttpExchange createHttpExchange ( HttpServletRequest request , HttpServletResponse response ) {
try {
return ( this . peerConnections ! = null ?
exchangeConstructor . newInstance ( request , response , this . peerConnections ) :
exchangeConstructor . newInstance ( request , response ) ) ;
}
catch ( Exception ex ) {
throw new HandshakeFailureException ( "Failed to instantiate ServletWebSocketHttpExchange" , ex ) ;
}
}
private Handshake getHandshakeToUse ( ServletWebSocketHttpExchange exchange , ConfiguredServerEndpoint endpoint ) {
Handshake handshake = new JsrHybi13Handshake ( endpoint ) ;
if ( handshake . matches ( exchange ) ) {