Browse Source

Consistently extend WebSocketHandlerAdapterSupport

The WebSocketHandler adapters for all runtimes now extend
WebSocketHandlerAdapterSupport, which now also exposes
a shared DataBufferFactory property initialized from the response.

Issue: SPR-14527
pull/1256/merge
Rossen Stoyanchev 9 years ago
parent
commit
d6895aa098
  1. 25
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java
  2. 13
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/ReactorNettyWebSocketHandlerAdapter.java
  3. 12
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/RxNettyWebSocketHandlerAdapter.java
  4. 125
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketHandlerAdapter.java
  5. 50
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java
  6. 5
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java
  7. 26
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/WebSocketHandlerAdapterSupport.java
  8. 92
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java
  9. 51
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java
  10. 26
      spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java

25
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java

@ -32,9 +32,8 @@ import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription; import org.reactivestreams.Subscription;
import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.CloseStatus; import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage; import org.springframework.web.reactive.socket.WebSocketMessage;
@ -48,21 +47,17 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type;
* @since 5.0 * @since 5.0
*/ */
@WebSocket @WebSocket
public class JettyWebSocketHandlerAdapter { public class JettyWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport {
private static final ByteBuffer EMPTY_PAYLOAD = ByteBuffer.wrap(new byte[0]); private static final ByteBuffer EMPTY_PAYLOAD = ByteBuffer.wrap(new byte[0]);
private final WebSocketHandler delegate;
private JettyWebSocketSession session; private JettyWebSocketSession session;
private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false);
public JettyWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler delegate) {
public JettyWebSocketHandlerAdapter(WebSocketHandler delegate) { super(request, response, delegate);
Assert.notNull("WebSocketHandler is required");
this.delegate = delegate;
} }
@ -71,7 +66,7 @@ public class JettyWebSocketHandlerAdapter {
this.session = new JettyWebSocketSession(session); this.session = new JettyWebSocketSession(session);
HandlerResultSubscriber subscriber = new HandlerResultSubscriber(); HandlerResultSubscriber subscriber = new HandlerResultSubscriber();
this.delegate.handle(this.session).subscribe(subscriber); getDelegate().handle(this.session).subscribe(subscriber);
} }
@OnWebSocketMessage @OnWebSocketMessage
@ -105,15 +100,15 @@ public class JettyWebSocketHandlerAdapter {
private <T> WebSocketMessage toMessage(Type type, T message) { private <T> WebSocketMessage toMessage(Type type, T message) {
if (Type.TEXT.equals(type)) { if (Type.TEXT.equals(type)) {
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = this.bufferFactory.wrap(bytes); DataBuffer buffer = getBufferFactory().wrap(bytes);
return WebSocketMessage.create(Type.TEXT, buffer); return WebSocketMessage.create(Type.TEXT, buffer);
} }
else if (Type.BINARY.equals(type)) { else if (Type.BINARY.equals(type)) {
DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message); DataBuffer buffer = getBufferFactory().wrap((ByteBuffer) message);
return WebSocketMessage.create(Type.BINARY, buffer); return WebSocketMessage.create(Type.BINARY, buffer);
} }
else if (Type.PONG.equals(type)) { else if (Type.PONG.equals(type)) {
DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message); DataBuffer buffer = getBufferFactory().wrap((ByteBuffer) message);
return WebSocketMessage.create(Type.PONG, buffer); return WebSocketMessage.create(Type.PONG, buffer);
} }
else { else {

13
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/ReactorNettyWebSocketHandlerAdapter.java

@ -21,10 +21,8 @@ import org.reactivestreams.Publisher;
import reactor.ipc.netty.http.HttpInbound; import reactor.ipc.netty.http.HttpInbound;
import reactor.ipc.netty.http.HttpOutbound; import reactor.ipc.netty.http.HttpOutbound;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
/** /**
@ -38,22 +36,13 @@ public class ReactorNettyWebSocketHandlerAdapter extends WebSocketHandlerAdapter
implements BiFunction<HttpInbound, HttpOutbound, Publisher<Void>> { implements BiFunction<HttpInbound, HttpOutbound, Publisher<Void>> {
private final NettyDataBufferFactory bufferFactory;
public ReactorNettyWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response, public ReactorNettyWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler handler) { WebSocketHandler handler) {
super(request, handler); super(request, response, handler);
Assert.notNull("'response' is required");
this.bufferFactory = (NettyDataBufferFactory) response.bufferFactory();
} }
public NettyDataBufferFactory getBufferFactory() {
return this.bufferFactory;
}
@Override @Override
public Publisher<Void> apply(HttpInbound inbound, HttpOutbound outbound) { public Publisher<Void> apply(HttpInbound inbound, HttpOutbound outbound) {
ReactorNettyWebSocketSession session = ReactorNettyWebSocketSession session =

12
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/RxNettyWebSocketHandlerAdapter.java

@ -20,10 +20,8 @@ import reactor.core.publisher.Mono;
import rx.Observable; import rx.Observable;
import rx.RxReactiveStreams; import rx.RxReactiveStreams;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
/** /**
@ -36,22 +34,14 @@ import org.springframework.web.reactive.socket.WebSocketHandler;
public class RxNettyWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport public class RxNettyWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport
implements io.reactivex.netty.protocol.http.ws.server.WebSocketHandler { implements io.reactivex.netty.protocol.http.ws.server.WebSocketHandler {
private final NettyDataBufferFactory bufferFactory;
public RxNettyWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response, public RxNettyWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler handler) { WebSocketHandler handler) {
super(request, handler); super(request, response, handler);
Assert.notNull("'response' is required");
this.bufferFactory = (NettyDataBufferFactory) response.bufferFactory();
} }
public NettyDataBufferFactory getBufferFactory() {
return this.bufferFactory;
}
@Override @Override
public Observable<Void> handle(WebSocketConnection conn) { public Observable<Void> handle(WebSocketConnection conn) {
RxNettyWebSocketSession session = new RxNettyWebSocketSession(conn, getUri(), getBufferFactory()); RxNettyWebSocketSession session = new RxNettyWebSocketSession(conn, getUri(), getBufferFactory());

125
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketHandlerAdapter.java

@ -28,9 +28,8 @@ import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription; import org.reactivestreams.Subscription;
import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.CloseStatus; import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage; import org.springframework.web.reactive.socket.WebSocketMessage;
@ -43,76 +42,84 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type;
* @author Violeta Georgieva * @author Violeta Georgieva
* @since 5.0 * @since 5.0
*/ */
public class TomcatWebSocketHandlerAdapter extends Endpoint { public class TomcatWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport {
private final WebSocketHandler delegate;
private TomcatWebSocketSession session; private TomcatWebSocketSession session;
private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false);
public TomcatWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler delegate) {
public TomcatWebSocketHandlerAdapter(WebSocketHandler delegate) { super(request, response, delegate);
Assert.notNull("WebSocketHandler is required");
this.delegate = delegate;
} }
@Override public Endpoint getEndpoint() {
public void onOpen(Session session, EndpointConfig config) { return new StandardEndpoint();
this.session = new TomcatWebSocketSession(session);
session.addMessageHandler(String.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
this.session.handleMessage(webSocketMessage.getType(), webSocketMessage);
});
session.addMessageHandler(ByteBuffer.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
this.session.handleMessage(webSocketMessage.getType(), webSocketMessage);
});
session.addMessageHandler(PongMessage.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
this.session.handleMessage(webSocketMessage.getType(), webSocketMessage);
});
HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber();
this.delegate.handle(this.session).subscribe(resultSubscriber);
} }
private <T> WebSocketMessage toMessage(T message) { private TomcatWebSocketSession getSession() {
if (message instanceof String) { return this.session;
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); }
return WebSocketMessage.create(Type.TEXT, this.bufferFactory.wrap(bytes));
}
else if (message instanceof ByteBuffer) { private class StandardEndpoint extends Endpoint {
DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message);
return WebSocketMessage.create(Type.BINARY, buffer); @Override
} public void onOpen(Session session, EndpointConfig config) {
else if (message instanceof PongMessage) { TomcatWebSocketHandlerAdapter.this.session = new TomcatWebSocketSession(session);
DataBuffer buffer = this.bufferFactory.wrap(((PongMessage) message).getApplicationData());
return WebSocketMessage.create(Type.PONG, buffer); session.addMessageHandler(String.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
getSession().handleMessage(webSocketMessage.getType(), webSocketMessage);
});
session.addMessageHandler(ByteBuffer.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
getSession().handleMessage(webSocketMessage.getType(), webSocketMessage);
});
session.addMessageHandler(PongMessage.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
getSession().handleMessage(webSocketMessage.getType(), webSocketMessage);
});
HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber();
getDelegate().handle(TomcatWebSocketHandlerAdapter.this.session).subscribe(resultSubscriber);
} }
else {
throw new IllegalArgumentException("Unexpected message type: " + message); private <T> WebSocketMessage toMessage(T message) {
if (message instanceof String) {
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8);
return WebSocketMessage.create(Type.TEXT, getBufferFactory().wrap(bytes));
}
else if (message instanceof ByteBuffer) {
DataBuffer buffer = getBufferFactory().wrap((ByteBuffer) message);
return WebSocketMessage.create(Type.BINARY, buffer);
}
else if (message instanceof PongMessage) {
DataBuffer buffer = getBufferFactory().wrap(((PongMessage) message).getApplicationData());
return WebSocketMessage.create(Type.PONG, buffer);
}
else {
throw new IllegalArgumentException("Unexpected message type: " + message);
}
} }
}
@Override @Override
public void onClose(Session session, CloseReason reason) { public void onClose(Session session, CloseReason reason) {
if (this.session != null) { if (getSession() != null) {
int code = reason.getCloseCode().getCode(); int code = reason.getCloseCode().getCode();
this.session.handleClose(new CloseStatus(code, reason.getReasonPhrase())); getSession().handleClose(new CloseStatus(code, reason.getReasonPhrase()));
}
} }
}
@Override @Override
public void onError(Session session, Throwable exception) { public void onError(Session session, Throwable exception) {
if (this.session != null) { if (getSession() != null) {
this.session.handleError(exception); getSession().handleError(exception);
}
} }
} }
private final class HandlerResultSubscriber implements Subscriber<Void> { private final class HandlerResultSubscriber implements Subscriber<Void> {
@Override @Override
@ -127,15 +134,15 @@ public class TomcatWebSocketHandlerAdapter extends Endpoint {
@Override @Override
public void onError(Throwable ex) { public void onError(Throwable ex) {
if (session != null) { if (getSession() != null) {
session.close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage())); getSession().close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage()));
} }
} }
@Override @Override
public void onComplete() { public void onComplete() {
if (session != null) { if (getSession() != null) {
session.close(); getSession().close();
} }
} }
} }

