Browse Source

Replace WebSocketHandlerAdapterSupport with delegation

This commit removes the base class WebSocketHandlerAdapterSupport which
was mainly a container for properties. Instead we use a
java.util.Function to create the WebSocketSession which differs in any
way by client and server, which in turn allows  HandshakeInfo to become
a simple immutable container once again.

Also for Undertow the WebSocketConnectionCallback implementation has
been moved into the server.upgrade package since it is for server-side
use only.

Issue: SPR-14527
pull/1026/merge
Rossen Stoyanchev 9 years ago
parent
commit
935577f00b
  1. 40
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/HandshakeInfo.java
  2. 42
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java
  3. 168
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardEndpoint.java
  4. 144
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardWebSocketHandlerAdapter.java
  5. 44
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java
  6. 66
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/WebSocketHandlerAdapterSupport.java
  7. 73
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java
  8. 2
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java
  9. 2
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java
  10. 87
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java
  11. 137
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/UndertowWebSocketClient.java
  12. 11
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClientSupport.java
  13. 10
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java
  14. 10
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java
  15. 57
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java

40
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/HandshakeInfo.java

@ -39,26 +39,27 @@ public class HandshakeInfo { @@ -39,26 +39,27 @@ public class HandshakeInfo {
private final Mono<Principal> principalMono;
private HttpHeaders headers;
private final HttpHeaders headers;
private Optional<String> protocol;
private final Optional<String> protocol;
public HandshakeInfo(URI uri, Mono<Principal> principal) {
this(uri, new HttpHeaders(), principal, Optional.empty());
}
public HandshakeInfo(URI uri, HttpHeaders headers, Mono<Principal> principal,
Optional<String> subProtocol) {
/**
* Constructor with information about the handshake.
* @param uri the endpoint URL
* @param headers request headers for server or response headers or client
* @param principal the principal for the session
* @param protocol the negotiated sub-protocol
*/
public HandshakeInfo(URI uri, HttpHeaders headers, Mono<Principal> principal, Optional<String> protocol) {
Assert.notNull(uri, "URI is required.");
Assert.notNull(headers, "HttpHeaders are required.");
Assert.notNull(principal, "Principal is required.");
Assert.notNull(subProtocol, "Sub-protocol is required.");
Assert.notNull(protocol, "Sub-protocol is required.");
this.uri = uri;
this.headers = headers;
this.principalMono = principal;
this.protocol = subProtocol;
this.protocol = protocol;
}
@ -77,15 +78,6 @@ public class HandshakeInfo { @@ -77,15 +78,6 @@ public class HandshakeInfo {
return this.headers;
}
/**
* Sets the handshake HTTP headers. Those are the request headers for a
* server session and the response headers for a client session.
* @param headers the handshake HTTP headers.
*/
public void setHeaders(HttpHeaders headers) {
this.headers = headers;
}
/**
* Return the principal associated with the handshake HTTP request.
*/
@ -102,14 +94,6 @@ public class HandshakeInfo { @@ -102,14 +94,6 @@ public class HandshakeInfo {
return this.protocol;
}
/**
* Sets the sub-protocol negotiated at handshake time.
* @param protocol the sub-protocol negotiated at handshake time.
*/
public void setSubProtocol(Optional<String> protocol) {
this.protocol = protocol;
}
@Override
public String toString() {

42
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java

@ -18,6 +18,7 @@ package org.springframework.web.reactive.socket.adapter; @@ -18,6 +18,7 @@ package org.springframework.web.reactive.socket.adapter;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.function.Function;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose;
@ -33,12 +34,12 @@ import org.reactivestreams.Subscription; @@ -33,12 +34,12 @@ import org.reactivestreams.Subscription;
import reactor.core.publisher.MonoProcessor;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketMessage.Type;
import org.springframework.web.reactive.socket.WebSocketSession;
/**
* Jetty {@link WebSocket @WebSocket} handler that delegates events to a
@ -49,34 +50,41 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type; @@ -49,34 +50,41 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type;
* @since 5.0
*/
@WebSocket
public class JettyWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport {
public class JettyWebSocketHandlerAdapter {
private static final ByteBuffer EMPTY_PAYLOAD = ByteBuffer.wrap(new byte[0]);
private JettyWebSocketSession delegateSession;
private final WebSocketHandler delegateHandler;
private final MonoProcessor<Void> completionMono;
private final Function<Session, JettyWebSocketSession> sessionFactory;
private JettyWebSocketSession delegateSession;
public JettyWebSocketHandlerAdapter(WebSocketHandler delegate, HandshakeInfo info,
DataBufferFactory bufferFactory) {
public JettyWebSocketHandlerAdapter(WebSocketHandler handler,
Function<Session, JettyWebSocketSession> sessionFactory) {
this(delegate, info, bufferFactory, null);
this(handler, null, sessionFactory);
}
public JettyWebSocketHandlerAdapter(WebSocketHandler delegate, HandshakeInfo info,
DataBufferFactory bufferFactory, MonoProcessor<Void> completionMono) {
public JettyWebSocketHandlerAdapter(WebSocketHandler handler, MonoProcessor<Void> completionMono,
Function<Session, JettyWebSocketSession> sessionFactory) {
super(delegate, info, bufferFactory);
Assert.notNull("WebSocketHandler is required");
Assert.notNull("'sessionFactory' is required");
this.delegateHandler = handler;
this.completionMono = completionMono;
this.sessionFactory = sessionFactory;
}
@OnWebSocketConnect
public void onWebSocketConnect(Session session) {
this.delegateSession = new JettyWebSocketSession(session, getHandshakeInfo(), bufferFactory());
this.delegateSession = sessionFactory.apply(session);
HandlerResultSubscriber subscriber = new HandlerResultSubscriber();
getDelegate().handle(this.delegateSession).subscribe(subscriber);
this.delegateHandler.handle(this.delegateSession).subscribe(subscriber);
}
@OnWebSocketMessage
@ -108,17 +116,19 @@ public class JettyWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport @@ -108,17 +116,19 @@ public class JettyWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport
}
private <T> WebSocketMessage toMessage(Type type, T message) {
WebSocketSession session = this.delegateSession;
Assert.state(session != null, "Cannot create message without a session");
if (Type.TEXT.equals(type)) {
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = bufferFactory().wrap(bytes);
DataBuffer buffer = session.bufferFactory().wrap(bytes);
return new WebSocketMessage(Type.TEXT, buffer);
}
else if (Type.BINARY.equals(type)) {
DataBuffer buffer = bufferFactory().wrap((ByteBuffer) message);
DataBuffer buffer = session.bufferFactory().wrap((ByteBuffer) message);
return new WebSocketMessage(Type.BINARY, buffer);
}
else if (Type.PONG.equals(type)) {
DataBuffer buffer = bufferFactory().wrap((ByteBuffer) message);
DataBuffer buffer = session.bufferFactory().wrap((ByteBuffer) message);
return new WebSocketMessage(Type.PONG, buffer);
}
else {

168
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardEndpoint.java

@ -1,168 +0,0 @@ @@ -1,168 +0,0 @@
/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.socket.adapter;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import javax.websocket.CloseReason;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.PongMessage;
import javax.websocket.Session;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.publisher.MonoProcessor;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketMessage.Type;
/**
* {@link Endpoint} delegating events
* to a reactive {@link WebSocketHandler} and its session.
*
* @author Violeta Georgieva
* @since 5.0
*/
public class StandardEndpoint extends Endpoint {
private final WebSocketHandler handler;
private final HandshakeInfo info;
private final DataBufferFactory bufferFactory;
private StandardWebSocketSession delegateSession;
private final MonoProcessor<Void> completionMono;
public StandardEndpoint(WebSocketHandler handler, HandshakeInfo info,
DataBufferFactory bufferFactory) {
this(handler, info, bufferFactory, null);
}
public StandardEndpoint(WebSocketHandler handler, HandshakeInfo info,
DataBufferFactory bufferFactory, MonoProcessor<Void> completionMono) {
this.handler = handler;
this.info = info;
this.bufferFactory = bufferFactory;
this.completionMono = completionMono;
}
@Override
public void onOpen(Session nativeSession, EndpointConfig config) {
this.delegateSession = new StandardWebSocketSession(nativeSession, this.info, this.bufferFactory);
nativeSession.addMessageHandler(String.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage);
});
nativeSession.addMessageHandler(ByteBuffer.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage);
});
nativeSession.addMessageHandler(PongMessage.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage);
});
HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber();
this.handler.handle(this.delegateSession).subscribe(resultSubscriber);
}
private <T> WebSocketMessage toMessage(T message) {
if (message instanceof String) {
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8);
return new WebSocketMessage(Type.TEXT, this.bufferFactory.wrap(bytes));
}
else if (message instanceof ByteBuffer) {
DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message);
return new WebSocketMessage(Type.BINARY, buffer);
}
else if (message instanceof PongMessage) {
DataBuffer buffer = this.bufferFactory.wrap(((PongMessage) message).getApplicationData());
return new WebSocketMessage(Type.PONG, buffer);
}
else {
throw new IllegalArgumentException("Unexpected message type: " + message);
}
}
@Override
public void onClose(Session session, CloseReason reason) {
if (this.delegateSession != null) {
int code = reason.getCloseCode().getCode();
this.delegateSession.handleClose(new CloseStatus(code, reason.getReasonPhrase()));
}
}
@Override
public void onError(Session session, Throwable exception) {
if (this.delegateSession != null) {
this.delegateSession.handleError(exception);
}
}
protected HandshakeInfo getHandshakeInfo() {
return this.info;
}
private final class HandlerResultSubscriber implements Subscriber<Void> {
@Override
public void onSubscribe(Subscription subscription) {
subscription.request(Long.MAX_VALUE);
}
@Override
public void onNext(Void aVoid) {
// no op
}
@Override
public void onError(Throwable ex) {
if (completionMono != null) {
completionMono.onError(ex);
}
if (delegateSession != null) {
int code = CloseStatus.SERVER_ERROR.getCode();
delegateSession.close(new CloseStatus(code, ex.getMessage()));
}
}
@Override
public void onComplete() {
if (completionMono != null) {
completionMono.onComplete();
}
if (delegateSession != null) {
delegateSession.close();
}
}
}
}

144
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/StandardWebSocketHandlerAdapter.java

@ -16,31 +16,153 @@ @@ -16,31 +16,153 @@
package org.springframework.web.reactive.socket.adapter;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.function.Function;
import javax.websocket.CloseReason;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.PongMessage;
import javax.websocket.Session;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.publisher.MonoProcessor;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketMessage.Type;
import org.springframework.web.reactive.socket.WebSocketSession;
/**
* Adapter for Java WebSocket API (JSR-356).
*
* Adapter for Java WebSocket API (JSR-356) that delegates events to a reactive
* {@link WebSocketHandler} and its session.
*
* @author Violeta Georgieva
* @author Rossen Stoyanchev
* @since 5.0
*/
public class StandardWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport {
public class StandardWebSocketHandlerAdapter extends Endpoint {
private final WebSocketHandler delegateHandler;
private final MonoProcessor<Void> completionMono;
private Function<Session, StandardWebSocketSession> sessionFactory;
public StandardWebSocketHandlerAdapter(WebSocketHandler delegate, HandshakeInfo info,
DataBufferFactory bufferFactory) {
private StandardWebSocketSession delegateSession;
super(delegate, info, bufferFactory);
public StandardWebSocketHandlerAdapter(WebSocketHandler handler,
Function<Session, StandardWebSocketSession> sessionFactory) {
this(handler, null, sessionFactory);
}
public StandardWebSocketHandlerAdapter(WebSocketHandler handler, MonoProcessor<Void> completionMono,
Function<Session, StandardWebSocketSession> sessionFactory) {
Assert.notNull("WebSocketHandler is required");
Assert.notNull("'sessionFactory' is required");
this.delegateHandler = handler;
this.completionMono = completionMono;
this.sessionFactory = sessionFactory;
}
@Override
public void onOpen(Session session, EndpointConfig config) {
this.delegateSession = this.sessionFactory.apply(session);
session.addMessageHandler(String.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage);
});
session.addMessageHandler(ByteBuffer.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage);
});
session.addMessageHandler(PongMessage.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage);
});
HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber();
this.delegateHandler.handle(this.delegateSession).subscribe(resultSubscriber);
}
private <T> WebSocketMessage toMessage(T message) {
WebSocketSession session = this.delegateSession;
Assert.state(session != null, "Cannot create message without a session");
if (message instanceof String) {
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8);
return new WebSocketMessage(Type.TEXT, session.bufferFactory().wrap(bytes));
}
else if (message instanceof ByteBuffer) {
DataBuffer buffer = session.bufferFactory().wrap((ByteBuffer) message);
return new WebSocketMessage(Type.BINARY, buffer);
}
else if (message instanceof PongMessage) {
DataBuffer buffer = session.bufferFactory().wrap(((PongMessage) message).getApplicationData());
return new WebSocketMessage(Type.PONG, buffer);
}
else {
throw new IllegalArgumentException("Unexpected message type: " + message);
}
}
@Override
public void onClose(Session session, CloseReason reason) {
if (this.delegateSession != null) {
int code = reason.getCloseCode().getCode();
this.delegateSession.handleClose(new CloseStatus(code, reason.getReasonPhrase()));
}
}
@Override
public void onError(Session session, Throwable exception) {
if (this.delegateSession != null) {
this.delegateSession.handleError(exception);
}
}
private final class HandlerResultSubscriber implements Subscriber<Void> {
@Override
public void onSubscribe(Subscription subscription) {
subscription.request(Long.MAX_VALUE);
}
@Override
public void onNext(Void aVoid) {
// no op
}
@Override
public void onError(Throwable ex) {
if (completionMono != null) {
completionMono.onError(ex);
}
if (delegateSession != null) {
int code = CloseStatus.SERVER_ERROR.getCode();
delegateSession.close(new CloseStatus(code, ex.getMessage()));
}
}
public Endpoint getEndpoint() {
return new StandardEndpoint(getDelegate(), getHandshakeInfo(), bufferFactory());
@Override
public void onComplete() {
if (completionMono != null) {
completionMono.onComplete();
}
if (delegateSession != null) {
delegateSession.close();
}
}
}
}
}

