diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java index ddd4226a607..9accb2c6799 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; @@ -275,12 +274,10 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec return this.messagingTemplate; } - public void send(UserDestinationResult destinationResult, Message message) throws MessagingException { - Set sessionIds = destinationResult.getSessionIds(); - Iterator itr = (sessionIds != null ? sessionIds.iterator() : null); - - for (String target : destinationResult.getTargetDestinations()) { - String sessionId = (itr != null ? itr.next() : null); + public void send(UserDestinationResult result, Message message) throws MessagingException { + Iterator itr = result.getSessionIds().iterator(); + for (String target : result.getTargetDestinations()) { + String sessionId = (itr.hasNext() ? itr.next() : null); getTemplateToUse(sessionId).send(target, message); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationResult.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationResult.java index c47607bac9a..b6190c6ebcc 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationResult.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationResult.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,7 +44,11 @@ public class UserDestinationResult { private final Set sessionIds; - public UserDestinationResult(String sourceDestination, Set targetDestinations, + /** + * Main constructor. + */ + public UserDestinationResult( + String sourceDestination, Set targetDestinations, String subscribeDestination, @Nullable String user) { this(sourceDestination, targetDestinations, subscribeDestination, user, null); @@ -113,7 +117,7 @@ public class UserDestinationResult { /** * Return the session id for the targetDestination. */ - public @Nullable Set getSessionIds() { + public Set getSessionIds() { return this.sessionIds; } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java index 0f2548f879e..e1900f09881 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.messaging.simp.user; import java.nio.charset.StandardCharsets; +import java.util.Set; import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; @@ -98,6 +99,26 @@ class UserDestinationMessageHandlerTests { assertThat(accessor.getFirstNativeHeader(ORIGINAL_DESTINATION)).isEqualTo("/user/queue/foo"); } + @Test + @SuppressWarnings("rawtypes") + void handleMessageWithoutSessionIds() { + UserDestinationResolver resolver = mock(); + Message message = createWith(SimpMessageType.MESSAGE, "joe", null, "/user/joe/queue/foo"); + UserDestinationResult result = new UserDestinationResult("/queue/foo-user123", Set.of("/queue/foo-user123"), "/user/queue/foo", "joe"); + given(resolver.resolveDestination(message)).willReturn(result); + + given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true); + UserDestinationMessageHandler handler = new UserDestinationMessageHandler(new StubMessageChannel(), this.brokerChannel, resolver); + handler.handleMessage(message); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); + Mockito.verify(this.brokerChannel).send(captor.capture()); + + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(captor.getValue()); + assertThat(accessor.getDestination()).isEqualTo("/queue/foo-user123"); + assertThat(accessor.getFirstNativeHeader(ORIGINAL_DESTINATION)).isEqualTo("/user/queue/foo"); + } + @Test @SuppressWarnings("rawtypes") void handleMessageWithoutActiveSession() { diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java index c480152c054..45aeecba96b 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ import org.springframework.core.task.SimpleAsyncTaskExecutor; import org.springframework.util.Assert; import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.async.DeferredResult.DeferredResultHandler; +import org.springframework.web.util.DisconnectedClientHelper; /** * The central class for managing asynchronous request processing, mainly intended @@ -342,6 +343,10 @@ public final class WebAsyncManager { if (logger.isDebugEnabled()) { logger.debug("Servlet container error notification for " + formatUri(this.asyncWebRequest) + ": " + ex); } + if (DisconnectedClientHelper.isClientDisconnectedException(ex)) { + ex = new AsyncRequestNotUsableException( + "Servlet container error notification for disconnected client", ex); + } Object result = interceptorChain.triggerAfterError(this.asyncWebRequest, callable, ex); result = (result != CallableProcessingInterceptor.RESULT_NONE ? result : ex); setConcurrentResultAndDispatch(result); @@ -434,6 +439,10 @@ public final class WebAsyncManager { if (logger.isDebugEnabled()) { logger.debug("Servlet container error notification for " + formatUri(this.asyncWebRequest)); } + if (DisconnectedClientHelper.isClientDisconnectedException(ex)) { + ex = new AsyncRequestNotUsableException( + "Servlet container error notification for disconnected client", ex); + } try { interceptorChain.triggerAfterError(this.asyncWebRequest, deferredResult, ex); synchronized (WebAsyncManager.this) { diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerErrorTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerErrorTests.java index 542638dcbf2..d11d6714c0a 100644 --- a/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerErrorTests.java +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerErrorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.web.context.request.async; +import java.io.IOException; import java.util.concurrent.Callable; import jakarta.servlet.AsyncEvent; @@ -152,6 +153,22 @@ class WebAsyncManagerErrorTests { verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, callable); } + @Test // gh-34363 + void startCallableProcessingDisconnectedClient() throws Exception { + StubCallable callable = new StubCallable(); + this.asyncManager.startCallableProcessing(callable); + + IOException ex = new IOException("broken pipe"); + AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), ex); + this.asyncWebRequest.onError(event); + + MockAsyncContext asyncContext = (MockAsyncContext) this.servletRequest.getAsyncContext(); + assertThat(this.asyncManager.hasConcurrentResult()).isTrue(); + assertThat(this.asyncManager.getConcurrentResult()) + .as("Disconnected client error not wrapped AsyncRequestNotUsableException") + .isOfAnyClassIn(AsyncRequestNotUsableException.class); + } + @Test void startDeferredResultProcessingErrorAndComplete() throws Exception { @@ -259,6 +276,21 @@ class WebAsyncManagerErrorTests { assertThat(((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()).isEqualTo("/test"); } + @Test // gh-34363 + void startDeferredResultProcessingDisconnectedClient() throws Exception { + DeferredResult deferredResult = new DeferredResult<>(); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + IOException ex = new IOException("broken pipe"); + AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), ex); + this.asyncWebRequest.onError(event); + + assertThat(this.asyncManager.hasConcurrentResult()).isTrue(); + assertThat(deferredResult.getResult()) + .as("Disconnected client error not wrapped AsyncRequestNotUsableException") + .isOfAnyClassIn(AsyncRequestNotUsableException.class); + } + private static final class StubCallable implements Callable { @Override diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java index ce18760cafa..f8fda8bb8b2 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java @@ -198,23 +198,25 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle { HttpMethod method = request.getMethod(); HttpHeaders headers = request.getHeaders(); - if (HttpMethod.GET != method && CONNECT_METHOD != method) { + if (HttpMethod.GET != method && !CONNECT_METHOD.equals(method)) { return Mono.error(new MethodNotAllowedException( request.getMethod(), Set.of(HttpMethod.GET, CONNECT_METHOD))); } - if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { - return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers); - } + if (HttpMethod.GET == method) { + if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { + return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers); + } - List connectionValue = headers.getConnection(); - if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) { - return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers); - } + List connectionValue = headers.getConnection(); + if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) { + return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers); + } - String key = headers.getFirst(SEC_WEBSOCKET_KEY); - if (key == null) { - return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header"); + String key = headers.getFirst(SEC_WEBSOCKET_KEY); + if (key == null) { + return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header"); + } } String protocol = selectProtocol(headers, handler); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java index 06ebf19469d..a4876a5272a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -149,7 +149,7 @@ public class WebSocketHttpHeaders extends HttpHeaders { } /** - * Returns the value of the {@code Sec-WebSocket-Key} header. + * Returns the value of the {@code Sec-WebSocket-Protocol} header. * @return the value of the header */ public List getSecWebSocketProtocol() { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java index 237be64cf57..0b2193190e4 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java @@ -175,7 +175,7 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Life } try { HttpMethod httpMethod = request.getMethod(); - if (HttpMethod.GET != httpMethod && CONNECT_METHOD != httpMethod) { + if (HttpMethod.GET != httpMethod && !CONNECT_METHOD.equals(httpMethod)) { response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED); response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD)); if (logger.isErrorEnabled()) { @@ -183,13 +183,24 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Life } return false; } - if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { - handleInvalidUpgradeHeader(request, response); - return false; - } - if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) { - handleInvalidConnectHeader(request, response); - return false; + if (HttpMethod.GET == httpMethod) { + if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { + handleInvalidUpgradeHeader(request, response); + return false; + } + List connectionValue = headers.getConnection(); + if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) { + handleInvalidConnectHeader(request, response); + return false; + } + String key = headers.getSecWebSocketKey(); + if (key == null) { + if (logger.isErrorEnabled()) { + logger.error("Missing \"Sec-WebSocket-Key\" header"); + } + response.setStatusCode(HttpStatus.BAD_REQUEST); + return false; + } } if (!isWebSocketVersionSupported(headers)) { handleWebSocketVersionNotSupported(request, response); @@ -199,14 +210,6 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Life response.setStatusCode(HttpStatus.FORBIDDEN); return false; } - String wsKey = headers.getSecWebSocketKey(); - if (wsKey == null) { - if (logger.isErrorEnabled()) { - logger.error("Missing \"Sec-WebSocket-Key\" header"); - } - response.setStatusCode(HttpStatus.BAD_REQUEST); - return false; - } } catch (IOException ex) { throw new HandshakeFailureException(