diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java index 82035d46082..dbb4c530287 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java @@ -213,7 +213,12 @@ public abstract class AbstractListenerWebSocketSession extends WebSocketSessi return this.isReady && this.currentData != null; } - public void setReady(boolean ready) { + /** + * Sub-classes can invoke this before sending a message (false) and + * after receiving the async send callback (true) effective translating + * async completion callback into simple flow control. + */ + public void setReadyToSend(boolean ready) { this.isReady = ready; } } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java index 815bd2da79e..9314dfe85d7 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java @@ -30,6 +30,8 @@ import org.eclipse.jetty.websocket.api.extensions.Frame; import org.eclipse.jetty.websocket.common.OpCode; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; + +import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.util.Assert; @@ -50,84 +52,90 @@ public class JettyWebSocketHandlerAdapter { private static final ByteBuffer EMPTY_PAYLOAD = ByteBuffer.wrap(new byte[0]); + private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false); - private final WebSocketHandler handler; + private final WebSocketHandler delegate; + + private JettyWebSocketSession session; - private JettyWebSocketSession wsSession; - public JettyWebSocketHandlerAdapter(WebSocketHandler handler) { - Assert.notNull("'handler' is required"); - this.handler = handler; + public JettyWebSocketHandlerAdapter(WebSocketHandler delegate) { + Assert.notNull("WebSocketHandler is required"); + this.delegate = delegate; } + @OnWebSocketConnect public void onWebSocketConnect(Session session) { - this.wsSession = new JettyWebSocketSession(session); + this.session = new JettyWebSocketSession(session); - HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(); - this.handler.handle(this.wsSession).subscribe(resultSubscriber); + HandlerResultSubscriber subscriber = new HandlerResultSubscriber(); + this.delegate.handle(this.session).subscribe(subscriber); } @OnWebSocketMessage public void onWebSocketText(String message) { - if (this.wsSession != null) { - WebSocketMessage wsMessage = toMessage(Type.TEXT, message); - this.wsSession.handleMessage(wsMessage.getType(), wsMessage); + if (this.session != null) { + WebSocketMessage webSocketMessage = toMessage(Type.TEXT, message); + this.session.handleMessage(webSocketMessage.getType(), webSocketMessage); } } @OnWebSocketMessage public void onWebSocketBinary(byte[] message, int offset, int length) { - if (this.wsSession != null) { - WebSocketMessage wsMessage = toMessage(Type.BINARY, ByteBuffer.wrap(message, offset, length)); - wsSession.handleMessage(wsMessage.getType(), wsMessage); + if (this.session != null) { + ByteBuffer buffer = ByteBuffer.wrap(message, offset, length); + WebSocketMessage webSocketMessage = toMessage(Type.BINARY, buffer); + session.handleMessage(webSocketMessage.getType(), webSocketMessage); } } @OnWebSocketFrame public void onWebSocketFrame(Frame frame) { - if (this.wsSession != null) { + if (this.session != null) { if (OpCode.PONG == frame.getOpCode()) { - ByteBuffer message = frame.getPayload() != null ? frame.getPayload() : EMPTY_PAYLOAD; - WebSocketMessage wsMessage = toMessage(Type.PONG, message); - wsSession.handleMessage(wsMessage.getType(), wsMessage); + ByteBuffer buffer = (frame.getPayload() != null ? frame.getPayload() : EMPTY_PAYLOAD); + WebSocketMessage webSocketMessage = toMessage(Type.PONG, buffer); + session.handleMessage(webSocketMessage.getType(), webSocketMessage); } } } @OnWebSocketClose public void onWebSocketClose(int statusCode, String reason) { - if (this.wsSession != null) { - this.wsSession.handleClose(new CloseStatus(statusCode, reason)); + if (this.session != null) { + this.session.handleClose(new CloseStatus(statusCode, reason)); } } @OnWebSocketError public void onWebSocketError(Throwable cause) { - if (this.wsSession != null) { - this.wsSession.handleError(cause); + if (this.session != null) { + this.session.handleError(cause); } } private WebSocketMessage toMessage(Type type, T message) { if (Type.TEXT.equals(type)) { - return WebSocketMessage.create(Type.TEXT, - bufferFactory.wrap(((String) message).getBytes(StandardCharsets.UTF_8))); + byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); + DataBuffer buffer = this.bufferFactory.wrap(bytes); + return WebSocketMessage.create(Type.TEXT, buffer); } else if (Type.BINARY.equals(type)) { - return WebSocketMessage.create(Type.BINARY, - bufferFactory.wrap((ByteBuffer) message)); + DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message); + return WebSocketMessage.create(Type.BINARY, buffer); } else if (Type.PONG.equals(type)) { - return WebSocketMessage.create(Type.PONG, - bufferFactory.wrap((ByteBuffer) message)); + DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message); + return WebSocketMessage.create(Type.PONG, buffer); } else { throw new IllegalArgumentException("Unexpected message type: " + message); } } + private final class HandlerResultSubscriber implements Subscriber { @Override @@ -142,15 +150,15 @@ public class JettyWebSocketHandlerAdapter { @Override public void onError(Throwable ex) { - if (wsSession != null) { - wsSession.close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage())); + if (session != null) { + session.close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage())); } } @Override public void onComplete() { - if (wsSession != null) { - wsSession.close(); + if (session != null) { + session.close(); } } } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketSession.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketSession.java index 56f04611c4f..7f23241f534 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketSession.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketSession.java @@ -17,6 +17,7 @@ package org.springframework.web.reactive.socket.adapter; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import org.eclipse.jetty.websocket.api.Session; @@ -52,22 +53,21 @@ public class JettyWebSocketSession extends AbstractListenerWebSocketSession() { - - @Override - public void onMessage(String message) { - WebSocketMessage wsMessage = toMessage(message); - wsSession.handleMessage(wsMessage.getType(), wsMessage); - } + this.session = new TomcatWebSocketSession(session); + session.addMessageHandler(String.class, message -> { + WebSocketMessage webSocketMessage = toMessage(message); + this.session.handleMessage(webSocketMessage.getType(), webSocketMessage); }); - session.addMessageHandler(new MessageHandler.Whole() { - - @Override - public void onMessage(ByteBuffer message) { - WebSocketMessage wsMessage = toMessage(message); - wsSession.handleMessage(wsMessage.getType(), wsMessage); - } - + session.addMessageHandler(ByteBuffer.class, message -> { + WebSocketMessage webSocketMessage = toMessage(message); + this.session.handleMessage(webSocketMessage.getType(), webSocketMessage); }); - session.addMessageHandler(new MessageHandler.Whole() { - - @Override - public void onMessage(PongMessage message) { - WebSocketMessage wsMessage = toMessage(message); - wsSession.handleMessage(wsMessage.getType(), wsMessage); - } - + session.addMessageHandler(PongMessage.class, message -> { + WebSocketMessage webSocketMessage = toMessage(message); + this.session.handleMessage(webSocketMessage.getType(), webSocketMessage); }); HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(); - this.handler.handle(this.wsSession).subscribe(resultSubscriber); + this.delegate.handle(this.session).subscribe(resultSubscriber); } @Override public void onClose(Session session, CloseReason reason) { - if (this.wsSession != null) { - this.wsSession.handleClose( - new CloseStatus(reason.getCloseCode().getCode(), reason.getReasonPhrase())); + if (this.session != null) { + int code = reason.getCloseCode().getCode(); + this.session.handleClose(new CloseStatus(code, reason.getReasonPhrase())); } } @Override public void onError(Session session, Throwable exception) { - if (this.wsSession != null) { - this.wsSession.handleError(exception); + if (this.session != null) { + this.session.handleError(exception); } } private WebSocketMessage toMessage(T message) { if (message instanceof String) { - return WebSocketMessage.create(Type.TEXT, - bufferFactory.wrap(((String) message).getBytes(StandardCharsets.UTF_8))); + byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); + return WebSocketMessage.create(Type.TEXT, this.bufferFactory.wrap(bytes)); } else if (message instanceof ByteBuffer) { - return WebSocketMessage.create(Type.BINARY, - bufferFactory.wrap((ByteBuffer) message)); + DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message); + return WebSocketMessage.create(Type.BINARY, buffer); } else if (message instanceof PongMessage) { - return WebSocketMessage.create(Type.PONG, - bufferFactory.wrap(((PongMessage) message).getApplicationData())); + DataBuffer buffer = this.bufferFactory.wrap(((PongMessage) message).getApplicationData()); + return WebSocketMessage.create(Type.PONG, buffer); } else { throw new IllegalArgumentException("Unexpected message type: " + message); } } + private final class HandlerResultSubscriber implements Subscriber { @Override @@ -139,15 +127,15 @@ public class TomcatWebSocketHandlerAdapter extends Endpoint { @Override public void onError(Throwable ex) { - if (wsSession != null) { - wsSession.close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage())); + if (session != null) { + session.close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage())); } } @Override public void onComplete() { - if (wsSession != null) { - wsSession.close(); + if (session != null) { + session.close(); } } } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketSession.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketSession.java index 6c11938ed28..a9698f6962a 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketSession.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketSession.java @@ -17,6 +17,7 @@ package org.springframework.web.reactive.socket.adapter; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import javax.websocket.CloseReason; import javax.websocket.CloseReason.CloseCodes; @@ -59,22 +60,21 @@ public class TomcatWebSocketSession extends AbstractListenerWebSocketSession WebSocketMessage toMessage(Type type, T message) { if (Type.TEXT.equals(type)) { - return WebSocketMessage.create(Type.TEXT, - bufferFactory.wrap(((String) message).getBytes(StandardCharsets.UTF_8))); + byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); + return WebSocketMessage.create(Type.TEXT, bufferFactory.wrap(bytes)); } else if (Type.BINARY.equals(type)) { - return WebSocketMessage.create(Type.BINARY, - bufferFactory.allocateBuffer().write((ByteBuffer[]) message)); + DataBuffer buffer = bufferFactory.allocateBuffer().write((ByteBuffer[]) message); + return WebSocketMessage.create(Type.BINARY, buffer); } else if (Type.PONG.equals(type)) { - return WebSocketMessage.create(Type.PONG, - bufferFactory.allocateBuffer().write((ByteBuffer[]) message)); + DataBuffer buffer = bufferFactory.allocateBuffer().write((ByteBuffer[]) message); + return WebSocketMessage.create(Type.PONG, buffer); } else { throw new IllegalArgumentException("Unexpected message type: " + message); @@ -144,12 +145,12 @@ public class UndertowWebSocketHandlerAdapter implements WebSocketConnectionCallb @Override public void onError(Throwable ex) { - wsSession.close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage())); + session.close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage())); } @Override public void onComplete() { - wsSession.close(); + session.close(); } } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java index eaee5879df8..c2a3a15b065 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java @@ -19,6 +19,7 @@ package org.springframework.web.reactive.socket.adapter; import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import io.undertow.websockets.core.CloseMessage; @@ -41,10 +42,12 @@ import org.springframework.web.reactive.socket.WebSocketSession; */ public class UndertowWebSocketSession extends AbstractListenerWebSocketSession { + public UndertowWebSocketSession(WebSocketChannel channel) throws URISyntaxException { super(channel, ObjectUtils.getIdentityHexString(channel), new URI(channel.getUrl())); } + @Override protected Mono closeInternal(CloseStatus status) { CloseMessage cm = new CloseMessage(status.getCode(), status.getReason()); @@ -69,26 +72,23 @@ public class UndertowWebSocketSession extends AbstractListenerWebSocketSession { + + private final class SendProcessorCallback implements WebSocketCallback { @Override public void complete(WebSocketChannel channel, Void context) { - getSendProcessor().setReady(true); + getSendProcessor().setReadyToSend(true); getSendProcessor().onWritePossible(); } @Override - public void onError(WebSocketChannel channel, Void context, - Throwable throwable) { + public void onError(WebSocketChannel channel, Void context, Throwable throwable) { getSendProcessor().cancel(); getSendProcessor().onError(throwable); } - } }