diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardEndpoint.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardEndpoint.java new file mode 100644 index 00000000000..94fc35597de --- /dev/null +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardEndpoint.java @@ -0,0 +1,168 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.socket.adapter; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import javax.websocket.CloseReason; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.PongMessage; +import javax.websocket.Session; + +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import reactor.core.publisher.MonoProcessor; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.web.reactive.socket.CloseStatus; +import org.springframework.web.reactive.socket.HandshakeInfo; +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.WebSocketMessage; +import org.springframework.web.reactive.socket.WebSocketMessage.Type; + +/** + * {@link Endpoint} delegating events + * to a reactive {@link WebSocketHandler} and its session. + * + * @author Violeta Georgieva + * @since 5.0 + */ +public class StandardEndpoint extends Endpoint { + + private final WebSocketHandler handler; + + private final HandshakeInfo info; + + private final DataBufferFactory bufferFactory; + + private StandardWebSocketSession delegateSession; + + private final MonoProcessor completionMono; + + + public StandardEndpoint(WebSocketHandler handler, HandshakeInfo info, + DataBufferFactory bufferFactory) { + this(handler, info, bufferFactory, null); + } + + public StandardEndpoint(WebSocketHandler handler, HandshakeInfo info, + DataBufferFactory bufferFactory, MonoProcessor completionMono) { + this.handler = handler; + this.info = info; + this.bufferFactory = bufferFactory; + this.completionMono = completionMono; + } + + + @Override + public void onOpen(Session nativeSession, EndpointConfig config) { + + this.delegateSession = new StandardWebSocketSession(nativeSession, this.info, this.bufferFactory); + + nativeSession.addMessageHandler(String.class, message -> { + WebSocketMessage webSocketMessage = toMessage(message); + this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); + }); + nativeSession.addMessageHandler(ByteBuffer.class, message -> { + WebSocketMessage webSocketMessage = toMessage(message); + this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); + }); + nativeSession.addMessageHandler(PongMessage.class, message -> { + WebSocketMessage webSocketMessage = toMessage(message); + this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); + }); + + HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(); + this.handler.handle(this.delegateSession).subscribe(resultSubscriber); + } + + private WebSocketMessage toMessage(T message) { + if (message instanceof String) { + byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); + return new WebSocketMessage(Type.TEXT, this.bufferFactory.wrap(bytes)); + } + else if (message instanceof ByteBuffer) { + DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message); + return new WebSocketMessage(Type.BINARY, buffer); + } + else if (message instanceof PongMessage) { + DataBuffer buffer = this.bufferFactory.wrap(((PongMessage) message).getApplicationData()); + return new WebSocketMessage(Type.PONG, buffer); + } + else { + throw new IllegalArgumentException("Unexpected message type: " + message); + } + } + + @Override + public void onClose(Session session, CloseReason reason) { + if (this.delegateSession != null) { + int code = reason.getCloseCode().getCode(); + this.delegateSession.handleClose(new CloseStatus(code, reason.getReasonPhrase())); + } + } + + @Override + public void onError(Session session, Throwable exception) { + if (this.delegateSession != null) { + this.delegateSession.handleError(exception); + } + } + + protected HandshakeInfo getHandshakeInfo() { + return this.info; + } + + + private final class HandlerResultSubscriber implements Subscriber { + + @Override + public void onSubscribe(Subscription subscription) { + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void aVoid) { + // no op + } + + @Override + public void onError(Throwable ex) { + if (completionMono != null) { + completionMono.onError(ex); + } + if (delegateSession != null) { + int code = CloseStatus.SERVER_ERROR.getCode(); + delegateSession.close(new CloseStatus(code, ex.getMessage())); + } + } + + @Override + public void onComplete() { + if (completionMono != null) { + completionMono.onComplete(); + } + if (delegateSession != null) { + delegateSession.close(); + } + } + } + +} \ No newline at end of file diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardWebSocketHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardWebSocketHandlerAdapter.java index dd79544e447..32d8ed30438 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardWebSocketHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardWebSocketHandlerAdapter.java @@ -16,28 +16,14 @@ package org.springframework.web.reactive.socket.adapter; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import javax.websocket.CloseReason; import javax.websocket.Endpoint; -import javax.websocket.EndpointConfig; -import javax.websocket.PongMessage; -import javax.websocket.Session; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; -import org.springframework.web.reactive.socket.CloseStatus; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; -import org.springframework.web.reactive.socket.WebSocketMessage; -import org.springframework.web.reactive.socket.WebSocketMessage.Type; /** - * Adapter for Java WebSocket API (JSR-356) {@link Endpoint} delegating events - * to a reactive {@link WebSocketHandler} and its session. + * Adapter for Java WebSocket API (JSR-356). * * @author Violeta Georgieva * @author Rossen Stoyanchev @@ -45,8 +31,6 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type; */ public class StandardWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport { - private StandardWebSocketSession delegateSession; - public StandardWebSocketHandlerAdapter(WebSocketHandler delegate, HandshakeInfo info, DataBufferFactory bufferFactory) { @@ -56,94 +40,7 @@ public class StandardWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupp public Endpoint getEndpoint() { - return new StandardEndpoint(); - } - - - private class StandardEndpoint extends Endpoint { - - @Override - public void onOpen(Session nativeSession, EndpointConfig config) { - - delegateSession = new StandardWebSocketSession(nativeSession, getHandshakeInfo(), bufferFactory()); - - nativeSession.addMessageHandler(String.class, message -> { - WebSocketMessage webSocketMessage = toMessage(message); - delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); - }); - nativeSession.addMessageHandler(ByteBuffer.class, message -> { - WebSocketMessage webSocketMessage = toMessage(message); - delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); - }); - nativeSession.addMessageHandler(PongMessage.class, message -> { - WebSocketMessage webSocketMessage = toMessage(message); - delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); - }); - - HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(); - getDelegate().handle(delegateSession).subscribe(resultSubscriber); - } - - private WebSocketMessage toMessage(T message) { - if (message instanceof String) { - byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); - return new WebSocketMessage(Type.TEXT, bufferFactory().wrap(bytes)); - } - else if (message instanceof ByteBuffer) { - DataBuffer buffer = bufferFactory().wrap((ByteBuffer) message); - return new WebSocketMessage(Type.BINARY, buffer); - } - else if (message instanceof PongMessage) { - DataBuffer buffer = bufferFactory().wrap(((PongMessage) message).getApplicationData()); - return new WebSocketMessage(Type.PONG, buffer); - } - else { - throw new IllegalArgumentException("Unexpected message type: " + message); - } - } - - @Override - public void onClose(Session session, CloseReason reason) { - if (delegateSession != null) { - int code = reason.getCloseCode().getCode(); - delegateSession.handleClose(new CloseStatus(code, reason.getReasonPhrase())); - } - } - - @Override - public void onError(Session session, Throwable exception) { - if (delegateSession != null) { - delegateSession.handleError(exception); - } - } - } - - private final class HandlerResultSubscriber implements Subscriber { - - @Override - public void onSubscribe(Subscription subscription) { - subscription.request(Long.MAX_VALUE); - } - - @Override - public void onNext(Void aVoid) { - // no op - } - - @Override - public void onError(Throwable ex) { - if (delegateSession != null) { - int code = CloseStatus.SERVER_ERROR.getCode(); - delegateSession.close(new CloseStatus(code, ex.getMessage())); - } - } - - @Override - public void onComplete() { - if (delegateSession != null) { - delegateSession.close(); - } - } + return new StandardEndpoint(getDelegate(), getHandshakeInfo(), bufferFactory()); } } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java new file mode 100644 index 00000000000..1a4c47a16c7 --- /dev/null +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java @@ -0,0 +1,160 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.socket.client; + +import java.net.URI; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import javax.websocket.ClientEndpointConfig; +import javax.websocket.ContainerProvider; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.HandshakeResponse; +import javax.websocket.Session; +import javax.websocket.WebSocketContainer; +import javax.websocket.ClientEndpointConfig.Configurator; + +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; + +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.web.reactive.socket.HandshakeInfo; +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.adapter.StandardEndpoint; + +/** + * A Java WebSocket API (JSR-356) based implementation of + * {@link WebSocketClient}. + * + * @author Violeta Georgieva + * @since 5.0 + */ +public class StandardWebSocketClient extends WebSocketClientSupport implements WebSocketClient { + + private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); + + private final WebSocketContainer wsContainer; + + + /** + * Default constructor that calls {@code ContainerProvider.getWebSocketContainer()} + * to obtain a (new) {@link WebSocketContainer} instance. + */ + public StandardWebSocketClient() { + this(ContainerProvider.getWebSocketContainer()); + } + + /** + * Constructor accepting an existing {@link WebSocketContainer} instance. + * @param wsContainer a web socket container + */ + public StandardWebSocketClient(WebSocketContainer wsContainer) { + this.wsContainer = wsContainer; + } + + + @Override + public Mono execute(URI url, WebSocketHandler handler) { + return execute(url, new HttpHeaders(), handler); + } + + @Override + public Mono execute(URI url, HttpHeaders headers, WebSocketHandler handler) { + return connectInternal(url, headers, handler); + } + + private Mono connectInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { + MonoProcessor processor = MonoProcessor.create(); + return Mono.fromCallable(() -> { + StandardWebSocketClientConfigurator configurator = + new StandardWebSocketClientConfigurator(headers); + + ClientEndpointConfig endpointConfig = createClientEndpointConfig( + configurator, beforeHandshake(url, headers, handler)); + + HandshakeInfo info = new HandshakeInfo(url, Mono.empty()); + + Endpoint endpoint = new StandardClientEndpoint(handler, info, + this.bufferFactory, configurator, processor); + + Session session = this.wsContainer.connectToServer(endpoint, endpointConfig, url); + return session; + }).then(processor); + } + + private ClientEndpointConfig createClientEndpointConfig( + StandardWebSocketClientConfigurator configurator, String[] subProtocols) { + + return ClientEndpointConfig.Builder.create() + .configurator(configurator) + .preferredSubprotocols(Arrays.asList(subProtocols)) + .build(); + } + + + private static final class StandardClientEndpoint extends StandardEndpoint { + + private final StandardWebSocketClientConfigurator configurator; + + public StandardClientEndpoint(WebSocketHandler handler, HandshakeInfo info, + DataBufferFactory bufferFactory, StandardWebSocketClientConfigurator configurator, + MonoProcessor processor) { + super(handler, info, bufferFactory, processor); + this.configurator = configurator; + } + + @Override + public void onOpen(Session nativeSession, EndpointConfig config) { + getHandshakeInfo().setHeaders(this.configurator.getResponseHeaders()); + getHandshakeInfo().setSubProtocol( + Optional.ofNullable(nativeSession.getNegotiatedSubprotocol())); + + super.onOpen(nativeSession, config); + } + } + + + private static final class StandardWebSocketClientConfigurator extends Configurator { + + private final HttpHeaders requestHeaders; + + private HttpHeaders responseHeaders = new HttpHeaders(); + + public StandardWebSocketClientConfigurator(HttpHeaders requestHeaders) { + this.requestHeaders = requestHeaders; + } + + @Override + public void beforeRequest(Map> requestHeaders) { + requestHeaders.putAll(this.requestHeaders); + } + + @Override + public void afterResponse(HandshakeResponse response) { + response.getHeaders().forEach((k, v) -> responseHeaders.put(k, v)); + } + + public HttpHeaders getResponseHeaders() { + return this.responseHeaders; + } + } + +} \ No newline at end of file diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java index 58413692ca8..21467be9bb1 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java @@ -25,7 +25,6 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.Endpoint; -import javax.websocket.server.ServerEndpointConfig; import org.apache.tomcat.websocket.server.WsServerContainer; import reactor.core.publisher.Mono; diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/WebSocketIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/WebSocketIntegrationTests.java index f667de15560..41017303713 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/WebSocketIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/WebSocketIntegrationTests.java @@ -41,6 +41,7 @@ import org.springframework.web.reactive.socket.WebSocketSession; import org.springframework.web.reactive.socket.client.JettyWebSocketClient; import org.springframework.web.reactive.socket.client.ReactorNettyWebSocketClient; import org.springframework.web.reactive.socket.client.RxNettyWebSocketClient; +import org.springframework.web.reactive.socket.client.StandardWebSocketClient; import org.springframework.web.reactive.socket.client.WebSocketClient; import static org.junit.Assert.assertEquals; @@ -78,6 +79,11 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests client.stop(); } + @Test + public void echoStandardClient() throws Exception { + testEcho(new StandardWebSocketClient()); + } + private void testEcho(WebSocketClient client) throws URISyntaxException { int count = 100; Flux input = Flux.range(1, count).map(index -> "msg-" + index); @@ -113,6 +119,11 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests client.stop(); } + @Test + public void subProtocolStandardClient() throws Exception { + testSubProtocol(new StandardWebSocketClient()); + } + private void testSubProtocol(WebSocketClient client) throws URISyntaxException { String protocol = "echo-v1"; AtomicReference infoRef = new AtomicReference<>();