44
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java

@ -25,20 +25,17 @@ import io.undertow.websockets.core.BufferedBinaryMessage; @@ -25,20 +25,17 @@ import io.undertow.websockets.core.BufferedBinaryMessage;
import io.undertow.websockets.core.BufferedTextMessage;
import io.undertow.websockets.core.CloseMessage;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.spi.WebSocketHttpExchange;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.publisher.MonoProcessor;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketMessage.Type;
import org.springframework.web.reactive.socket.WebSocketSession;
/**
* Undertow {@link WebSocketConnectionCallback} implementation that adapts and
@ -48,36 +45,35 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type; @@ -48,36 +45,35 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type;
* @author Rossen Stoyanchev
* @since 5.0
*/
public class UndertowWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport
implements WebSocketConnectionCallback {
public class UndertowWebSocketHandlerAdapter {
private UndertowWebSocketSession delegateSession;
private final WebSocketHandler delegateHandler;
private final MonoProcessor<Void> completionMono;
private UndertowWebSocketSession delegateSession;
public UndertowWebSocketHandlerAdapter(WebSocketHandler delegate, HandshakeInfo info,
DataBufferFactory bufferFactory) {
this(delegate, info, bufferFactory, null);
public UndertowWebSocketHandlerAdapter(WebSocketHandler handler) {
this(handler, null);
}
public UndertowWebSocketHandlerAdapter(WebSocketHandler delegate, HandshakeInfo info,
DataBufferFactory bufferFactory, MonoProcessor<Void> completionMono) {
public UndertowWebSocketHandlerAdapter(WebSocketHandler handler, MonoProcessor<Void> completionMono) {
super(delegate, info, bufferFactory);
Assert.notNull("WebSocketHandler is required");
Assert.notNull("'sessionFactory' is required");
this.delegateHandler = handler;
this.completionMono = completionMono;
}
@Override
public void onConnect(WebSocketHttpExchange exchange, WebSocketChannel channel) {
this.delegateSession = new UndertowWebSocketSession(channel, getHandshakeInfo(), bufferFactory());
channel.getReceiveSetter().set(new UndertowReceiveListener());
channel.resumeReceives();
public void handle(UndertowWebSocketSession webSocketSession) {
this.delegateSession = webSocketSession;
webSocketSession.getDelegate().getReceiveSetter().set(new UndertowReceiveListener());
webSocketSession.getDelegate().resumeReceives();
HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber();
getDelegate().handle(this.delegateSession).subscribe(resultSubscriber);
this.delegateHandler.handle(this.delegateSession).subscribe(resultSubscriber);
}
@ -113,16 +109,18 @@ public class UndertowWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupp @@ -113,16 +109,18 @@ public class UndertowWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupp
}
private <T> WebSocketMessage toMessage(Type type, T message) {
WebSocketSession session = delegateSession;
Assert.state(session != null, "Cannot create message without a session");
if (Type.TEXT.equals(type)) {
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8);
return new WebSocketMessage(Type.TEXT, bufferFactory().wrap(bytes));
return new WebSocketMessage(Type.TEXT, session.bufferFactory().wrap(bytes));
}
else if (Type.BINARY.equals(type)) {
DataBuffer buffer = bufferFactory().allocateBuffer().write((ByteBuffer[]) message);
DataBuffer buffer = session.bufferFactory().allocateBuffer().write((ByteBuffer[]) message);
return new WebSocketMessage(Type.BINARY, buffer);
}
else if (Type.PONG.equals(type)) {
DataBuffer buffer = bufferFactory().allocateBuffer().write((ByteBuffer[]) message);
DataBuffer buffer = session.bufferFactory().allocateBuffer().write((ByteBuffer[]) message);
return new WebSocketMessage(Type.PONG, buffer);
}
else {

66
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/WebSocketHandlerAdapterSupport.java

@ -1,66 +0,0 @@ @@ -1,66 +0,0 @@
/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.socket.adapter;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
/**
* Base class for adapters from event-listener WebSocket APIs (e.g. Java
* WebSocket API JSR-356, Jetty, Undertow) to the Reactive Streams based
* {@link WebSocketHandler}.
*
* @author Rossen Stoyanchev
* @since 5.0
*/
public abstract class WebSocketHandlerAdapterSupport {
private final WebSocketHandler delegate;
private final HandshakeInfo handshakeInfo;
private final DataBufferFactory bufferFactory;
protected WebSocketHandlerAdapterSupport(WebSocketHandler delegate, HandshakeInfo info,
DataBufferFactory bufferFactory) {
Assert.notNull(delegate, "WebSocketHandler delegate is required");
Assert.notNull(info, "HandshakeInfo is required.");
Assert.notNull(bufferFactory, "DataBufferFactory is required");
this.delegate = delegate;
this.handshakeInfo = info;
this.bufferFactory = bufferFactory;
}
protected WebSocketHandler getDelegate() {
return this.delegate;
}
protected HandshakeInfo getHandshakeInfo() {
return this.handshakeInfo;
}
@SuppressWarnings("unchecked")
protected <T extends DataBufferFactory> T bufferFactory() {
return (T) this.bufferFactory;
}
}

73
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java

@ -17,13 +17,10 @@ @@ -17,13 +17,10 @@
package org.springframework.web.reactive.socket.client;
import java.net.URI;
import java.util.Optional;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
@ -31,15 +28,16 @@ import org.springframework.context.Lifecycle; @@ -31,15 +28,16 @@ import org.springframework.context.Lifecycle;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.util.ObjectUtils;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession;
/**
* A Jetty based implementation of {@link WebSocketClient}.
*
* @author Violeta Georgieva
* @author Rossen Stoyanchev
* @since 5.0
*/
public class JettyWebSocketClient extends WebSocketClientSupport implements WebSocketClient, Lifecycle {
@ -76,31 +74,38 @@ public class JettyWebSocketClient extends WebSocketClientSupport implements WebS @@ -76,31 +74,38 @@ public class JettyWebSocketClient extends WebSocketClientSupport implements WebS
@Override
public Mono<Void> execute(URI url, HttpHeaders headers, WebSocketHandler handler) {
return connectInternal(url, headers, handler);
return executeInternal(url, headers, handler);
}
private Mono<Void> connectInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
MonoProcessor<Void> processor = MonoProcessor.create();
private Mono<Void> executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
MonoProcessor<Void> completionMono = MonoProcessor.create();
return Mono.fromCallable(
() -> {
HandshakeInfo info = new HandshakeInfo(url, Mono.empty());
Object adapter = new JettyClientAdapter(handler, info, this.bufferFactory, processor);
ClientUpgradeRequest request = createRequest(url, headers, handler);
return this.wsClient.connect(adapter, url, request);
String[] protocols = beforeHandshake(url, headers, handler);
ClientUpgradeRequest upgradeRequest = createRequest(url, headers, protocols);
Object jettyHandler = createJettyHandler(url, handler, completionMono);
return this.wsClient.connect(jettyHandler, url, upgradeRequest);
})
.then(processor);
.then(completionMono);
}
private ClientUpgradeRequest createRequest(URI url, HttpHeaders headers, WebSocketHandler handler) {
ClientUpgradeRequest request = new ClientUpgradeRequest();
String[] protocols = beforeHandshake(url, headers, handler);
if (!ObjectUtils.isEmpty(protocols)) {
request.setSubProtocols(protocols);
}
private Object createJettyHandler(URI url, WebSocketHandler handler, MonoProcessor<Void> completion) {
return new JettyWebSocketHandlerAdapter(
handler, completion, session -> createJettySession(url, session));
}
headers.forEach((k, v) -> request.setHeader(k, v));
private JettyWebSocketSession createJettySession(URI url, Session session) {
UpgradeResponse response = session.getUpgradeResponse();
HttpHeaders responseHeaders = new HttpHeaders();
response.getHeaders().forEach(responseHeaders::put);
HandshakeInfo info = afterHandshake(url, responseHeaders);
return new JettyWebSocketSession(session, info, bufferFactory);
}
private ClientUpgradeRequest createRequest(URI url, HttpHeaders headers, String[] protocols) {
ClientUpgradeRequest request = new ClientUpgradeRequest();
request.setSubProtocols(protocols);
headers.forEach(request::setHeader);
return request;
}
@ -140,32 +145,4 @@ public class JettyWebSocketClient extends WebSocketClientSupport implements WebS @@ -140,32 +145,4 @@ public class JettyWebSocketClient extends WebSocketClientSupport implements WebS
}
}
@WebSocket
private static final class JettyClientAdapter extends JettyWebSocketHandlerAdapter {
public JettyClientAdapter(WebSocketHandler delegate,
HandshakeInfo info, DataBufferFactory bufferFactory, MonoProcessor<Void> processor) {
super(delegate, info, bufferFactory, processor);
}
@Override
public void onWebSocketConnect(Session session) {
UpgradeResponse response = session.getUpgradeResponse();
getHandshakeInfo().setHeaders(getResponseHeaders(response));
getHandshakeInfo().setSubProtocol(
Optional.ofNullable(response.getAcceptedSubProtocol()));
super.onWebSocketConnect(session);
}
private HttpHeaders getResponseHeaders(UpgradeResponse response) {
HttpHeaders responseHeaders = new HttpHeaders();
response.getHeaders().forEach((k, v) -> responseHeaders.put(k, v));
return responseHeaders;
}
}
}

2
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java

@ -70,7 +70,7 @@ public class ReactorNettyWebSocketClient extends WebSocketClientSupport implemen @@ -70,7 +70,7 @@ public class ReactorNettyWebSocketClient extends WebSocketClientSupport implemen
})
.then(response -> {
HttpHeaders responseHeaders = getResponseHeaders(response);
HandshakeInfo info = afterHandshake(url, response.status().code(), responseHeaders);
HandshakeInfo info = afterHandshake(url, responseHeaders);
ByteBufAllocator allocator = response.channel().alloc();
NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator);

2
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java

@ -114,7 +114,7 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We @@ -114,7 +114,7 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
.flatMap(tuple -> {
WebSocketResponse<ByteBuf> response = tuple.getT1();
HttpHeaders responseHeaders = getResponseHeaders(response);
HandshakeInfo info = afterHandshake(url, response.getStatus().code(), responseHeaders);
HandshakeInfo info = afterHandshake(url, responseHeaders);
ByteBufAllocator allocator = response.unsafeNettyChannel().alloc();
NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator);

87
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java

@ -20,25 +20,25 @@ import java.net.URI; @@ -20,25 +20,25 @@ import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ClientEndpointConfig.Configurator;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.HandshakeResponse;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.ClientEndpointConfig.Configurator;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
import reactor.core.scheduler.Schedulers;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.StandardEndpoint;
import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.StandardWebSocketSession;
/**
* A Java WebSocket API (JSR-356) based implementation of
@ -78,70 +78,59 @@ public class StandardWebSocketClient extends WebSocketClientSupport implements W @@ -78,70 +78,59 @@ public class StandardWebSocketClient extends WebSocketClientSupport implements W
@Override
public Mono<Void> execute(URI url, HttpHeaders headers, WebSocketHandler handler) {
return connectInternal(url, headers, handler);
return executeInternal(url, headers, handler);
}
private Mono<Void> connectInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
MonoProcessor<Void> processor = MonoProcessor.create();
return Mono.fromCallable(() -> {
StandardWebSocketClientConfigurator configurator =
new StandardWebSocketClientConfigurator(headers);
ClientEndpointConfig endpointConfig = createClientEndpointConfig(
configurator, beforeHandshake(url, headers, handler));
HandshakeInfo info = new HandshakeInfo(url, Mono.empty());
Endpoint endpoint = new StandardClientEndpoint(handler, info,
this.bufferFactory, configurator, processor);
Session session = this.wsContainer.connectToServer(endpoint, endpointConfig, url);
return session;
}).then(processor);
private Mono<Void> executeInternal(URI url, HttpHeaders requestHeaders, WebSocketHandler handler) {
MonoProcessor<Void> completionMono = MonoProcessor.create();
return Mono.fromCallable(
() -> {
String[] subProtocols = beforeHandshake(url, requestHeaders, handler);
DefaultConfigurator configurator = new DefaultConfigurator(requestHeaders);
ClientEndpointConfig config = createEndpointConfig(configurator, subProtocols);
Endpoint endpoint = createEndpoint(url, handler, completionMono, configurator);
return this.wsContainer.connectToServer(endpoint, config, url);
})
.subscribeOn(Schedulers.elastic()) // connectToServer is blocking
.then(completionMono);
}
private ClientEndpointConfig createClientEndpointConfig(
StandardWebSocketClientConfigurator configurator, String[] subProtocols) {
private ClientEndpointConfig createEndpointConfig(Configurator configurator, String[] subProtocols) {
return ClientEndpointConfig.Builder.create()
.configurator(configurator)
.preferredSubprotocols(Arrays.asList(subProtocols))
.build();
}
private StandardWebSocketHandlerAdapter createEndpoint(URI url, WebSocketHandler handler,
MonoProcessor<Void> completion, DefaultConfigurator configurator) {
private static final class StandardClientEndpoint extends StandardEndpoint {
private final StandardWebSocketClientConfigurator configurator;
public StandardClientEndpoint(WebSocketHandler handler, HandshakeInfo info,
DataBufferFactory bufferFactory, StandardWebSocketClientConfigurator configurator,
MonoProcessor<Void> processor) {
super(handler, info, bufferFactory, processor);
this.configurator = configurator;
}
@Override
public void onOpen(Session nativeSession, EndpointConfig config) {
getHandshakeInfo().setHeaders(this.configurator.getResponseHeaders());
getHandshakeInfo().setSubProtocol(
Optional.ofNullable(nativeSession.getNegotiatedSubprotocol()));
return new StandardWebSocketHandlerAdapter(handler, completion,
session -> createSession(url, configurator.getResponseHeaders(), session));
}
super.onOpen(nativeSession, config);
}
private StandardWebSocketSession createSession(URI url, HttpHeaders responseHeaders, Session session) {
HandshakeInfo info = afterHandshake(url, responseHeaders);
return new StandardWebSocketSession(session, info, this.bufferFactory);
}
private static final class StandardWebSocketClientConfigurator extends Configurator {
private static final class DefaultConfigurator extends Configurator {
private final HttpHeaders requestHeaders;
private HttpHeaders responseHeaders = new HttpHeaders();
private final HttpHeaders responseHeaders = new HttpHeaders();
public StandardWebSocketClientConfigurator(HttpHeaders requestHeaders) {
public DefaultConfigurator(HttpHeaders requestHeaders) {
this.requestHeaders = requestHeaders;
}
public HttpHeaders getResponseHeaders() {
return this.responseHeaders;
}
@Override
public void beforeRequest(Map<String, List<String>> requestHeaders) {
requestHeaders.putAll(this.requestHeaders);
@ -149,11 +138,7 @@ public class StandardWebSocketClient extends WebSocketClientSupport implements W @@ -149,11 +138,7 @@ public class StandardWebSocketClient extends WebSocketClientSupport implements W
@Override
public void afterResponse(HandshakeResponse response) {
response.getHeaders().forEach((k, v) -> responseHeaders.put(k, v));
}
public HttpHeaders getResponseHeaders() {
return this.responseHeaders;
response.getHeaders().forEach(this.responseHeaders::put);
}
}

137
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/UndertowWebSocketClient.java

@ -23,26 +23,20 @@ import java.util.Arrays; @@ -23,26 +23,20 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.function.Function;
import javax.net.ssl.SSLContext;
import io.undertow.protocols.ssl.UndertowXnioSsl;
import io.undertow.server.DefaultByteBufferPool;
import io.undertow.websockets.WebSocketExtension;
import io.undertow.websockets.client.WebSocketClient.ConnectionBuilder;
import io.undertow.websockets.client.WebSocketClientNegotiation;
import io.undertow.websockets.core.WebSocketChannel;
import org.xnio.IoFuture;
import org.xnio.IoFuture.Notifier;
import org.xnio.IoFuture.Status;
import org.xnio.OptionMap;
import org.xnio.Options;
import org.xnio.Xnio;
import org.xnio.XnioWorker;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
@ -52,11 +46,13 @@ import org.springframework.http.HttpHeaders; @@ -52,11 +46,13 @@ import org.springframework.http.HttpHeaders;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession;
/**
* An Undertow based implementation of {@link WebSocketClient}.
*
* @author Violeta Georgieva
* @author Rossen Stoyanchev
* @since 5.0
*/
public class UndertowWebSocketClient extends WebSocketClientSupport implements WebSocketClient {
@ -82,7 +78,10 @@ public class UndertowWebSocketClient extends WebSocketClientSupport implements W @@ -82,7 +78,10 @@ public class UndertowWebSocketClient extends WebSocketClientSupport implements W
}
}
private final Function<URI, io.undertow.websockets.client.WebSocketClient.ConnectionBuilder> builder;
private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory();
private final Function<URI, ConnectionBuilder> builder;
/**
@ -100,15 +99,13 @@ public class UndertowWebSocketClient extends WebSocketClientSupport implements W @@ -100,15 +99,13 @@ public class UndertowWebSocketClient extends WebSocketClientSupport implements W
* instance.
* @param builder a connection builder that can be used to create a web socket connection.
*/
public UndertowWebSocketClient(Function<URI,
io.undertow.websockets.client.WebSocketClient.ConnectionBuilder> builder) {
public UndertowWebSocketClient(Function<URI, ConnectionBuilder> builder) {
this.builder = builder;
}
private static io.undertow.websockets.client.WebSocketClient.ConnectionBuilder createDefaultConnectionBuilder(
URI url) {
private static ConnectionBuilder createDefaultConnectionBuilder(URI url) {
io.undertow.websockets.client.WebSocketClient.ConnectionBuilder builder =
ConnectionBuilder builder =
io.undertow.websockets.client.WebSocketClient.connectionBuilder(
worker, new DefaultByteBufferPool(false, DEFAULT_BUFFER_SIZE), url);
@ -135,99 +132,77 @@ public class UndertowWebSocketClient extends WebSocketClientSupport implements W @@ -135,99 +132,77 @@ public class UndertowWebSocketClient extends WebSocketClientSupport implements W
@Override
public Mono<Void> execute(URI url, HttpHeaders headers, WebSocketHandler handler) {
return connectInternal(url, headers, handler);
return executeInternal(url, headers, handler);
}
private Mono<Void> connectInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
MonoProcessor<Void> processor = MonoProcessor.create();
private Mono<Void> executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
MonoProcessor<Void> completionMono = MonoProcessor.create();
return Mono.fromCallable(
() -> {
WSClientNegotiation clientNegotiation =
new WSClientNegotiation(beforeHandshake(url, headers, handler),
Collections.emptyList(), headers);
io.undertow.websockets.client.WebSocketClient.ConnectionBuilder builder =
this.builder.apply(url).setClientNegotiation(clientNegotiation);
IoFuture<WebSocketChannel> future = builder.connect();
future.addNotifier(new ResultNotifier(url, handler, clientNegotiation, processor), new Object());
return future;
String[] subProtocols = beforeHandshake(url, headers, handler);
DefaultNegotiation negotiation = new DefaultNegotiation(subProtocols, headers);
return this.builder.apply(url)
.setClientNegotiation(negotiation)
.connect()
.addNotifier((future, attachment) -> {
if (Status.DONE.equals(future.getStatus())) {
WebSocketChannel channel;
try {
channel = future.get();
}
catch (CancellationException | IOException ex) {
completionMono.onError(ex);
return;
}
handleWebSocket(url, handler, completionMono, negotiation, channel);
}
else if (Status.FAILED.equals(future.getStatus())) {
completionMono.onError(future.getException());
}
else {
String message = "Failed to connect" + future.getStatus();
completionMono.onError(new IllegalStateException(message));
}
}, null);
})
.then(processor);
.then(completionMono);
}
private void handleWebSocket(URI url, WebSocketHandler handler, MonoProcessor<Void> completionMono,
DefaultNegotiation negotiation, WebSocketChannel channel) {
private static final class ResultNotifier implements Notifier<WebSocketChannel, Object> {
private final URI url;
private final WebSocketHandler handler;
private final WSClientNegotiation clientNegotiation;
private final MonoProcessor<Void> processor;
public ResultNotifier(URI url, WebSocketHandler handler,
WSClientNegotiation clientNegotiation, MonoProcessor<Void> processor) {
this.url = url;
this.handler = handler;
this.clientNegotiation = clientNegotiation;
this.processor = processor;
}
@Override
public void notify(IoFuture<? extends WebSocketChannel> ioFuture,
Object attachment) {
if (Status.CANCELLED.equals(ioFuture.getStatus())) {
processor.onError(null);
}
else if (Status.FAILED.equals(ioFuture.getStatus())) {
processor.onError(ioFuture.getException());
}
else if (Status.DONE.equals(ioFuture.getStatus())) {
try {
WebSocketChannel channel = ioFuture.get();
DataBufferFactory bufferFactory = new DefaultDataBufferFactory();
HandshakeInfo info = new HandshakeInfo(url, clientNegotiation.getResponseHeaders(),
Mono.empty(), Optional.ofNullable(channel.getSubProtocol()));
UndertowWebSocketHandlerAdapter adapter =
new UndertowWebSocketHandlerAdapter(handler,
info, bufferFactory, processor);
adapter.onConnect(null, channel);
}
catch (CancellationException | IOException ex) {
processor.onError(ex);
}
}
}
HandshakeInfo info = afterHandshake(url, negotiation.getResponseHeaders());
UndertowWebSocketSession session = new UndertowWebSocketSession(channel, info, this.bufferFactory);
new UndertowWebSocketHandlerAdapter(handler, completionMono).handle(session);
}
private static final class WSClientNegotiation extends WebSocketClientNegotiation {
private static final class DefaultNegotiation extends WebSocketClientNegotiation {
private final HttpHeaders requestHeaders;
private HttpHeaders responseHeaders = new HttpHeaders();
public WSClientNegotiation(String[] subProtocols,
List<WebSocketExtension> extensions, HttpHeaders requestHeaders) {
super(Arrays.asList(subProtocols), extensions);
public DefaultNegotiation(String[] subProtocols, HttpHeaders requestHeaders) {
super(Arrays.asList(subProtocols), Collections.emptyList());
this.requestHeaders = requestHeaders;
}
public HttpHeaders getResponseHeaders() {
return this.responseHeaders;
}
@Override
public void beforeRequest(Map<String, List<String>> headers) {
requestHeaders.forEach((k, v) -> headers.put(k, v));
this.requestHeaders.forEach(headers::put);
}
@Override
public void afterRequest(Map<String, List<String>> headers) {
headers.forEach((k, v) -> responseHeaders.put(k, v));
}
public HttpHeaders getResponseHeaders() {
return responseHeaders;
headers.forEach((k, v) -> this.responseHeaders.put(k, v));
}
}

11
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClientSupport.java

@ -41,20 +41,19 @@ public class WebSocketClientSupport { @@ -41,20 +41,19 @@ public class WebSocketClientSupport {
protected final Log logger = LogFactory.getLog(getClass());
protected String[] beforeHandshake(URI url, HttpHeaders headers, WebSocketHandler handler) {
protected String[] beforeHandshake(URI url, HttpHeaders requestHeaders, WebSocketHandler handler) {
if (logger.isDebugEnabled()) {
logger.debug("Executing handshake to " + url);
}
return handler.getSubProtocols();
}
protected HandshakeInfo afterHandshake(URI url, int statusCode, HttpHeaders headers) {
Assert.isTrue(statusCode == 101);
protected HandshakeInfo afterHandshake(URI url, HttpHeaders responseHeaders) {
if (logger.isDebugEnabled()) {
logger.debug("Handshake response: " + url + ", " + headers);
logger.debug("Handshake response: " + url + ", " + responseHeaders);
}
String protocol = headers.getFirst(SEC_WEBSOCKET_PROTOCOL);
return new HandshakeInfo(url, headers, Mono.empty(), Optional.ofNullable(protocol));
String protocol = responseHeaders.getFirst(SEC_WEBSOCKET_PROTOCOL);
return new HandshakeInfo(url, responseHeaders, Mono.empty(), Optional.ofNullable(protocol));
}
}

10
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java

@ -37,6 +37,7 @@ import org.springframework.util.Assert; @@ -37,6 +37,7 @@ import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.server.ServerWebExchange;
@ -118,9 +119,12 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life @@ -118,9 +119,12 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
HttpServletRequest servletRequest = getHttpServletRequest(request);
HttpServletResponse servletResponse = getHttpServletResponse(response);
HandshakeInfo info = getHandshakeInfo(exchange, subProtocol);
DataBufferFactory factory = response.bufferFactory();
JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter(handler, info, factory);
JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter(handler,
null, session -> {
HandshakeInfo info = getHandshakeInfo(exchange, subProtocol);
DataBufferFactory factory = response.bufferFactory();
return new JettyWebSocketSession(session, info, factory);
});
startLazily(servletRequest);

10
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java

@ -38,6 +38,7 @@ import org.springframework.util.Assert; @@ -38,6 +38,7 @@ import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.StandardWebSocketSession;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.server.ServerWebExchange;
@ -63,9 +64,12 @@ public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy { @@ -63,9 +64,12 @@ public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy {
HttpServletRequest servletRequest = getHttpServletRequest(request);
HttpServletResponse servletResponse = getHttpServletResponse(response);
HandshakeInfo info = getHandshakeInfo(exchange, subProtocol);
DataBufferFactory factory = response.bufferFactory();
Endpoint endpoint = new StandardWebSocketHandlerAdapter(handler, info, factory).getEndpoint();
Endpoint endpoint = new StandardWebSocketHandlerAdapter(handler,
session -> {
HandshakeInfo info = getHandshakeInfo(exchange, subProtocol);
DataBufferFactory factory = response.bufferFactory();
return new StandardWebSocketSession(session, info, factory);
});
String requestURI = servletRequest.getRequestURI();
DefaultServerEndpointConfig config = new DefaultServerEndpointConfig(requestURI, endpoint);

57
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java

@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
package org.springframework.web.reactive.socket.server.upgrade;
import java.net.URI;
import java.security.Principal;
import java.util.Collections;
import java.util.List;
@ -25,18 +26,21 @@ import java.util.Set; @@ -25,18 +26,21 @@ import java.util.Set;
import io.undertow.server.HttpServerExchange;
import io.undertow.websockets.WebSocketConnectionCallback;
import io.undertow.websockets.WebSocketProtocolHandshakeHandler;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.protocol.Handshake;
import io.undertow.websockets.core.protocol.version13.Hybi13Handshake;
import io.undertow.websockets.spi.WebSocketHttpExchange;
import reactor.core.publisher.Mono;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.UndertowServerHttpRequest;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.server.ServerWebExchange;
@ -55,22 +59,15 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy { @@ -55,22 +59,15 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy {
Optional<String> subProtocol) {
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
HandshakeInfo info = getHandshakeInfo(exchange, subProtocol);
DataBufferFactory bufferFactory = response.bufferFactory();
Assert.isTrue(request instanceof UndertowServerHttpRequest);
HttpServerExchange httpExchange = ((UndertowServerHttpRequest) request).getUndertowExchange();
WebSocketConnectionCallback callback =
new UndertowWebSocketHandlerAdapter(handler, info, bufferFactory);
Set<String> protocols = subProtocol.map(Collections::singleton).orElse(Collections.emptySet());
Hybi13Handshake handshake = new Hybi13Handshake(protocols, false);
List<Handshake> handshakes = Collections.singletonList(handshake);
try {
DefaultCallback callback = new DefaultCallback(exchange, handler, subProtocol);
new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange);
}
catch (Exception ex) {
@ -80,10 +77,44 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy { @@ -80,10 +77,44 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy {
return Mono.empty();
}
private HandshakeInfo getHandshakeInfo(ServerWebExchange exchange, Optional<String> protocol) {
ServerHttpRequest request = exchange.getRequest();
Mono<Principal> principal = exchange.getPrincipal();
return new HandshakeInfo(request.getURI(), request.getHeaders(), principal, protocol);
private class DefaultCallback implements WebSocketConnectionCallback {
private final ServerWebExchange exchange;
private final WebSocketHandler handler;
private final Optional<String> subProtocol;
public DefaultCallback(ServerWebExchange exchange, WebSocketHandler handler,
Optional<String> subProtocol) {
this.exchange = exchange;
this.handler = handler;
this.subProtocol = subProtocol;
}
@Override
public void onConnect(WebSocketHttpExchange httpExchange, WebSocketChannel channel) {
UndertowWebSocketHandlerAdapter adapter = new UndertowWebSocketHandlerAdapter(this.handler);
UndertowWebSocketSession session = createWebSocketSession(channel);
adapter.handle(session);
}
private UndertowWebSocketSession createWebSocketSession(WebSocketChannel channel) {
HandshakeInfo info = getHandshakeInfo();
DataBufferFactory bufferFactory = this.exchange.getResponse().bufferFactory();
return new UndertowWebSocketSession(channel, info, bufferFactory);
}
private HandshakeInfo getHandshakeInfo() {
ServerHttpRequest request = this.exchange.getRequest();
URI url = request.getURI();
HttpHeaders headers = request.getHeaders();
Mono<Principal> principal = this.exchange.getPrincipal();
return new HandshakeInfo(url, headers, principal, this.subProtocol);
}
}
}

Loading…
Cancel
Save