diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java index f95def729fe..6b8fbf80bbc 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java @@ -18,7 +18,6 @@ package org.springframework.web.reactive.socket.client; import java.net.URI; import java.security.NoSuchAlgorithmException; import java.util.function.Function; - import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; @@ -34,8 +33,9 @@ import rx.RxReactiveStreams; import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.http.HttpHeaders; -import org.springframework.web.reactive.socket.WebSocketSession; import org.springframework.web.reactive.socket.HandshakeInfo; +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.WebSocketSession; import org.springframework.web.reactive.socket.adapter.RxNettyWebSocketSession; /** @@ -85,18 +85,18 @@ public class RxNettyWebSocketClient implements WebSocketClient { @Override - public Mono connect(URI url) { - return connect(url, new HttpHeaders()); + public Mono execute(URI url, WebSocketHandler handler) { + return execute(url, new HttpHeaders(), handler); } @Override - public Mono connect(URI url, HttpHeaders headers) { + public Mono execute(URI url, HttpHeaders headers, WebSocketHandler handler) { HandshakeInfo info = new HandshakeInfo(url, headers, Mono.empty()); - Observable observable = connectInternal(info); - return Mono.from(RxReactiveStreams.toPublisher(observable)); + Observable completion = connectInternal(handler, info); + return Mono.from(RxReactiveStreams.toPublisher(completion)); } - private Observable connectInternal(HandshakeInfo info) { + private Observable connectInternal(WebSocketHandler handler, HandshakeInfo info) { return createWebSocketRequest(info.getUri()) .flatMap(response -> { ByteBufAllocator allocator = response.unsafeNettyChannel().alloc(); @@ -104,10 +104,11 @@ public class RxNettyWebSocketClient implements WebSocketClient { Observable conn = response.getWebSocketConnection(); return Observable.zip(conn, Observable.just(bufferFactory), Tuples::of); }) - .map(tuple -> { + .flatMap(tuple -> { WebSocketConnection conn = tuple.getT1(); NettyDataBufferFactory bufferFactory = tuple.getT2(); - return new RxNettyWebSocketSession(conn, info, bufferFactory); + WebSocketSession session = new RxNettyWebSocketSession(conn, info, bufferFactory); + return RxReactiveStreams.toObservable(handler.handle(session)); }); } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClient.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClient.java index f8df12714ab..0ef241b66bb 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClient.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClient.java @@ -20,10 +20,10 @@ import java.net.URI; import reactor.core.publisher.Mono; import org.springframework.http.HttpHeaders; -import org.springframework.web.reactive.socket.WebSocketSession; +import org.springframework.web.reactive.socket.WebSocketHandler; /** - * Contract for starting a WebSocket interaction. + * Contract for connecting and handling a WebSocket session. * * @author Rossen Stoyanchev * @since 5.0 @@ -31,18 +31,23 @@ import org.springframework.web.reactive.socket.WebSocketSession; public interface WebSocketClient { /** - * Start a WebSocket interaction to the given url. + * Execute a handshake request to the given url and handle the resulting + * WebSocket session with the given handler. * @param url the handshake url - * @return the session for the WebSocket interaction + * @param handler the handler of the WebSocket session + * @return completion {@code Mono} to indicate the outcome of the + * WebSocket session handling */ - Mono connect(URI url); + Mono execute(URI url, WebSocketHandler handler); /** - * Start a WebSocket interaction to the given url. + * A variant of {@link #execute(URI, WebSocketHandler)} with custom headers. * @param url the handshake url - * @param headers headers for the handshake request - * @return the session for the WebSocket interaction + * @param headers custom headers for the handshake request + * @param handler the handler of the WebSocket session + * @return completion {@code Mono} to indicate the outcome of the + * WebSocket session handling */ - Mono connect(URI url, HttpHeaders headers); + Mono execute(URI url, HttpHeaders headers, WebSocketHandler handler); } diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/AbstractWebSocketIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/AbstractWebSocketIntegrationTests.java index ff11e9a70b1..eb1e34df2a3 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/AbstractWebSocketIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/AbstractWebSocketIntegrationTests.java @@ -16,8 +16,11 @@ package org.springframework.web.reactive.socket.server; import java.io.File; +import java.net.URI; +import java.net.URISyntaxException; import org.apache.tomcat.websocket.server.WsContextListener; +import org.jetbrains.annotations.NotNull; import org.junit.After; import org.junit.Before; import org.junit.runner.RunWith; @@ -112,6 +115,12 @@ public abstract class AbstractWebSocketIntegrationTests { } } + @NotNull + protected URI getUrl(String path) throws URISyntaxException { + return new URI("ws://localhost:" + this.port + path); + } + + static abstract class AbstractHandlerAdapterConfig { @Bean diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/WebSocketIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/WebSocketIntegrationTests.java index 9e90e47a4e8..71b0e1471d0 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/WebSocketIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/WebSocketIntegrationTests.java @@ -15,13 +15,13 @@ */ package org.springframework.web.reactive.socket.server; -import java.net.URI; import java.util.HashMap; import java.util.Map; import org.junit.Test; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.ReplayProcessor; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -35,7 +35,6 @@ import static org.junit.Assert.assertEquals; /** * Integration tests with server-side {@link WebSocketHandler}s. - * * @author Rossen Stoyanchev */ @SuppressWarnings({"unused", "WeakerAccess"}) @@ -52,9 +51,10 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests public void echo() throws Exception { int count = 100; Flux input = Flux.range(1, count).map(index -> "msg-" + index); - Flux output = new RxNettyWebSocketClient() - .connect(new URI("ws://localhost:" + this.port + "/echo")) - .flatMap(session -> session + ReplayProcessor emitter = ReplayProcessor.create(count); + + new RxNettyWebSocketClient() + .execute(getUrl("/echo"), session -> session .send(input.map(session::textMessage)) .thenMany(session.receive() .take(count) @@ -62,9 +62,12 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests String text = message.getPayloadAsText(); message.release(); return text; - }) - )); - assertEquals(input.collectList().blockMillis(5000), output.collectList().blockMillis(5000)); + })) + .subscribeWith(emitter) + .then()) + .blockMillis(5000); + + assertEquals(input.collectList().blockMillis(5000), emitter.collectList().blockMillis(5000)); }