@ -28,9 +28,8 @@ import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription ;
import org.reactivestreams.Subscription ;
import org.springframework.core.io.buffer.DataBuffer ;
import org.springframework.core.io.buffer.DataBuffer ;
import org.springframework.core.io.buffer.DataBufferFactory ;
import org.springframework.http.server.reactive.ServerHttpRequest ;
import org.springframework.core.io.buffer.DefaultDataBufferFactory ;
import org.springframework.http.server.reactive.ServerHttpResponse ;
import org.springframework.util.Assert ;
import org.springframework.web.reactive.socket.CloseStatus ;
import org.springframework.web.reactive.socket.CloseStatus ;
import org.springframework.web.reactive.socket.WebSocketHandler ;
import org.springframework.web.reactive.socket.WebSocketHandler ;
import org.springframework.web.reactive.socket.WebSocketMessage ;
import org.springframework.web.reactive.socket.WebSocketMessage ;
@ -43,76 +42,84 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type;
* @author Violeta Georgieva
* @author Violeta Georgieva
* @since 5 . 0
* @since 5 . 0
* /
* /
public class TomcatWebSocketHandlerAdapter extends Endpoint {
public class TomcatWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport {
private final WebSocketHandler delegate ;
private TomcatWebSocketSession session ;
private TomcatWebSocketSession session ;
private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory ( false ) ;
public TomcatWebSocketHandlerAdapter ( ServerHttpRequest request , ServerHttpResponse response ,
WebSocketHandler delegate ) {
public TomcatWebSocketHandlerAdapter ( WebSocketHandler delegate ) {
super ( request , response , delegate ) ;
Assert . notNull ( "WebSocketHandler is required" ) ;
this . delegate = delegate ;
}
}
@Override
public Endpoint getEndpoint ( ) {
public void onOpen ( Session session , EndpointConfig config ) {
return new StandardEndpoint ( ) ;
this . session = new TomcatWebSocketSession ( session ) ;
session . addMessageHandler ( String . class , message - > {
WebSocketMessage webSocketMessage = toMessage ( message ) ;
this . session . handleMessage ( webSocketMessage . getType ( ) , webSocketMessage ) ;
} ) ;
session . addMessageHandler ( ByteBuffer . class , message - > {
WebSocketMessage webSocketMessage = toMessage ( message ) ;
this . session . handleMessage ( webSocketMessage . getType ( ) , webSocketMessage ) ;
} ) ;
session . addMessageHandler ( PongMessage . class , message - > {
WebSocketMessage webSocketMessage = toMessage ( message ) ;
this . session . handleMessage ( webSocketMessage . getType ( ) , webSocketMessage ) ;
} ) ;
HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber ( ) ;
this . delegate . handle ( this . session ) . subscribe ( resultSubscriber ) ;
}
}
private < T > WebSocketMessage toMessage ( T message ) {
private TomcatWebSocketSession getSession ( ) {
if ( message instanceof String ) {
return this . session ;
byte [ ] bytes = ( ( String ) message ) . getBytes ( StandardCharsets . UTF_8 ) ;
}
return WebSocketMessage . create ( Type . TEXT , this . bufferFactory . wrap ( bytes ) ) ;
}
else if ( message instanceof ByteBuffer ) {
private class StandardEndpoint extends Endpoint {
DataBuffer buffer = this . bufferFactory . wrap ( ( ByteBuffer ) message ) ;
return WebSocketMessage . create ( Type . BINARY , buffer ) ;
@Override
}
public void onOpen ( Session session , EndpointConfig config ) {
else if ( message instanceof PongMessage ) {
TomcatWebSocketHandlerAdapter . this . session = new TomcatWebSocketSession ( session ) ;
DataBuffer buffer = this . bufferFactory . wrap ( ( ( PongMessage ) message ) . getApplicationData ( ) ) ;
return WebSocketMessage . create ( Type . PONG , buffer ) ;
session . addMessageHandler ( String . class , message - > {
WebSocketMessage webSocketMessage = toMessage ( message ) ;
getSession ( ) . handleMessage ( webSocketMessage . getType ( ) , webSocketMessage ) ;
} ) ;
session . addMessageHandler ( ByteBuffer . class , message - > {
WebSocketMessage webSocketMessage = toMessage ( message ) ;
getSession ( ) . handleMessage ( webSocketMessage . getType ( ) , webSocketMessage ) ;
} ) ;
session . addMessageHandler ( PongMessage . class , message - > {
WebSocketMessage webSocketMessage = toMessage ( message ) ;
getSession ( ) . handleMessage ( webSocketMessage . getType ( ) , webSocketMessage ) ;
} ) ;
HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber ( ) ;
getDelegate ( ) . handle ( TomcatWebSocketHandlerAdapter . this . session ) . subscribe ( resultSubscriber ) ;
}
}
else {
throw new IllegalArgumentException ( "Unexpected message type: " + message ) ;
private < T > WebSocketMessage toMessage ( T message ) {
if ( message instanceof String ) {
byte [ ] bytes = ( ( String ) message ) . getBytes ( StandardCharsets . UTF_8 ) ;
return WebSocketMessage . create ( Type . TEXT , getBufferFactory ( ) . wrap ( bytes ) ) ;
}
else if ( message instanceof ByteBuffer ) {
DataBuffer buffer = getBufferFactory ( ) . wrap ( ( ByteBuffer ) message ) ;
return WebSocketMessage . create ( Type . BINARY , buffer ) ;
}
else if ( message instanceof PongMessage ) {
DataBuffer buffer = getBufferFactory ( ) . wrap ( ( ( PongMessage ) message ) . getApplicationData ( ) ) ;
return WebSocketMessage . create ( Type . PONG , buffer ) ;
}
else {
throw new IllegalArgumentException ( "Unexpected message type: " + message ) ;
}
}
}
}
@Override
@Override
public void onClose ( Session session , CloseReason reason ) {
public void onClose ( Session session , CloseReason reason ) {
if ( this . session ! = null ) {
if ( getSession ( ) ! = null ) {
int code = reason . getCloseCode ( ) . getCode ( ) ;
int code = reason . getCloseCode ( ) . getCode ( ) ;
this . session . handleClose ( new CloseStatus ( code , reason . getReasonPhrase ( ) ) ) ;
getSession ( ) . handleClose ( new CloseStatus ( code , reason . getReasonPhrase ( ) ) ) ;
}
}
}
}
@Override
@Override
public void onError ( Session session , Throwable exception ) {
public void onError ( Session session , Throwable exception ) {
if ( this . session ! = null ) {
if ( getSession ( ) ! = null ) {
this . session . handleError ( exception ) ;
getSession ( ) . handleError ( exception ) ;
}
}
}
}
}
private final class HandlerResultSubscriber implements Subscriber < Void > {
private final class HandlerResultSubscriber implements Subscriber < Void > {
@Override
@Override
@ -127,15 +134,15 @@ public class TomcatWebSocketHandlerAdapter extends Endpoint {
@Override
@Override
public void onError ( Throwable ex ) {
public void onError ( Throwable ex ) {
if ( session ! = null ) {
if ( getSession ( ) ! = null ) {
session . close ( new CloseStatus ( CloseStatus . SERVER_ERROR . getCode ( ) , ex . getMessage ( ) ) ) ;
getSession ( ) . close ( new CloseStatus ( CloseStatus . SERVER_ERROR . getCode ( ) , ex . getMessage ( ) ) ) ;
}
}
}
}
@Override
@Override
public void onComplete ( ) {
public void onComplete ( ) {
if ( session ! = null ) {
if ( getSession ( ) ! = null ) {
session . close ( ) ;
getSession ( ) . close ( ) ;
}
}
}
}
}
}