diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java index f898083b586..41427095e75 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java @@ -37,6 +37,7 @@ import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdap import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.server.ServerWebExchange; + /** * A {@link RequestUpgradeStrategy} for use with Jetty. * @@ -45,52 +46,58 @@ import org.springframework.web.server.ServerWebExchange; */ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Lifecycle { - private static final ThreadLocal wsContainerHolder = - new NamedThreadLocal<>("Jetty WebSocketHandler Adapter"); + private static final ThreadLocal adapterHolder = + new NamedThreadLocal<>("JettyWebSocketHandlerAdapter"); private WebSocketServerFactory factory; private ServletContext servletContext; - private volatile boolean running = false; + private boolean running = false; + + private final Object lifecycleMonitor = new Object(); @Override public void start() { - if (!isRunning() && this.servletContext != null) { - this.running = true; - try { - this.factory = new WebSocketServerFactory(this.servletContext); - this.factory.setCreator((request, response) -> { - JettyWebSocketHandlerAdapter adapter = wsContainerHolder.get(); - Assert.state(adapter != null, "Expected JettyWebSocketHandlerAdapter"); - return adapter; - }); - this.factory.start(); - } - catch (Exception ex) { - throw new IllegalStateException("Unable to start Jetty WebSocketServerFactory", ex); + synchronized (this.lifecycleMonitor) { + if (!isRunning() && this.servletContext != null) { + this.running = true; + try { + this.factory = new WebSocketServerFactory(this.servletContext); + this.factory.setCreator((request, response) -> adapterHolder.get()); + this.factory.start(); + } + catch (Exception ex) { + throw new IllegalStateException("Unable to start WebSocketServerFactory", ex); + } } } } @Override public void stop() { - if (isRunning()) { - this.running = false; - try { - this.factory.stop(); - } - catch (Exception ex) { - throw new IllegalStateException("Unable to stop Jetty WebSocketServerFactory", ex); + synchronized (this.lifecycleMonitor) { + if (isRunning()) { + try { + this.factory.stop(); + } + catch (Exception ex) { + throw new IllegalStateException("Failed to stop WebSocketServerFactory", ex); + } + finally { + this.running = false; + } } } } @Override public boolean isRunning() { - return this.running; + synchronized (this.lifecycleMonitor) { + return this.running; + } } @Override @@ -103,25 +110,20 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life HttpServletRequest servletRequest = getHttpServletRequest(request); HttpServletResponse servletResponse = getHttpServletResponse(response); - if (this.servletContext == null) { - this.servletContext = servletRequest.getServletContext(); - this.servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory()); - } - - try { - start(); + startLazily(servletRequest); - Assert.isTrue(this.factory.isUpgradeRequest( - servletRequest, servletResponse), "Not a WebSocket handshake"); + boolean isUpgrade = this.factory.isUpgradeRequest(servletRequest, servletResponse); + Assert.isTrue(isUpgrade, "Not a WebSocket handshake"); - wsContainerHolder.set(adapter); + try { + adapterHolder.set(adapter); this.factory.acceptWebSocket(servletRequest, servletResponse); } catch (IOException ex) { return Mono.error(ex); } finally { - wsContainerHolder.remove(); + adapterHolder.remove(); } return Mono.empty(); @@ -137,4 +139,17 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life return ((ServletServerHttpResponse) response).getServletResponse(); } + private void startLazily(HttpServletRequest request) { + if (this.servletContext != null) { + return; + } + synchronized (this.lifecycleMonitor) { + if (this.servletContext == null) { + this.servletContext = request.getServletContext(); + this.servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory()); + start(); + } + } + } + } diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/AbstractWebSocketHandlerIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/AbstractWebSocketHandlerIntegrationTests.java index 17d7bb49702..ff46885195d 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/AbstractWebSocketHandlerIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/AbstractWebSocketHandlerIntegrationTests.java @@ -66,7 +66,7 @@ public abstract class AbstractWebSocketHandlerIntegrationTests { public Class handlerAdapterConfigClass; - @Parameters + @Parameters(name = "server [{0}]") public static Object[][] arguments() { File base = new File(System.getProperty("java.io.tmpdir")); return new Object[][] { diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/BasicWebSocketHandlerIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/BasicWebSocketHandlerIntegrationTests.java index 8ea6af72279..09264b54523 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/BasicWebSocketHandlerIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/socket/server/BasicWebSocketHandlerIntegrationTests.java @@ -20,7 +20,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.reactivex.netty.protocol.http.client.HttpClient; @@ -66,7 +65,11 @@ public class BasicWebSocketHandlerIntegrationTests extends AbstractWebSocketHand .mergeWith(conn.getInput()) ) .take(10) - .map(frame -> frame.content().toString(StandardCharsets.UTF_8)) + .map(frame -> { + String text = frame.content().toString(StandardCharsets.UTF_8); + frame.release(); + return text; + }) .toList().toBlocking().first(); List expected = messages.toList().toBlocking().first(); assertEquals(expected, actual);