50
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java

@ -16,30 +16,27 @@
package org.springframework.web.reactive.socket.adapter; package org.springframework.web.reactive.socket.adapter;
import java.net.URISyntaxException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import io.undertow.websockets.WebSocketConnectionCallback;
import io.undertow.websockets.core.AbstractReceiveListener;
import io.undertow.websockets.core.BufferedBinaryMessage;
import io.undertow.websockets.core.BufferedTextMessage;
import io.undertow.websockets.core.CloseMessage;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.spi.WebSocketHttpExchange;
import org.reactivestreams.Subscriber; import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription; import org.reactivestreams.Subscription;
import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.CloseStatus; import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage; import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketMessage.Type; import org.springframework.web.reactive.socket.WebSocketMessage.Type;
import io.undertow.websockets.WebSocketConnectionCallback;
import io.undertow.websockets.core.AbstractReceiveListener;
import io.undertow.websockets.core.BufferedBinaryMessage;
import io.undertow.websockets.core.BufferedTextMessage;
import io.undertow.websockets.core.CloseMessage;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.spi.WebSocketHttpExchange;
/** /**
* Undertow {@code WebSocketHandler} implementation adapting and * Undertow {@code WebSocketHandler} implementation adapting and
* delegating to a Spring {@link WebSocketHandler}. * delegating to a Spring {@link WebSocketHandler}.
@ -47,36 +44,27 @@ import io.undertow.websockets.spi.WebSocketHttpExchange;
* @author Violeta Georgieva * @author Violeta Georgieva
* @since 5.0 * @since 5.0
*/ */
public class UndertowWebSocketHandlerAdapter implements WebSocketConnectionCallback { public class UndertowWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport
implements WebSocketConnectionCallback {
private final WebSocketHandler delegate;
private UndertowWebSocketSession session; private UndertowWebSocketSession session;
private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false);
public UndertowWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler delegate) {
public UndertowWebSocketHandlerAdapter(WebSocketHandler delegate) { super(request, response, delegate);
Assert.notNull("WebSocketHandler is required");
this.delegate = delegate;
} }
@Override @Override
public void onConnect(WebSocketHttpExchange exchange, WebSocketChannel channel) { public void onConnect(WebSocketHttpExchange exchange, WebSocketChannel channel) {
try { this.session = new UndertowWebSocketSession(channel, getUri());
this.session = new UndertowWebSocketSession(channel);
}
catch (URISyntaxException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
channel.getReceiveSetter().set(new UndertowReceiveListener()); channel.getReceiveSetter().set(new UndertowReceiveListener());
channel.resumeReceives(); channel.resumeReceives();
HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(); HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber();
this.delegate.handle(this.session).subscribe(resultSubscriber); getDelegate().handle(this.session).subscribe(resultSubscriber);
} }
@ -114,14 +102,14 @@ public class UndertowWebSocketHandlerAdapter implements WebSocketConnectionCallb
private <T> WebSocketMessage toMessage(Type type, T message) { private <T> WebSocketMessage toMessage(Type type, T message) {
if (Type.TEXT.equals(type)) { if (Type.TEXT.equals(type)) {
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8);
return WebSocketMessage.create(Type.TEXT, bufferFactory.wrap(bytes)); return WebSocketMessage.create(Type.TEXT, getBufferFactory().wrap(bytes));
} }
else if (Type.BINARY.equals(type)) { else if (Type.BINARY.equals(type)) {
DataBuffer buffer = bufferFactory.allocateBuffer().write((ByteBuffer[]) message); DataBuffer buffer = getBufferFactory().allocateBuffer().write((ByteBuffer[]) message);
return WebSocketMessage.create(Type.BINARY, buffer); return WebSocketMessage.create(Type.BINARY, buffer);
} }
else if (Type.PONG.equals(type)) { else if (Type.PONG.equals(type)) {
DataBuffer buffer = bufferFactory.allocateBuffer().write((ByteBuffer[]) message); DataBuffer buffer = getBufferFactory().allocateBuffer().write((ByteBuffer[]) message);
return WebSocketMessage.create(Type.PONG, buffer); return WebSocketMessage.create(Type.PONG, buffer);
} }
else { else {

5
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java

@ -18,7 +18,6 @@ package org.springframework.web.reactive.socket.adapter;
import java.io.IOException; import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -43,8 +42,8 @@ import org.springframework.web.reactive.socket.WebSocketSession;
public class UndertowWebSocketSession extends AbstractListenerWebSocketSession<WebSocketChannel> { public class UndertowWebSocketSession extends AbstractListenerWebSocketSession<WebSocketChannel> {
public UndertowWebSocketSession(WebSocketChannel channel) throws URISyntaxException { public UndertowWebSocketSession(WebSocketChannel channel, URI url) {
super(channel, ObjectUtils.getIdentityHexString(channel), new URI(channel.getUrl())); super(channel, ObjectUtils.getIdentityHexString(channel), url);
} }

26
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/WebSocketHandlerAdapterSupport.java

@ -17,12 +17,15 @@ package org.springframework.web.reactive.socket.adapter;
import java.net.URI; import java.net.URI;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
/** /**
* Base class for {@link WebSocketHandler} implementations. * Base class for {@link WebSocketHandler} adapters to underlying WebSocket
* handler APIs.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 5.0 * @since 5.0
@ -33,21 +36,32 @@ public abstract class WebSocketHandlerAdapterSupport {
private final WebSocketHandler delegate; private final WebSocketHandler delegate;
private final DataBufferFactory bufferFactory;
protected WebSocketHandlerAdapterSupport(ServerHttpRequest request, WebSocketHandler handler) {
Assert.notNull("'request' is required"); protected WebSocketHandlerAdapterSupport(ServerHttpRequest request, ServerHttpResponse response,
Assert.notNull("'handler' handler is required"); WebSocketHandler handler) {
Assert.notNull("ServerHttpRequest is required");
Assert.notNull("ServerHttpResponse is required");
Assert.notNull("WebSocketHandler handler is required");
this.uri = request.getURI(); this.uri = request.getURI();
this.bufferFactory = response.bufferFactory();
this.delegate = handler; this.delegate = handler;
} }
public URI getUri() { protected URI getUri() {
return this.uri; return this.uri;
} }
public WebSocketHandler getDelegate() { protected WebSocketHandler getDelegate() {
return this.delegate; return this.delegate;
} }
@SuppressWarnings("unchecked")
protected <T extends DataBufferFactory> T getBufferFactory() {
return (T) this.bufferFactory;
}
} }

92
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java

@ -17,16 +17,14 @@
package org.springframework.web.reactive.socket.server.upgrade; package org.springframework.web.reactive.socket.server.upgrade;
import java.io.IOException; import java.io.IOException;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.util.DecoratedObjectFactory; import org.eclipse.jetty.util.DecoratedObjectFactory;
import org.eclipse.jetty.websocket.server.WebSocketServerFactory; import org.eclipse.jetty.websocket.server.WebSocketServerFactory;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; import reactor.core.publisher.Mono;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
import org.springframework.context.Lifecycle; import org.springframework.context.Lifecycle;
import org.springframework.core.NamedThreadLocal; import org.springframework.core.NamedThreadLocal;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
@ -39,8 +37,6 @@ import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdap
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
/** /**
* A {@link RequestUpgradeStrategy} for use with Jetty. * A {@link RequestUpgradeStrategy} for use with Jetty.
* *
@ -52,43 +48,13 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
private static final ThreadLocal<JettyWebSocketHandlerAdapter> wsContainerHolder = private static final ThreadLocal<JettyWebSocketHandlerAdapter> wsContainerHolder =
new NamedThreadLocal<>("Jetty WebSocketHandler Adapter"); new NamedThreadLocal<>("Jetty WebSocketHandler Adapter");
private WebSocketServerFactory factory; private WebSocketServerFactory factory;
private ServletContext servletContext; private ServletContext servletContext;
private volatile boolean running = false; private volatile boolean running = false;
@Override
public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler webSocketHandler) {
JettyWebSocketHandlerAdapter adapter =
new JettyWebSocketHandlerAdapter(webSocketHandler);
HttpServletRequest servletRequest = getHttpServletRequest(exchange.getRequest());
HttpServletResponse servletResponse = getHttpServletResponse(exchange.getResponse());
if (this.servletContext == null) {
this.servletContext = servletRequest.getServletContext();
servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory());
}
try {
start();
Assert.isTrue(this.factory.isUpgradeRequest(servletRequest, servletResponse), "Not a WebSocket handshake");
wsContainerHolder.set(adapter);
this.factory.acceptWebSocket(servletRequest, servletResponse);
}
catch (IOException ex) {
return Mono.error(ex);
}
finally {
wsContainerHolder.remove();
}
return Mono.empty();
}
@Override @Override
public void start() { public void start() {
@ -96,16 +62,10 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
this.running = true; this.running = true;
try { try {
this.factory = new WebSocketServerFactory(this.servletContext); this.factory = new WebSocketServerFactory(this.servletContext);
this.factory.setCreator(new WebSocketCreator() { this.factory.setCreator((request, response) -> {
JettyWebSocketHandlerAdapter adapter = wsContainerHolder.get();
@Override Assert.state(adapter != null, "Expected JettyWebSocketHandlerAdapter");
public Object createWebSocket(ServletUpgradeRequest req, return adapter;
ServletUpgradeResponse resp) {
JettyWebSocketHandlerAdapter adapter = wsContainerHolder.get();
Assert.state(adapter != null, "Expected JettyWebSocketHandlerAdapter");
return adapter;
}
}); });
this.factory.start(); this.factory.start();
} }
@ -133,12 +93,46 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
return this.running; return this.running;
} }
private final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) { @Override
public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler) {
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter(request, response, handler);
HttpServletRequest servletRequest = getHttpServletRequest(request);
HttpServletResponse servletResponse = getHttpServletResponse(response);
if (this.servletContext == null) {
this.servletContext = servletRequest.getServletContext();
this.servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory());
}
try {
start();
Assert.isTrue(this.factory.isUpgradeRequest(
servletRequest, servletResponse), "Not a WebSocket handshake");
wsContainerHolder.set(adapter);
this.factory.acceptWebSocket(servletRequest, servletResponse);
}
catch (IOException ex) {
return Mono.error(ex);
}
finally {
wsContainerHolder.remove();
}
return Mono.empty();
}
private HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
Assert.isTrue(request instanceof ServletServerHttpRequest); Assert.isTrue(request instanceof ServletServerHttpRequest);
return ((ServletServerHttpRequest) request).getServletRequest(); return ((ServletServerHttpRequest) request).getServletRequest();
} }
private final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) { private HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
Assert.isTrue(response instanceof ServletServerHttpResponse); Assert.isTrue(response instanceof ServletServerHttpResponse);
return ((ServletServerHttpResponse) response).getServletResponse(); return ((ServletServerHttpResponse) response).getServletResponse();
} }

51
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java

@ -24,6 +24,8 @@ import javax.servlet.ServletContext;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.websocket.Endpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.apache.tomcat.websocket.server.WsServerContainer; import org.apache.tomcat.websocket.server.WsServerContainer;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
@ -50,45 +52,46 @@ public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy {
@Override @Override
public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler webSocketHandler){ public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler){
TomcatWebSocketHandlerAdapter endpoint = ServerHttpRequest request = exchange.getRequest();
new TomcatWebSocketHandlerAdapter(webSocketHandler); ServerHttpResponse response = exchange.getResponse();
Endpoint endpoint = new TomcatWebSocketHandlerAdapter(request, response, handler).getEndpoint();
HttpServletRequest servletRequest = getHttpServletRequest(exchange.getRequest()); HttpServletRequest servletRequest = getHttpServletRequest(request);
HttpServletResponse servletResponse = getHttpServletResponse(exchange.getResponse()); HttpServletResponse servletResponse = getHttpServletResponse(response);
Map<String, String> pathParams = Collections.<String, String> emptyMap(); String requestURI = servletRequest.getRequestURI();
ServerEndpointConfig config = new ServerEndpointRegistration(requestURI, endpoint);
ServerEndpointRegistration sec =
new ServerEndpointRegistration(servletRequest.getRequestURI(), endpoint);
try { try {
getContainer(servletRequest).doUpgrade(servletRequest, servletResponse, WsServerContainer container = getContainer(servletRequest);
sec, pathParams); container.doUpgrade(servletRequest, servletResponse, config, Collections.emptyMap());
} }
catch (ServletException | IOException e) { catch (ServletException | IOException ex) {
return Mono.error(e); return Mono.error(ex);
} }
return Mono.empty(); return Mono.empty();
} }
private WsServerContainer getContainer(HttpServletRequest request) { private HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
ServletContext servletContext = request.getServletContext();
Object container = servletContext.getAttribute(SERVER_CONTAINER_ATTR);
Assert.notNull(container, "No '" + SERVER_CONTAINER_ATTR + "' ServletContext attribute. " +
"Are you running in a Servlet container that supports JSR-356?");
Assert.isTrue(container instanceof WsServerContainer);
return (WsServerContainer) container;
}
private final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
Assert.isTrue(request instanceof ServletServerHttpRequest); Assert.isTrue(request instanceof ServletServerHttpRequest);
return ((ServletServerHttpRequest) request).getServletRequest(); return ((ServletServerHttpRequest) request).getServletRequest();
} }
private final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) { private HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
Assert.isTrue(response instanceof ServletServerHttpResponse); Assert.isTrue(response instanceof ServletServerHttpResponse);
return ((ServletServerHttpResponse) response).getServletResponse(); return ((ServletServerHttpResponse) response).getServletResponse();
} }
private WsServerContainer getContainer(HttpServletRequest request) {
ServletContext servletContext = request.getServletContext();
Object container = servletContext.getAttribute(SERVER_CONTAINER_ATTR);
Assert.notNull(container,
"No 'javax.websocket.server.ServerContainer' ServletContext attribute. " +
"Are you running in a Servlet container that supports JSR-356?");
Assert.isTrue(container instanceof WsServerContainer);
return (WsServerContainer) container;
}
} }

26
spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java

@ -17,6 +17,7 @@
package org.springframework.web.reactive.socket.server.upgrade; package org.springframework.web.reactive.socket.server.upgrade;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.UndertowServerHttpRequest; import org.springframework.http.server.reactive.UndertowServerHttpRequest;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
@ -25,6 +26,7 @@ import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import io.undertow.server.HttpServerExchange; import io.undertow.server.HttpServerExchange;
import io.undertow.websockets.WebSocketConnectionCallback;
import io.undertow.websockets.WebSocketProtocolHandshakeHandler; import io.undertow.websockets.WebSocketProtocolHandshakeHandler;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -37,27 +39,23 @@ import reactor.core.publisher.Mono;
public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy { public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy {
@Override @Override
public Mono<Void> upgrade(ServerWebExchange exchange, public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler) {
WebSocketHandler webSocketHandler) {
UndertowWebSocketHandlerAdapter callback = ServerHttpRequest request = exchange.getRequest();
new UndertowWebSocketHandlerAdapter(webSocketHandler); ServerHttpResponse response = exchange.getResponse();
WebSocketConnectionCallback callback = new UndertowWebSocketHandlerAdapter(request, response, handler);
Assert.isTrue(request instanceof UndertowServerHttpRequest);
HttpServerExchange httpExchange = ((UndertowServerHttpRequest) request).getUndertowExchange();
WebSocketProtocolHandshakeHandler handler =
new WebSocketProtocolHandshakeHandler(callback);
try { try {
handler.handleRequest(getUndertowExchange(exchange.getRequest())); new WebSocketProtocolHandshakeHandler(callback).handleRequest(httpExchange);
} }
catch (Exception e) { catch (Exception ex) {
return Mono.error(e); return Mono.error(ex);
} }
return Mono.empty(); return Mono.empty();
} }
private final HttpServerExchange getUndertowExchange(ServerHttpRequest request) {
Assert.isTrue(request instanceof UndertowServerHttpRequest);
return ((UndertowServerHttpRequest) request).getUndertowExchange();
}
} }

Loading…
Cancel
Save