diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/HandshakeInfo.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/HandshakeInfo.java index 7cedb404ede..04013021f74 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/HandshakeInfo.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/HandshakeInfo.java @@ -39,26 +39,27 @@ public class HandshakeInfo { private final Mono principalMono; - private HttpHeaders headers; + private final HttpHeaders headers; - private Optional protocol; + private final Optional protocol; - public HandshakeInfo(URI uri, Mono principal) { - this(uri, new HttpHeaders(), principal, Optional.empty()); - } - - public HandshakeInfo(URI uri, HttpHeaders headers, Mono principal, - Optional subProtocol) { - + /** + * Constructor with information about the handshake. + * @param uri the endpoint URL + * @param headers request headers for server or response headers or client + * @param principal the principal for the session + * @param protocol the negotiated sub-protocol + */ + public HandshakeInfo(URI uri, HttpHeaders headers, Mono principal, Optional protocol) { Assert.notNull(uri, "URI is required."); Assert.notNull(headers, "HttpHeaders are required."); Assert.notNull(principal, "Principal is required."); - Assert.notNull(subProtocol, "Sub-protocol is required."); + Assert.notNull(protocol, "Sub-protocol is required."); this.uri = uri; this.headers = headers; this.principalMono = principal; - this.protocol = subProtocol; + this.protocol = protocol; } @@ -77,15 +78,6 @@ public class HandshakeInfo { return this.headers; } - /** - * Sets the handshake HTTP headers. Those are the request headers for a - * server session and the response headers for a client session. - * @param headers the handshake HTTP headers. - */ - public void setHeaders(HttpHeaders headers) { - this.headers = headers; - } - /** * Return the principal associated with the handshake HTTP request. */ @@ -102,14 +94,6 @@ public class HandshakeInfo { return this.protocol; } - /** - * Sets the sub-protocol negotiated at handshake time. - * @param protocol the sub-protocol negotiated at handshake time. - */ - public void setSubProtocol(Optional protocol) { - this.protocol = protocol; - } - @Override public String toString() { diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java index 44509fd9af6..29abdd3a9e2 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java @@ -18,6 +18,7 @@ package org.springframework.web.reactive.socket.adapter; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.function.Function; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose; @@ -33,12 +34,12 @@ import org.reactivestreams.Subscription; import reactor.core.publisher.MonoProcessor; import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.util.Assert; import org.springframework.web.reactive.socket.CloseStatus; -import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketMessage; import org.springframework.web.reactive.socket.WebSocketMessage.Type; +import org.springframework.web.reactive.socket.WebSocketSession; /** * Jetty {@link WebSocket @WebSocket} handler that delegates events to a @@ -49,34 +50,41 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type; * @since 5.0 */ @WebSocket -public class JettyWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport { +public class JettyWebSocketHandlerAdapter { private static final ByteBuffer EMPTY_PAYLOAD = ByteBuffer.wrap(new byte[0]); - private JettyWebSocketSession delegateSession; + + private final WebSocketHandler delegateHandler; private final MonoProcessor completionMono; + private final Function sessionFactory; + + private JettyWebSocketSession delegateSession; + - public JettyWebSocketHandlerAdapter(WebSocketHandler delegate, HandshakeInfo info, - DataBufferFactory bufferFactory) { + public JettyWebSocketHandlerAdapter(WebSocketHandler handler, + Function sessionFactory) { - this(delegate, info, bufferFactory, null); + this(handler, null, sessionFactory); } - public JettyWebSocketHandlerAdapter(WebSocketHandler delegate, HandshakeInfo info, - DataBufferFactory bufferFactory, MonoProcessor completionMono) { + public JettyWebSocketHandlerAdapter(WebSocketHandler handler, MonoProcessor completionMono, + Function sessionFactory) { - super(delegate, info, bufferFactory); + Assert.notNull("WebSocketHandler is required"); + Assert.notNull("'sessionFactory' is required"); + this.delegateHandler = handler; this.completionMono = completionMono; + this.sessionFactory = sessionFactory; } - @OnWebSocketConnect public void onWebSocketConnect(Session session) { - this.delegateSession = new JettyWebSocketSession(session, getHandshakeInfo(), bufferFactory()); + this.delegateSession = sessionFactory.apply(session); HandlerResultSubscriber subscriber = new HandlerResultSubscriber(); - getDelegate().handle(this.delegateSession).subscribe(subscriber); + this.delegateHandler.handle(this.delegateSession).subscribe(subscriber); } @OnWebSocketMessage @@ -108,17 +116,19 @@ public class JettyWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport } private WebSocketMessage toMessage(Type type, T message) { + WebSocketSession session = this.delegateSession; + Assert.state(session != null, "Cannot create message without a session"); if (Type.TEXT.equals(type)) { byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); - DataBuffer buffer = bufferFactory().wrap(bytes); + DataBuffer buffer = session.bufferFactory().wrap(bytes); return new WebSocketMessage(Type.TEXT, buffer); } else if (Type.BINARY.equals(type)) { - DataBuffer buffer = bufferFactory().wrap((ByteBuffer) message); + DataBuffer buffer = session.bufferFactory().wrap((ByteBuffer) message); return new WebSocketMessage(Type.BINARY, buffer); } else if (Type.PONG.equals(type)) { - DataBuffer buffer = bufferFactory().wrap((ByteBuffer) message); + DataBuffer buffer = session.bufferFactory().wrap((ByteBuffer) message); return new WebSocketMessage(Type.PONG, buffer); } else { diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardEndpoint.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardEndpoint.java deleted file mode 100644 index 94fc35597de..00000000000 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardEndpoint.java +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Copyright 2002-2016 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.reactive.socket.adapter; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import javax.websocket.CloseReason; -import javax.websocket.Endpoint; -import javax.websocket.EndpointConfig; -import javax.websocket.PongMessage; -import javax.websocket.Session; - -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import reactor.core.publisher.MonoProcessor; - -import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferFactory; -import org.springframework.web.reactive.socket.CloseStatus; -import org.springframework.web.reactive.socket.HandshakeInfo; -import org.springframework.web.reactive.socket.WebSocketHandler; -import org.springframework.web.reactive.socket.WebSocketMessage; -import org.springframework.web.reactive.socket.WebSocketMessage.Type; - -/** - * {@link Endpoint} delegating events - * to a reactive {@link WebSocketHandler} and its session. - * - * @author Violeta Georgieva - * @since 5.0 - */ -public class StandardEndpoint extends Endpoint { - - private final WebSocketHandler handler; - - private final HandshakeInfo info; - - private final DataBufferFactory bufferFactory; - - private StandardWebSocketSession delegateSession; - - private final MonoProcessor completionMono; - - - public StandardEndpoint(WebSocketHandler handler, HandshakeInfo info, - DataBufferFactory bufferFactory) { - this(handler, info, bufferFactory, null); - } - - public StandardEndpoint(WebSocketHandler handler, HandshakeInfo info, - DataBufferFactory bufferFactory, MonoProcessor completionMono) { - this.handler = handler; - this.info = info; - this.bufferFactory = bufferFactory; - this.completionMono = completionMono; - } - - - @Override - public void onOpen(Session nativeSession, EndpointConfig config) { - - this.delegateSession = new StandardWebSocketSession(nativeSession, this.info, this.bufferFactory); - - nativeSession.addMessageHandler(String.class, message -> { - WebSocketMessage webSocketMessage = toMessage(message); - this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); - }); - nativeSession.addMessageHandler(ByteBuffer.class, message -> { - WebSocketMessage webSocketMessage = toMessage(message); - this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); - }); - nativeSession.addMessageHandler(PongMessage.class, message -> { - WebSocketMessage webSocketMessage = toMessage(message); - this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); - }); - - HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(); - this.handler.handle(this.delegateSession).subscribe(resultSubscriber); - } - - private WebSocketMessage toMessage(T message) { - if (message instanceof String) { - byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); - return new WebSocketMessage(Type.TEXT, this.bufferFactory.wrap(bytes)); - } - else if (message instanceof ByteBuffer) { - DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message); - return new WebSocketMessage(Type.BINARY, buffer); - } - else if (message instanceof PongMessage) { - DataBuffer buffer = this.bufferFactory.wrap(((PongMessage) message).getApplicationData()); - return new WebSocketMessage(Type.PONG, buffer); - } - else { - throw new IllegalArgumentException("Unexpected message type: " + message); - } - } - - @Override - public void onClose(Session session, CloseReason reason) { - if (this.delegateSession != null) { - int code = reason.getCloseCode().getCode(); - this.delegateSession.handleClose(new CloseStatus(code, reason.getReasonPhrase())); - } - } - - @Override - public void onError(Session session, Throwable exception) { - if (this.delegateSession != null) { - this.delegateSession.handleError(exception); - } - } - - protected HandshakeInfo getHandshakeInfo() { - return this.info; - } - - - private final class HandlerResultSubscriber implements Subscriber { - - @Override - public void onSubscribe(Subscription subscription) { - subscription.request(Long.MAX_VALUE); - } - - @Override - public void onNext(Void aVoid) { - // no op - } - - @Override - public void onError(Throwable ex) { - if (completionMono != null) { - completionMono.onError(ex); - } - if (delegateSession != null) { - int code = CloseStatus.SERVER_ERROR.getCode(); - delegateSession.close(new CloseStatus(code, ex.getMessage())); - } - } - - @Override - public void onComplete() { - if (completionMono != null) { - completionMono.onComplete(); - } - if (delegateSession != null) { - delegateSession.close(); - } - } - } - -} \ No newline at end of file diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardWebSocketHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardWebSocketHandlerAdapter.java index 32d8ed30438..351fd88de66 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardWebSocketHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardWebSocketHandlerAdapter.java @@ -16,31 +16,153 @@ package org.springframework.web.reactive.socket.adapter; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.function.Function; +import javax.websocket.CloseReason; import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.PongMessage; +import javax.websocket.Session; -import org.springframework.core.io.buffer.DataBufferFactory; -import org.springframework.web.reactive.socket.HandshakeInfo; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.publisher.MonoProcessor; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.util.Assert; +import org.springframework.web.reactive.socket.CloseStatus; import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.WebSocketMessage; +import org.springframework.web.reactive.socket.WebSocketMessage.Type; +import org.springframework.web.reactive.socket.WebSocketSession; /** - * Adapter for Java WebSocket API (JSR-356). - * + * Adapter for Java WebSocket API (JSR-356) that delegates events to a reactive + * {@link WebSocketHandler} and its session. + * * @author Violeta Georgieva * @author Rossen Stoyanchev * @since 5.0 */ -public class StandardWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport { +public class StandardWebSocketHandlerAdapter extends Endpoint { + + private final WebSocketHandler delegateHandler; + + private final MonoProcessor completionMono; + private Function sessionFactory; - public StandardWebSocketHandlerAdapter(WebSocketHandler delegate, HandshakeInfo info, - DataBufferFactory bufferFactory) { + private StandardWebSocketSession delegateSession; - super(delegate, info, bufferFactory); + + public StandardWebSocketHandlerAdapter(WebSocketHandler handler, + Function sessionFactory) { + + this(handler, null, sessionFactory); } + public StandardWebSocketHandlerAdapter(WebSocketHandler handler, MonoProcessor completionMono, + Function sessionFactory) { + + Assert.notNull("WebSocketHandler is required"); + Assert.notNull("'sessionFactory' is required"); + this.delegateHandler = handler; + this.completionMono = completionMono; + this.sessionFactory = sessionFactory; + } + + + @Override + public void onOpen(Session session, EndpointConfig config) { + + this.delegateSession = this.sessionFactory.apply(session); + + session.addMessageHandler(String.class, message -> { + WebSocketMessage webSocketMessage = toMessage(message); + this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); + }); + session.addMessageHandler(ByteBuffer.class, message -> { + WebSocketMessage webSocketMessage = toMessage(message); + this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); + }); + session.addMessageHandler(PongMessage.class, message -> { + WebSocketMessage webSocketMessage = toMessage(message); + this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); + }); + + HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(); + this.delegateHandler.handle(this.delegateSession).subscribe(resultSubscriber); + } + + private WebSocketMessage toMessage(T message) { + WebSocketSession session = this.delegateSession; + Assert.state(session != null, "Cannot create message without a session"); + if (message instanceof String) { + byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); + return new WebSocketMessage(Type.TEXT, session.bufferFactory().wrap(bytes)); + } + else if (message instanceof ByteBuffer) { + DataBuffer buffer = session.bufferFactory().wrap((ByteBuffer) message); + return new WebSocketMessage(Type.BINARY, buffer); + } + else if (message instanceof PongMessage) { + DataBuffer buffer = session.bufferFactory().wrap(((PongMessage) message).getApplicationData()); + return new WebSocketMessage(Type.PONG, buffer); + } + else { + throw new IllegalArgumentException("Unexpected message type: " + message); + } + } + + @Override + public void onClose(Session session, CloseReason reason) { + if (this.delegateSession != null) { + int code = reason.getCloseCode().getCode(); + this.delegateSession.handleClose(new CloseStatus(code, reason.getReasonPhrase())); + } + } + + @Override + public void onError(Session session, Throwable exception) { + if (this.delegateSession != null) { + this.delegateSession.handleError(exception); + } + } + + + private final class HandlerResultSubscriber implements Subscriber { + + @Override + public void onSubscribe(Subscription subscription) { + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void aVoid) { + // no op + } + + @Override + public void onError(Throwable ex) { + if (completionMono != null) { + completionMono.onError(ex); + } + if (delegateSession != null) { + int code = CloseStatus.SERVER_ERROR.getCode(); + delegateSession.close(new CloseStatus(code, ex.getMessage())); + } + } - public Endpoint getEndpoint() { - return new StandardEndpoint(getDelegate(), getHandshakeInfo(), bufferFactory()); + @Override + public void onComplete() { + if (completionMono != null) { + completionMono.onComplete(); + } + if (delegateSession != null) { + delegateSession.close(); + } + } } -} +} \ No newline at end of file diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java index 57c61edd01c..0b5322e223a 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java @@ -25,20 +25,17 @@ import io.undertow.websockets.core.BufferedBinaryMessage; import io.undertow.websockets.core.BufferedTextMessage; import io.undertow.websockets.core.CloseMessage; import io.undertow.websockets.core.WebSocketChannel; -import io.undertow.websockets.spi.WebSocketHttpExchange; - import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; - import reactor.core.publisher.MonoProcessor; import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.util.Assert; import org.springframework.web.reactive.socket.CloseStatus; -import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketMessage; import org.springframework.web.reactive.socket.WebSocketMessage.Type; +import org.springframework.web.reactive.socket.WebSocketSession; /** * Undertow {@link WebSocketConnectionCallback} implementation that adapts and @@ -48,36 +45,35 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type; * @author Rossen Stoyanchev * @since 5.0 */ -public class UndertowWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport - implements WebSocketConnectionCallback { +public class UndertowWebSocketHandlerAdapter { - private UndertowWebSocketSession delegateSession; + private final WebSocketHandler delegateHandler; private final MonoProcessor completionMono; + private UndertowWebSocketSession delegateSession; - public UndertowWebSocketHandlerAdapter(WebSocketHandler delegate, HandshakeInfo info, - DataBufferFactory bufferFactory) { - this(delegate, info, bufferFactory, null); + public UndertowWebSocketHandlerAdapter(WebSocketHandler handler) { + this(handler, null); } - public UndertowWebSocketHandlerAdapter(WebSocketHandler delegate, HandshakeInfo info, - DataBufferFactory bufferFactory, MonoProcessor completionMono) { + public UndertowWebSocketHandlerAdapter(WebSocketHandler handler, MonoProcessor completionMono) { - super(delegate, info, bufferFactory); + Assert.notNull("WebSocketHandler is required"); + Assert.notNull("'sessionFactory' is required"); + this.delegateHandler = handler; this.completionMono = completionMono; } - @Override - public void onConnect(WebSocketHttpExchange exchange, WebSocketChannel channel) { - this.delegateSession = new UndertowWebSocketSession(channel, getHandshakeInfo(), bufferFactory()); - channel.getReceiveSetter().set(new UndertowReceiveListener()); - channel.resumeReceives(); + public void handle(UndertowWebSocketSession webSocketSession) { + this.delegateSession = webSocketSession; + webSocketSession.getDelegate().getReceiveSetter().set(new UndertowReceiveListener()); + webSocketSession.getDelegate().resumeReceives(); HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(); - getDelegate().handle(this.delegateSession).subscribe(resultSubscriber); + this.delegateHandler.handle(this.delegateSession).subscribe(resultSubscriber); } @@ -113,16 +109,18 @@ public class UndertowWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupp } private WebSocketMessage toMessage(Type type, T message) { + WebSocketSession session = delegateSession; + Assert.state(session != null, "Cannot create message without a session"); if (Type.TEXT.equals(type)) { byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); - return new WebSocketMessage(Type.TEXT, bufferFactory().wrap(bytes)); + return new WebSocketMessage(Type.TEXT, session.bufferFactory().wrap(bytes)); } else if (Type.BINARY.equals(type)) { - DataBuffer buffer = bufferFactory().allocateBuffer().write((ByteBuffer[]) message); + DataBuffer buffer = session.bufferFactory().allocateBuffer().write((ByteBuffer[]) message); return new WebSocketMessage(Type.BINARY, buffer); } else if (Type.PONG.equals(type)) { - DataBuffer buffer = bufferFactory().allocateBuffer().write((ByteBuffer[]) message); + DataBuffer buffer = session.bufferFactory().allocateBuffer().write((ByteBuffer[]) message); return new WebSocketMessage(Type.PONG, buffer); } else { diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/WebSocketHandlerAdapterSupport.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/WebSocketHandlerAdapterSupport.java deleted file mode 100644 index e8062c8af42..00000000000 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/WebSocketHandlerAdapterSupport.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2002-2016 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.web.reactive.socket.adapter; - -import org.springframework.core.io.buffer.DataBufferFactory; -import org.springframework.util.Assert; -import org.springframework.web.reactive.socket.HandshakeInfo; -import org.springframework.web.reactive.socket.WebSocketHandler; - -/** - * Base class for adapters from event-listener WebSocket APIs (e.g. Java - * WebSocket API JSR-356, Jetty, Undertow) to the Reactive Streams based - * {@link WebSocketHandler}. - * - * @author Rossen Stoyanchev - * @since 5.0 - */ -public abstract class WebSocketHandlerAdapterSupport { - - private final WebSocketHandler delegate; - - private final HandshakeInfo handshakeInfo; - - private final DataBufferFactory bufferFactory; - - - protected WebSocketHandlerAdapterSupport(WebSocketHandler delegate, HandshakeInfo info, - DataBufferFactory bufferFactory) { - - Assert.notNull(delegate, "WebSocketHandler delegate is required"); - Assert.notNull(info, "HandshakeInfo is required."); - Assert.notNull(bufferFactory, "DataBufferFactory is required"); - - this.delegate = delegate; - this.handshakeInfo = info; - this.bufferFactory = bufferFactory; - } - - - protected WebSocketHandler getDelegate() { - return this.delegate; - } - - protected HandshakeInfo getHandshakeInfo() { - return this.handshakeInfo; - } - - @SuppressWarnings("unchecked") - protected T bufferFactory() { - return (T) this.bufferFactory; - } - -} diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java index f96f64ef23b..6341206718d 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java @@ -17,13 +17,10 @@ package org.springframework.web.reactive.socket.client; import java.net.URI; -import java.util.Optional; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.UpgradeResponse; -import org.eclipse.jetty.websocket.api.annotations.WebSocket; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; - import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; @@ -31,15 +28,16 @@ import org.springframework.context.Lifecycle; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.HttpHeaders; -import org.springframework.util.ObjectUtils; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter; +import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession; /** * A Jetty based implementation of {@link WebSocketClient}. * * @author Violeta Georgieva + * @author Rossen Stoyanchev * @since 5.0 */ public class JettyWebSocketClient extends WebSocketClientSupport implements WebSocketClient, Lifecycle { @@ -76,31 +74,38 @@ public class JettyWebSocketClient extends WebSocketClientSupport implements WebS @Override public Mono execute(URI url, HttpHeaders headers, WebSocketHandler handler) { - return connectInternal(url, headers, handler); + return executeInternal(url, headers, handler); } - private Mono connectInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { - MonoProcessor processor = MonoProcessor.create(); + private Mono executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { + MonoProcessor completionMono = MonoProcessor.create(); return Mono.fromCallable( () -> { - HandshakeInfo info = new HandshakeInfo(url, Mono.empty()); - Object adapter = new JettyClientAdapter(handler, info, this.bufferFactory, processor); - ClientUpgradeRequest request = createRequest(url, headers, handler); - return this.wsClient.connect(adapter, url, request); + String[] protocols = beforeHandshake(url, headers, handler); + ClientUpgradeRequest upgradeRequest = createRequest(url, headers, protocols); + Object jettyHandler = createJettyHandler(url, handler, completionMono); + return this.wsClient.connect(jettyHandler, url, upgradeRequest); }) - .then(processor); + .then(completionMono); } - private ClientUpgradeRequest createRequest(URI url, HttpHeaders headers, WebSocketHandler handler) { - ClientUpgradeRequest request = new ClientUpgradeRequest(); - - String[] protocols = beforeHandshake(url, headers, handler); - if (!ObjectUtils.isEmpty(protocols)) { - request.setSubProtocols(protocols); - } + private Object createJettyHandler(URI url, WebSocketHandler handler, MonoProcessor completion) { + return new JettyWebSocketHandlerAdapter( + handler, completion, session -> createJettySession(url, session)); + } - headers.forEach((k, v) -> request.setHeader(k, v)); + private JettyWebSocketSession createJettySession(URI url, Session session) { + UpgradeResponse response = session.getUpgradeResponse(); + HttpHeaders responseHeaders = new HttpHeaders(); + response.getHeaders().forEach(responseHeaders::put); + HandshakeInfo info = afterHandshake(url, responseHeaders); + return new JettyWebSocketSession(session, info, bufferFactory); + } + private ClientUpgradeRequest createRequest(URI url, HttpHeaders headers, String[] protocols) { + ClientUpgradeRequest request = new ClientUpgradeRequest(); + request.setSubProtocols(protocols); + headers.forEach(request::setHeader); return request; } @@ -140,32 +145,4 @@ public class JettyWebSocketClient extends WebSocketClientSupport implements WebS } } - - @WebSocket - private static final class JettyClientAdapter extends JettyWebSocketHandlerAdapter { - - public JettyClientAdapter(WebSocketHandler delegate, - HandshakeInfo info, DataBufferFactory bufferFactory, MonoProcessor processor) { - super(delegate, info, bufferFactory, processor); - } - - @Override - public void onWebSocketConnect(Session session) { - UpgradeResponse response = session.getUpgradeResponse(); - - getHandshakeInfo().setHeaders(getResponseHeaders(response)); - getHandshakeInfo().setSubProtocol( - Optional.ofNullable(response.getAcceptedSubProtocol())); - - super.onWebSocketConnect(session); - } - - private HttpHeaders getResponseHeaders(UpgradeResponse response) { - HttpHeaders responseHeaders = new HttpHeaders(); - response.getHeaders().forEach((k, v) -> responseHeaders.put(k, v)); - return responseHeaders; - } - - } - } \ No newline at end of file diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java index 8c28741af27..b08da13ff71 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java @@ -70,7 +70,7 @@ public class ReactorNettyWebSocketClient extends WebSocketClientSupport implemen }) .then(response -> { HttpHeaders responseHeaders = getResponseHeaders(response); - HandshakeInfo info = afterHandshake(url, response.status().code(), responseHeaders); + HandshakeInfo info = afterHandshake(url, responseHeaders); ByteBufAllocator allocator = response.channel().alloc(); NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator); diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java index 139bb9688cf..a55b8f94e87 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java @@ -114,7 +114,7 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We .flatMap(tuple -> { WebSocketResponse response = tuple.getT1(); HttpHeaders responseHeaders = getResponseHeaders(response); - HandshakeInfo info = afterHandshake(url, response.getStatus().code(), responseHeaders); + HandshakeInfo info = afterHandshake(url, responseHeaders); ByteBufAllocator allocator = response.unsafeNettyChannel().alloc(); NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator); diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java index 1a4c47a16c7..53be6829706 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java @@ -20,25 +20,25 @@ import java.net.URI; import java.util.Arrays; import java.util.List; import java.util.Map; -import java.util.Optional; import javax.websocket.ClientEndpointConfig; +import javax.websocket.ClientEndpointConfig.Configurator; import javax.websocket.ContainerProvider; import javax.websocket.Endpoint; -import javax.websocket.EndpointConfig; import javax.websocket.HandshakeResponse; import javax.websocket.Session; import javax.websocket.WebSocketContainer; -import javax.websocket.ClientEndpointConfig.Configurator; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; +import reactor.core.scheduler.Schedulers; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.HttpHeaders; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; -import org.springframework.web.reactive.socket.adapter.StandardEndpoint; +import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter; +import org.springframework.web.reactive.socket.adapter.StandardWebSocketSession; /** * A Java WebSocket API (JSR-356) based implementation of @@ -78,70 +78,59 @@ public class StandardWebSocketClient extends WebSocketClientSupport implements W @Override public Mono execute(URI url, HttpHeaders headers, WebSocketHandler handler) { - return connectInternal(url, headers, handler); + return executeInternal(url, headers, handler); } - private Mono connectInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { - MonoProcessor processor = MonoProcessor.create(); - return Mono.fromCallable(() -> { - StandardWebSocketClientConfigurator configurator = - new StandardWebSocketClientConfigurator(headers); - - ClientEndpointConfig endpointConfig = createClientEndpointConfig( - configurator, beforeHandshake(url, headers, handler)); - - HandshakeInfo info = new HandshakeInfo(url, Mono.empty()); - - Endpoint endpoint = new StandardClientEndpoint(handler, info, - this.bufferFactory, configurator, processor); - - Session session = this.wsContainer.connectToServer(endpoint, endpointConfig, url); - return session; - }).then(processor); + private Mono executeInternal(URI url, HttpHeaders requestHeaders, WebSocketHandler handler) { + MonoProcessor completionMono = MonoProcessor.create(); + return Mono.fromCallable( + () -> { + String[] subProtocols = beforeHandshake(url, requestHeaders, handler); + DefaultConfigurator configurator = new DefaultConfigurator(requestHeaders); + ClientEndpointConfig config = createEndpointConfig(configurator, subProtocols); + Endpoint endpoint = createEndpoint(url, handler, completionMono, configurator); + return this.wsContainer.connectToServer(endpoint, config, url); + }) + .subscribeOn(Schedulers.elastic()) // connectToServer is blocking + .then(completionMono); } - private ClientEndpointConfig createClientEndpointConfig( - StandardWebSocketClientConfigurator configurator, String[] subProtocols) { - + private ClientEndpointConfig createEndpointConfig(Configurator configurator, String[] subProtocols) { return ClientEndpointConfig.Builder.create() .configurator(configurator) .preferredSubprotocols(Arrays.asList(subProtocols)) .build(); } + private StandardWebSocketHandlerAdapter createEndpoint(URI url, WebSocketHandler handler, + MonoProcessor completion, DefaultConfigurator configurator) { - private static final class StandardClientEndpoint extends StandardEndpoint { - - private final StandardWebSocketClientConfigurator configurator; - - public StandardClientEndpoint(WebSocketHandler handler, HandshakeInfo info, - DataBufferFactory bufferFactory, StandardWebSocketClientConfigurator configurator, - MonoProcessor processor) { - super(handler, info, bufferFactory, processor); - this.configurator = configurator; - } - - @Override - public void onOpen(Session nativeSession, EndpointConfig config) { - getHandshakeInfo().setHeaders(this.configurator.getResponseHeaders()); - getHandshakeInfo().setSubProtocol( - Optional.ofNullable(nativeSession.getNegotiatedSubprotocol())); + return new StandardWebSocketHandlerAdapter(handler, completion, + session -> createSession(url, configurator.getResponseHeaders(), session)); + } - super.onOpen(nativeSession, config); - } + private StandardWebSocketSession createSession(URI url, HttpHeaders responseHeaders, Session session) { + HandshakeInfo info = afterHandshake(url, responseHeaders); + return new StandardWebSocketSession(session, info, this.bufferFactory); } - private static final class StandardWebSocketClientConfigurator extends Configurator { + private static final class DefaultConfigurator extends Configurator { private final HttpHeaders requestHeaders; - private HttpHeaders responseHeaders = new HttpHeaders(); + private final HttpHeaders responseHeaders = new HttpHeaders(); + - public StandardWebSocketClientConfigurator(HttpHeaders requestHeaders) { + public DefaultConfigurator(HttpHeaders requestHeaders) { this.requestHeaders = requestHeaders; } + + public HttpHeaders getResponseHeaders() { + return this.responseHeaders; + } + @Override public void beforeRequest(Map> requestHeaders) { requestHeaders.putAll(this.requestHeaders); @@ -149,11 +138,7 @@ public class StandardWebSocketClient extends WebSocketClientSupport implements W @Override public void afterResponse(HandshakeResponse response) { - response.getHeaders().forEach((k, v) -> responseHeaders.put(k, v)); - } - - public HttpHeaders getResponseHeaders() { - return this.responseHeaders; + response.getHeaders().forEach(this.responseHeaders::put); } } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/UndertowWebSocketClient.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/UndertowWebSocketClient.java index 0ee60771071..fe5403e13b2 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/UndertowWebSocketClient.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/UndertowWebSocketClient.java @@ -23,26 +23,20 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.concurrent.CancellationException; import java.util.function.Function; - import javax.net.ssl.SSLContext; import io.undertow.protocols.ssl.UndertowXnioSsl; import io.undertow.server.DefaultByteBufferPool; -import io.undertow.websockets.WebSocketExtension; +import io.undertow.websockets.client.WebSocketClient.ConnectionBuilder; import io.undertow.websockets.client.WebSocketClientNegotiation; import io.undertow.websockets.core.WebSocketChannel; - -import org.xnio.IoFuture; -import org.xnio.IoFuture.Notifier; import org.xnio.IoFuture.Status; import org.xnio.OptionMap; import org.xnio.Options; import org.xnio.Xnio; import org.xnio.XnioWorker; - import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; @@ -52,11 +46,13 @@ import org.springframework.http.HttpHeaders; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter; +import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession; /** * An Undertow based implementation of {@link WebSocketClient}. * * @author Violeta Georgieva + * @author Rossen Stoyanchev * @since 5.0 */ public class UndertowWebSocketClient extends WebSocketClientSupport implements WebSocketClient { @@ -82,7 +78,10 @@ public class UndertowWebSocketClient extends WebSocketClientSupport implements W } } - private final Function builder; + + private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); + + private final Function builder; /** @@ -100,15 +99,13 @@ public class UndertowWebSocketClient extends WebSocketClientSupport implements W * instance. * @param builder a connection builder that can be used to create a web socket connection. */ - public UndertowWebSocketClient(Function builder) { + public UndertowWebSocketClient(Function builder) { this.builder = builder; } - private static io.undertow.websockets.client.WebSocketClient.ConnectionBuilder createDefaultConnectionBuilder( - URI url) { + private static ConnectionBuilder createDefaultConnectionBuilder(URI url) { - io.undertow.websockets.client.WebSocketClient.ConnectionBuilder builder = + ConnectionBuilder builder = io.undertow.websockets.client.WebSocketClient.connectionBuilder( worker, new DefaultByteBufferPool(false, DEFAULT_BUFFER_SIZE), url); @@ -135,99 +132,77 @@ public class UndertowWebSocketClient extends WebSocketClientSupport implements W @Override public Mono execute(URI url, HttpHeaders headers, WebSocketHandler handler) { - return connectInternal(url, headers, handler); + return executeInternal(url, headers, handler); } - private Mono connectInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { - MonoProcessor processor = MonoProcessor.create(); + private Mono executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { + MonoProcessor completionMono = MonoProcessor.create(); return Mono.fromCallable( () -> { - WSClientNegotiation clientNegotiation = - new WSClientNegotiation(beforeHandshake(url, headers, handler), - Collections.emptyList(), headers); - - io.undertow.websockets.client.WebSocketClient.ConnectionBuilder builder = - this.builder.apply(url).setClientNegotiation(clientNegotiation); - - IoFuture future = builder.connect(); - future.addNotifier(new ResultNotifier(url, handler, clientNegotiation, processor), new Object()); - return future; + String[] subProtocols = beforeHandshake(url, headers, handler); + DefaultNegotiation negotiation = new DefaultNegotiation(subProtocols, headers); + + return this.builder.apply(url) + .setClientNegotiation(negotiation) + .connect() + .addNotifier((future, attachment) -> { + if (Status.DONE.equals(future.getStatus())) { + WebSocketChannel channel; + try { + channel = future.get(); + } + catch (CancellationException | IOException ex) { + completionMono.onError(ex); + return; + } + handleWebSocket(url, handler, completionMono, negotiation, channel); + } + else if (Status.FAILED.equals(future.getStatus())) { + completionMono.onError(future.getException()); + } + else { + String message = "Failed to connect" + future.getStatus(); + completionMono.onError(new IllegalStateException(message)); + } + }, null); }) - .then(processor); + .then(completionMono); } + private void handleWebSocket(URI url, WebSocketHandler handler, MonoProcessor completionMono, + DefaultNegotiation negotiation, WebSocketChannel channel) { - private static final class ResultNotifier implements Notifier { - - private final URI url; - - private final WebSocketHandler handler; - - private final WSClientNegotiation clientNegotiation; - - private final MonoProcessor processor; - - public ResultNotifier(URI url, WebSocketHandler handler, - WSClientNegotiation clientNegotiation, MonoProcessor processor) { - this.url = url; - this.handler = handler; - this.clientNegotiation = clientNegotiation; - this.processor = processor; - } - - @Override - public void notify(IoFuture ioFuture, - Object attachment) { - if (Status.CANCELLED.equals(ioFuture.getStatus())) { - processor.onError(null); - } - else if (Status.FAILED.equals(ioFuture.getStatus())) { - processor.onError(ioFuture.getException()); - } - else if (Status.DONE.equals(ioFuture.getStatus())) { - try { - WebSocketChannel channel = ioFuture.get(); - DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); - HandshakeInfo info = new HandshakeInfo(url, clientNegotiation.getResponseHeaders(), - Mono.empty(), Optional.ofNullable(channel.getSubProtocol())); - - UndertowWebSocketHandlerAdapter adapter = - new UndertowWebSocketHandlerAdapter(handler, - info, bufferFactory, processor); - adapter.onConnect(null, channel); - } - catch (CancellationException | IOException ex) { - processor.onError(ex); - } - } - } + HandshakeInfo info = afterHandshake(url, negotiation.getResponseHeaders()); + UndertowWebSocketSession session = new UndertowWebSocketSession(channel, info, this.bufferFactory); + new UndertowWebSocketHandlerAdapter(handler, completionMono).handle(session); } - private static final class WSClientNegotiation extends WebSocketClientNegotiation { + private static final class DefaultNegotiation extends WebSocketClientNegotiation { private final HttpHeaders requestHeaders; private HttpHeaders responseHeaders = new HttpHeaders(); - public WSClientNegotiation(String[] subProtocols, - List extensions, HttpHeaders requestHeaders) { - super(Arrays.asList(subProtocols), extensions); + + public DefaultNegotiation(String[] subProtocols, HttpHeaders requestHeaders) { + super(Arrays.asList(subProtocols), Collections.emptyList()); this.requestHeaders = requestHeaders; } + + public HttpHeaders getResponseHeaders() { + return this.responseHeaders; + } + @Override public void beforeRequest(Map> headers) { - requestHeaders.forEach((k, v) -> headers.put(k, v)); + this.requestHeaders.forEach(headers::put); } @Override public void afterRequest(Map> headers) { - headers.forEach((k, v) -> responseHeaders.put(k, v)); - } - - public HttpHeaders getResponseHeaders() { - return responseHeaders; + headers.forEach((k, v) -> this.responseHeaders.put(k, v)); } } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClientSupport.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClientSupport.java index e1b7108d7a5..1173cc354e1 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClientSupport.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClientSupport.java @@ -41,20 +41,19 @@ public class WebSocketClientSupport { protected final Log logger = LogFactory.getLog(getClass()); - protected String[] beforeHandshake(URI url, HttpHeaders headers, WebSocketHandler handler) { + protected String[] beforeHandshake(URI url, HttpHeaders requestHeaders, WebSocketHandler handler) { if (logger.isDebugEnabled()) { logger.debug("Executing handshake to " + url); } return handler.getSubProtocols(); } - protected HandshakeInfo afterHandshake(URI url, int statusCode, HttpHeaders headers) { - Assert.isTrue(statusCode == 101); + protected HandshakeInfo afterHandshake(URI url, HttpHeaders responseHeaders) { if (logger.isDebugEnabled()) { - logger.debug("Handshake response: " + url + ", " + headers); + logger.debug("Handshake response: " + url + ", " + responseHeaders); } - String protocol = headers.getFirst(SEC_WEBSOCKET_PROTOCOL); - return new HandshakeInfo(url, headers, Mono.empty(), Optional.ofNullable(protocol)); + String protocol = responseHeaders.getFirst(SEC_WEBSOCKET_PROTOCOL); + return new HandshakeInfo(url, responseHeaders, Mono.empty(), Optional.ofNullable(protocol)); } } \ No newline at end of file diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java index 2c95112eb65..c265a9c287e 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java @@ -37,6 +37,7 @@ import org.springframework.util.Assert; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter; +import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession; import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.server.ServerWebExchange; @@ -118,9 +119,12 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life HttpServletRequest servletRequest = getHttpServletRequest(request); HttpServletResponse servletResponse = getHttpServletResponse(response); - HandshakeInfo info = getHandshakeInfo(exchange, subProtocol); - DataBufferFactory factory = response.bufferFactory(); - JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter(handler, info, factory); + JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter(handler, + null, session -> { + HandshakeInfo info = getHandshakeInfo(exchange, subProtocol); + DataBufferFactory factory = response.bufferFactory(); + return new JettyWebSocketSession(session, info, factory); + }); startLazily(servletRequest); diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java index 21467be9bb1..7bfcce57b09 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java @@ -38,6 +38,7 @@ import org.springframework.util.Assert; import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter; +import org.springframework.web.reactive.socket.adapter.StandardWebSocketSession; import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.server.ServerWebExchange; @@ -63,9 +64,12 @@ public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy { HttpServletRequest servletRequest = getHttpServletRequest(request); HttpServletResponse servletResponse = getHttpServletResponse(response); - HandshakeInfo info = getHandshakeInfo(exchange, subProtocol); - DataBufferFactory factory = response.bufferFactory(); - Endpoint endpoint = new StandardWebSocketHandlerAdapter(handler, info, factory).getEndpoint(); + Endpoint endpoint = new StandardWebSocketHandlerAdapter(handler, + session -> { + HandshakeInfo info = getHandshakeInfo(exchange, subProtocol); + DataBufferFactory factory = response.bufferFactory(); + return new StandardWebSocketSession(session, info, factory); + }); String requestURI = servletRequest.getRequestURI(); DefaultServerEndpointConfig config = new DefaultServerEndpointConfig(requestURI, endpoint); diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java index 959a7d75172..b3aec4ed3c0 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java @@ -16,6 +16,7 @@ package org.springframework.web.reactive.socket.server.upgrade; +import java.net.URI; import java.security.Principal; import java.util.Collections; import java.util.List; @@ -25,18 +26,21 @@ import java.util.Set; import io.undertow.server.HttpServerExchange; import io.undertow.websockets.WebSocketConnectionCallback; import io.undertow.websockets.WebSocketProtocolHandshakeHandler; +import io.undertow.websockets.core.WebSocketChannel; import io.undertow.websockets.core.protocol.Handshake; import io.undertow.websockets.core.protocol.version13.Hybi13Handshake; +import io.undertow.websockets.spi.WebSocketHttpExchange; import reactor.core.publisher.Mono; import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpHeaders; import org.springframework.http.server.reactive.ServerHttpRequest; -import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.http.server.reactive.UndertowServerHttpRequest; import org.springframework.util.Assert; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter; +import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession; import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.server.ServerWebExchange; @@ -55,22 +59,15 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy { Optional subProtocol) { ServerHttpRequest request = exchange.getRequest(); - ServerHttpResponse response = exchange.getResponse(); - - HandshakeInfo info = getHandshakeInfo(exchange, subProtocol); - DataBufferFactory bufferFactory = response.bufferFactory(); - Assert.isTrue(request instanceof UndertowServerHttpRequest); HttpServerExchange httpExchange = ((UndertowServerHttpRequest) request).getUndertowExchange(); - WebSocketConnectionCallback callback = - new UndertowWebSocketHandlerAdapter(handler, info, bufferFactory); - Set protocols = subProtocol.map(Collections::singleton).orElse(Collections.emptySet()); Hybi13Handshake handshake = new Hybi13Handshake(protocols, false); List handshakes = Collections.singletonList(handshake); try { + DefaultCallback callback = new DefaultCallback(exchange, handler, subProtocol); new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange); } catch (Exception ex) { @@ -80,10 +77,44 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy { return Mono.empty(); } - private HandshakeInfo getHandshakeInfo(ServerWebExchange exchange, Optional protocol) { - ServerHttpRequest request = exchange.getRequest(); - Mono principal = exchange.getPrincipal(); - return new HandshakeInfo(request.getURI(), request.getHeaders(), principal, protocol); + + private class DefaultCallback implements WebSocketConnectionCallback { + + private final ServerWebExchange exchange; + + private final WebSocketHandler handler; + + private final Optional subProtocol; + + + public DefaultCallback(ServerWebExchange exchange, WebSocketHandler handler, + Optional subProtocol) { + + this.exchange = exchange; + this.handler = handler; + this.subProtocol = subProtocol; + } + + @Override + public void onConnect(WebSocketHttpExchange httpExchange, WebSocketChannel channel) { + UndertowWebSocketHandlerAdapter adapter = new UndertowWebSocketHandlerAdapter(this.handler); + UndertowWebSocketSession session = createWebSocketSession(channel); + adapter.handle(session); + } + + private UndertowWebSocketSession createWebSocketSession(WebSocketChannel channel) { + HandshakeInfo info = getHandshakeInfo(); + DataBufferFactory bufferFactory = this.exchange.getResponse().bufferFactory(); + return new UndertowWebSocketSession(channel, info, bufferFactory); + } + + private HandshakeInfo getHandshakeInfo() { + ServerHttpRequest request = this.exchange.getRequest(); + URI url = request.getURI(); + HttpHeaders headers = request.getHeaders(); + Mono principal = this.exchange.getPrincipal(); + return new HandshakeInfo(url, headers, principal, this.subProtocol); + } } }