diff --git a/build.gradle b/build.gradle index ef2425a84b2..7c7f576b7fb 100644 --- a/build.gradle +++ b/build.gradle @@ -829,6 +829,9 @@ project("spring-web-reactive") { exclude group: "org.apache.tomcat", module: "tomcat-websocket-api" exclude group: "org.apache.tomcat", module: "tomcat-servlet-api" } + optional("io.undertow:undertow-websockets-jsr:${undertowVersion}") { + exclude group: "org.jboss.spec.javax.websocket", module: "jboss-websocket-api_1.1_spec" + } testCompile("io.projectreactor.addons:reactor-test:${reactorCoreVersion}") testCompile("javax.validation:validation-api:${beanvalVersion}") testCompile("org.hibernate:hibernate-validator:${hibval5Version}") diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSessionSupport.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSessionSupport.java new file mode 100644 index 00000000000..44fbe4bb849 --- /dev/null +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSessionSupport.java @@ -0,0 +1,195 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.socket.adapter; + +import java.io.IOException; +import java.net.URI; + +import java.util.concurrent.atomic.AtomicBoolean; + +import org.reactivestreams.Publisher; +import org.springframework.http.server.reactive.AbstractRequestBodyPublisher; +import org.springframework.http.server.reactive.AbstractResponseBodyProcessor; +import org.springframework.util.Assert; +import org.springframework.web.reactive.socket.CloseStatus; +import org.springframework.web.reactive.socket.WebSocketMessage; +import org.springframework.web.reactive.socket.WebSocketSession; +import org.springframework.web.reactive.socket.WebSocketMessage.Type; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Base class for Listener-based {@link WebSocketSession} adapters. + * + * @author Violeta Georgieva + * @since 5.0 + */ +public abstract class AbstractListenerWebSocketSessionSupport extends WebSocketSessionSupport { + + private final AtomicBoolean sendCalled = new AtomicBoolean(); + + private final String id; + + private final URI uri; + + protected final WebSocketMessagePublisher webSocketMessagePublisher = + new WebSocketMessagePublisher(); + + protected volatile WebSocketMessageProcessor webSocketMessageProcessor; + + public AbstractListenerWebSocketSessionSupport(T delegate, String id, URI uri) { + super(delegate); + Assert.notNull(id, "'id' is required."); + Assert.notNull(uri, "'uri' is required."); + this.id = id; + this.uri = uri; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public URI getUri() { + return this.uri; + } + + @Override + public Flux receive() { + return Flux.from(this.webSocketMessagePublisher); + } + + @Override + public Mono send(Publisher messages) { + if (this.sendCalled.compareAndSet(false, true)) { + this.webSocketMessageProcessor = new WebSocketMessageProcessor(); + return Mono.from(subscriber -> { + messages.subscribe(this.webSocketMessageProcessor); + this.webSocketMessageProcessor.subscribe(subscriber); + }); + } + else { + return Mono.error(new IllegalStateException("send() has already been called")); + } + } + + protected void resumeReceives() { + // no-op + } + + protected void suspendReceives() { + // no-op + } + + protected abstract boolean writeInternal(WebSocketMessage message) throws IOException; + + /** Handle a message callback from the Servlet container */ + void handleMessage(Type type, WebSocketMessage message) { + this.webSocketMessagePublisher.processWebSocketMessage(message); + } + + /** Handle a error callback from the Servlet container */ + void handleError(Throwable ex) { + this.webSocketMessagePublisher.onError(ex); + if (this.webSocketMessageProcessor != null) { + this.webSocketMessageProcessor.cancel(); + this.webSocketMessageProcessor.onError(ex); + } + } + + /** Handle a complete callback from the Servlet container */ + void handleClose(CloseStatus reason) { + this.webSocketMessagePublisher.onAllDataRead(); + if (this.webSocketMessageProcessor != null) { + this.webSocketMessageProcessor.cancel(); + this.webSocketMessageProcessor.onComplete(); + } + } + + final class WebSocketMessagePublisher extends AbstractRequestBodyPublisher { + private volatile WebSocketMessage webSocketMessage; + + @Override + protected void checkOnDataAvailable() { + if (this.webSocketMessage != null) { + onDataAvailable(); + } + } + + @Override + protected WebSocketMessage read() throws IOException { + if (this.webSocketMessage != null) { + WebSocketMessage result = this.webSocketMessage; + this.webSocketMessage = null; + resumeReceives(); + return result; + } + + return null; + } + + void processWebSocketMessage(WebSocketMessage webSocketMessage) { + this.webSocketMessage = webSocketMessage; + suspendReceives(); + onDataAvailable(); + } + + boolean canAccept() { + return this.webSocketMessage == null; + } + } + + final class WebSocketMessageProcessor extends AbstractResponseBodyProcessor { + private volatile boolean isReady = true; + + @Override + protected boolean write(WebSocketMessage message) throws IOException { + return writeInternal(message); + } + + @Override + protected void releaseData() { + if (logger.isTraceEnabled()) { + logger.trace("releaseBuffer: " + this.currentData); + } + this.currentData = null; + } + + @Override + protected boolean isDataEmpty(WebSocketMessage data) { + return data.getPayload().readableByteCount() == 0; + } + + @Override + protected boolean isWritePossible() { + if (this.isReady && this.currentData != null) { + return true; + } + else { + return false; + } + } + + void setReady(boolean ready) { + this.isReady = ready; + } + + } + +} diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketHandlerAdapter.java index e3066a5323e..12fa2980e8b 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketHandlerAdapter.java @@ -110,7 +110,8 @@ public class TomcatWebSocketHandlerAdapter extends Endpoint { @Override public void onClose(Session session, CloseReason reason) { if (this.wsSession != null) { - this.wsSession.handleClose(reason); + this.wsSession.handleClose( + new CloseStatus(reason.getCloseCode().getCode(), reason.getReasonPhrase())); } } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketSession.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketSession.java index bc30d7c26bc..3e9ab0258ae 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketSession.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketSession.java @@ -17,9 +17,7 @@ package org.springframework.web.reactive.socket.adapter; import java.io.IOException; -import java.net.URI; import java.nio.charset.StandardCharsets; -import java.util.concurrent.atomic.AtomicBoolean; import javax.websocket.CloseReason; import javax.websocket.SendHandler; @@ -27,15 +25,10 @@ import javax.websocket.SendResult; import javax.websocket.Session; import javax.websocket.CloseReason.CloseCodes; -import org.reactivestreams.Publisher; -import org.springframework.http.server.reactive.AbstractRequestBodyPublisher; -import org.springframework.http.server.reactive.AbstractResponseBodyProcessor; import org.springframework.web.reactive.socket.CloseStatus; import org.springframework.web.reactive.socket.WebSocketMessage; import org.springframework.web.reactive.socket.WebSocketSession; -import org.springframework.web.reactive.socket.WebSocketMessage.Type; -import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** @@ -45,58 +38,17 @@ import reactor.core.publisher.Mono; * @author Violeta Georgieva * @since 5.0 */ -public class TomcatWebSocketSession extends WebSocketSessionSupport { - - private final AtomicBoolean sendCalled = new AtomicBoolean(); - - private final WebSocketMessagePublisher webSocketMessagePublisher = - new WebSocketMessagePublisher(); - - private final String id; - - private final URI uri; - - private volatile WebSocketMessageProcessor webSocketMessageProcessor; +public class TomcatWebSocketSession extends AbstractListenerWebSocketSessionSupport { public TomcatWebSocketSession(Session session) { - super(session); - this.id = session.getId(); - this.uri = session.getRequestURI(); - } - - @Override - public String getId() { - return this.id; - } - - @Override - public URI getUri() { - return this.uri; - } - - @Override - public Flux receive() { - return Flux.from(this.webSocketMessagePublisher); - } - - @Override - public Mono send(Publisher messages) { - if (this.sendCalled.compareAndSet(false, true)) { - this.webSocketMessageProcessor = new WebSocketMessageProcessor(); - return Mono.from(subscriber -> { - messages.subscribe(this.webSocketMessageProcessor); - this.webSocketMessageProcessor.subscribe(subscriber); - }); - } - else { - return Mono.error(new IllegalStateException("send() has already been called")); - } + super(session, session.getId(), session.getRequestURI()); } @Override protected Mono closeInternal(CloseStatus status) { try { - getDelegate().close(new CloseReason(CloseCodes.getCloseCode(status.getCode()), status.getReason())); + getDelegate().close( + new CloseReason(CloseCodes.getCloseCode(status.getCode()), status.getReason())); } catch (IOException e) { return Mono.error(e); @@ -108,125 +60,43 @@ public class TomcatWebSocketSession extends WebSocketSessionSupport { return this.webSocketMessagePublisher.canAccept(); } - /** Handle a message callback from the Servlet container */ - void handleMessage(Type type, WebSocketMessage message) { - this.webSocketMessagePublisher.processWebSocketMessage(message); - } - - /** Handle a error callback from the Servlet container */ - void handleError(Throwable ex) { - this.webSocketMessagePublisher.onError(ex); - if (this.webSocketMessageProcessor != null) { - this.webSocketMessageProcessor.cancel(); - this.webSocketMessageProcessor.onError(ex); + @Override + protected boolean writeInternal(WebSocketMessage message) throws IOException { + if (WebSocketMessage.Type.TEXT.equals(message.getType())) { + this.webSocketMessageProcessor.setReady(false); + getDelegate().getAsyncRemote().sendText( + new String(message.getPayload().asByteBuffer().array(), StandardCharsets.UTF_8), + new WebSocketMessageSendHandler()); } - } - - /** Handle a complete callback from the Servlet container */ - void handleClose(CloseReason reason) { - this.webSocketMessagePublisher.onAllDataRead(); - if (this.webSocketMessageProcessor != null) { - this.webSocketMessageProcessor.cancel(); - this.webSocketMessageProcessor.onComplete(); + else if (WebSocketMessage.Type.BINARY.equals(message.getType())) { + this.webSocketMessageProcessor.setReady(false); + getDelegate().getAsyncRemote().sendBinary(message.getPayload().asByteBuffer(), + new WebSocketMessageSendHandler()); } - } - - private static final class WebSocketMessagePublisher extends AbstractRequestBodyPublisher { - private volatile WebSocketMessage webSocketMessage; - - @Override - protected void checkOnDataAvailable() { - if (this.webSocketMessage != null) { - onDataAvailable(); - } + else if (WebSocketMessage.Type.PING.equals(message.getType())) { + getDelegate().getAsyncRemote().sendPing(message.getPayload().asByteBuffer()); } - - @Override - protected WebSocketMessage read() throws IOException { - if (this.webSocketMessage != null) { - WebSocketMessage result = this.webSocketMessage; - this.webSocketMessage = null; - return result; - } - - return null; + else if (WebSocketMessage.Type.PONG.equals(message.getType())) { + getDelegate().getAsyncRemote().sendPong(message.getPayload().asByteBuffer()); } - - void processWebSocketMessage(WebSocketMessage webSocketMessage) { - this.webSocketMessage = webSocketMessage; - onDataAvailable(); - } - - boolean canAccept() { - return this.webSocketMessage == null; + else { + throw new IllegalArgumentException("Unexpected message type: " + message.getType()); } + return true; } - private final class WebSocketMessageProcessor extends AbstractResponseBodyProcessor { - private volatile boolean isReady = true; + private final class WebSocketMessageSendHandler implements SendHandler { @Override - protected boolean write(WebSocketMessage message) throws IOException { - if (WebSocketMessage.Type.TEXT.equals(message.getType())) { - this.isReady = false; - getDelegate().getAsyncRemote().sendText( - new String(message.getPayload().asByteBuffer().array(), StandardCharsets.UTF_8), - new WebSocketMessageSendHandler()); - } - else if (WebSocketMessage.Type.BINARY.equals(message.getType())) { - this.isReady = false; - getDelegate().getAsyncRemote().sendBinary(message.getPayload().asByteBuffer(), - new WebSocketMessageSendHandler()); - } - else if (WebSocketMessage.Type.PING.equals(message.getType())) { - getDelegate().getAsyncRemote().sendPing(message.getPayload().asByteBuffer()); - } - else if (WebSocketMessage.Type.PONG.equals(message.getType())) { - getDelegate().getAsyncRemote().sendPong(message.getPayload().asByteBuffer()); + public void onResult(SendResult result) { + if (result.isOK()) { + webSocketMessageProcessor.setReady(true); + webSocketMessageProcessor.onWritePossible(); } else { - throw new IllegalArgumentException("Unexpected message type: " + message.getType()); - } - return true; - } - - @Override - protected void releaseData() { - if (logger.isTraceEnabled()) { - logger.trace("releaseBuffer: " + this.currentData); + webSocketMessageProcessor.cancel(); + webSocketMessageProcessor.onError(result.getException()); } - this.currentData = null; - } - - @Override - protected boolean isDataEmpty(WebSocketMessage data) { - return data.getPayload().readableByteCount() == 0; - } - - @Override - protected boolean isWritePossible() { - if (this.isReady && this.currentData != null) { - return true; - } - else { - return false; - } - } - - private final class WebSocketMessageSendHandler implements SendHandler { - - @Override - public void onResult(SendResult result) { - if (result.isOK()) { - isReady = true; - webSocketMessageProcessor.onWritePossible(); - } - else { - webSocketMessageProcessor.cancel(); - webSocketMessageProcessor.onError(result.getException()); - } - } - } } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java new file mode 100644 index 00000000000..ec4ea0c1558 --- /dev/null +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java @@ -0,0 +1,156 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.socket.adapter; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; + +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.util.Assert; +import org.springframework.web.reactive.socket.CloseStatus; +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.WebSocketMessage; +import org.springframework.web.reactive.socket.WebSocketMessage.Type; + +import io.undertow.websockets.WebSocketConnectionCallback; +import io.undertow.websockets.core.AbstractReceiveListener; +import io.undertow.websockets.core.BufferedBinaryMessage; +import io.undertow.websockets.core.BufferedTextMessage; +import io.undertow.websockets.core.CloseMessage; +import io.undertow.websockets.core.WebSocketChannel; +import io.undertow.websockets.spi.WebSocketHttpExchange; + +/** + * Undertow {@code WebSocketHandler} implementation adapting and + * delegating to a Spring {@link WebSocketHandler}. + * + * @author Violeta Georgieva + * @since 5.0 + */ +public class UndertowWebSocketHandlerAdapter implements WebSocketConnectionCallback { + + private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false); + + private final WebSocketHandler handler; + + private UndertowWebSocketSession wsSession; + + public UndertowWebSocketHandlerAdapter(WebSocketHandler handler) { + Assert.notNull("'handler' is required"); + this.handler = handler; + } + + @Override + public void onConnect(WebSocketHttpExchange exchange, WebSocketChannel channel) { + try { + this.wsSession = new UndertowWebSocketSession(channel); + } + catch (URISyntaxException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + + channel.getReceiveSetter().set(new ReceiveListener()); + channel.resumeReceives(); + + HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(); + this.handler.handle(this.wsSession).subscribe(resultSubscriber); + } + + private final class ReceiveListener extends AbstractReceiveListener { + + @Override + protected void onFullTextMessage(WebSocketChannel channel, BufferedTextMessage message) { + wsSession.handleMessage(Type.TEXT, toMessage(Type.TEXT, message.getData())); + } + + @Override + protected void onFullBinaryMessage(WebSocketChannel channel, + BufferedBinaryMessage message) throws IOException { + wsSession.handleMessage(Type.BINARY, toMessage(Type.BINARY, message.getData().getResource())); + message.getData().free(); + } + + @Override + protected void onFullPongMessage(WebSocketChannel channel, + BufferedBinaryMessage message) throws IOException { + wsSession.handleMessage(Type.PONG, toMessage(Type.PONG, message.getData().getResource())); + message.getData().free(); + } + + @Override + protected void onFullCloseMessage(WebSocketChannel channel, + BufferedBinaryMessage message) throws IOException { + CloseMessage closeMessage = new CloseMessage(message.getData().getResource()); + wsSession.handleClose(new CloseStatus(closeMessage.getCode(), closeMessage.getReason())); + message.getData().free(); + } + + @Override + protected void onError(WebSocketChannel channel, Throwable error) { + wsSession.handleError(error); + } + + private WebSocketMessage toMessage(Type type, T message) { + if (Type.TEXT.equals(type)) { + return WebSocketMessage.create(Type.TEXT, + bufferFactory.wrap(((String) message).getBytes(StandardCharsets.UTF_8))); + } + else if (Type.BINARY.equals(type)) { + return WebSocketMessage.create(Type.BINARY, + bufferFactory.allocateBuffer().write((ByteBuffer[]) message)); + } + else if (Type.PONG.equals(type)) { + return WebSocketMessage.create(Type.PONG, + bufferFactory.allocateBuffer().write((ByteBuffer[]) message)); + } + else { + throw new IllegalArgumentException("Unexpected message type: " + message); + } + } + + } + + private final class HandlerResultSubscriber implements Subscriber { + + @Override + public void onSubscribe(Subscription subscription) { + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void aVoid) { + // no op + } + + @Override + public void onError(Throwable ex) { + wsSession.close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage())); + } + + @Override + public void onComplete() { + wsSession.close(); + } + } + +} diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java new file mode 100644 index 00000000000..bacad80f64a --- /dev/null +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java @@ -0,0 +1,111 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.socket.adapter; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; + +import org.springframework.util.ObjectUtils; +import org.springframework.web.reactive.socket.CloseStatus; +import org.springframework.web.reactive.socket.WebSocketMessage; +import org.springframework.web.reactive.socket.WebSocketSession; + +import io.undertow.websockets.core.CloseMessage; +import io.undertow.websockets.core.WebSocketCallback; +import io.undertow.websockets.core.WebSocketChannel; +import io.undertow.websockets.core.WebSockets; +import reactor.core.publisher.Mono; + +/** + * Spring {@link WebSocketSession} adapter for Undertow's + * {@link io.undertow.websockets.core.WebSocketChannel}. + * + * @author Violeta Georgieva + * @since 5.0 + */ +public class UndertowWebSocketSession extends AbstractListenerWebSocketSessionSupport { + + public UndertowWebSocketSession(WebSocketChannel channel) throws URISyntaxException { + super(channel, ObjectUtils.getIdentityHexString(channel), new URI(channel.getUrl())); + } + + @Override + protected Mono closeInternal(CloseStatus status) { + CloseMessage cm = new CloseMessage(status.getCode(), status.getReason()); + if (!getDelegate().isCloseFrameSent()) { + WebSockets.sendClose(cm, getDelegate(), null); + } + return Mono.empty(); + } + + protected void resumeReceives() { + getDelegate().resumeReceives(); + } + + protected void suspendReceives() { + getDelegate().suspendReceives(); + } + + @Override + protected boolean writeInternal(WebSocketMessage message) throws IOException { + if (WebSocketMessage.Type.TEXT.equals(message.getType())) { + this.webSocketMessageProcessor.setReady(false); + WebSockets.sendText( + new String(message.getPayload().asByteBuffer().array(), StandardCharsets.UTF_8), + getDelegate(), new WebSocketMessageSendHandler()); + } + else if (WebSocketMessage.Type.BINARY.equals(message.getType())) { + this.webSocketMessageProcessor.setReady(false); + WebSockets.sendBinary(message.getPayload().asByteBuffer(), + getDelegate(), new WebSocketMessageSendHandler()); + } + else if (WebSocketMessage.Type.PING.equals(message.getType())) { + this.webSocketMessageProcessor.setReady(false); + WebSockets.sendPing(message.getPayload().asByteBuffer(), + getDelegate(), new WebSocketMessageSendHandler()); + } + else if (WebSocketMessage.Type.PONG.equals(message.getType())) { + this.webSocketMessageProcessor.setReady(false); + WebSockets.sendPong(message.getPayload().asByteBuffer(), + getDelegate(), new WebSocketMessageSendHandler()); + } + else { + throw new IllegalArgumentException("Unexpected message type: " + message.getType()); + } + return true; + } + + private final class WebSocketMessageSendHandler implements WebSocketCallback { + + @Override + public void complete(WebSocketChannel channel, Void context) { + webSocketMessageProcessor.setReady(true); + webSocketMessageProcessor.onWritePossible(); + } + + @Override + public void onError(WebSocketChannel channel, Void context, + Throwable throwable) { + webSocketMessageProcessor.cancel(); + webSocketMessageProcessor.onError(throwable); + } + + } + +} diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java new file mode 100644 index 00000000000..7cc22ddda7e --- /dev/null +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.socket.server.upgrade; + +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.UndertowServerHttpRequest; +import org.springframework.util.Assert; +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter; +import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; +import org.springframework.web.server.ServerWebExchange; + +import io.undertow.server.HttpServerExchange; +import io.undertow.websockets.WebSocketProtocolHandshakeHandler; +import reactor.core.publisher.Mono; + +/** +* A {@link RequestUpgradeStrategy} for use with Undertow. + * + * @author Violeta Georgieva + * @since 5.0 + */ +public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy { + + @Override + public Mono upgrade(ServerWebExchange exchange, + WebSocketHandler webSocketHandler) { + + UndertowWebSocketHandlerAdapter callback = + new UndertowWebSocketHandlerAdapter(webSocketHandler); + + WebSocketProtocolHandshakeHandler handler = + new WebSocketProtocolHandshakeHandler(callback); + try { + handler.handleRequest(getUndertowExchange(exchange.getRequest())); + } + catch (Exception e) { + return Mono.error(e); + } + + return Mono.empty(); + } + + private final HttpServerExchange getUndertowExchange(ServerHttpRequest request) { + Assert.isTrue(request instanceof UndertowServerHttpRequest); + return ((UndertowServerHttpRequest) request).getUndertowExchange(); + } + +} diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/AbstractWebSocketHandlerIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/AbstractWebSocketHandlerIntegrationTests.java index de1d4d9b2a2..5d494253cb9 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/AbstractWebSocketHandlerIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/AbstractWebSocketHandlerIntegrationTests.java @@ -33,6 +33,7 @@ import org.springframework.http.server.reactive.bootstrap.HttpServer; import org.springframework.http.server.reactive.bootstrap.ReactorHttpServer; import org.springframework.http.server.reactive.bootstrap.RxNettyHttpServer; import org.springframework.http.server.reactive.bootstrap.TomcatHttpServer; +import org.springframework.http.server.reactive.bootstrap.UndertowHttpServer; import org.springframework.util.SocketUtils; import org.springframework.web.reactive.DispatcherHandler; import org.springframework.web.reactive.socket.server.support.HandshakeWebSocketService; @@ -40,6 +41,7 @@ import org.springframework.web.reactive.socket.server.support.WebSocketHandlerAd import org.springframework.web.reactive.socket.server.upgrade.ReactorNettyRequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.upgrade.RxNettyRequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.upgrade.TomcatRequestUpgradeStrategy; +import org.springframework.web.reactive.socket.server.upgrade.UndertowRequestUpgradeStrategy; /** * Base class for WebSocket integration tests involving a server-side @@ -68,7 +70,8 @@ public abstract class AbstractWebSocketHandlerIntegrationTests { return new Object[][] { {new ReactorHttpServer(), ReactorNettyConfig.class}, {new RxNettyHttpServer(), RxNettyConfig.class}, - {new TomcatHttpServer(base.getAbsolutePath(), WsContextListener.class), TomcatConfig.class} + {new TomcatHttpServer(base.getAbsolutePath(), WsContextListener.class), TomcatConfig.class}, + {new UndertowHttpServer(), UndertowConfig.class} }; } @@ -150,4 +153,13 @@ public abstract class AbstractWebSocketHandlerIntegrationTests { } } + @Configuration + static class UndertowConfig extends AbstractHandlerAdapterConfig { + + @Override + protected RequestUpgradeStrategy getUpgradeStrategy() { + return new UndertowRequestUpgradeStrategy(); + } + } + }