From c88bfc54c9256e3c07511aa303be2b136c03e0e2 Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Thu, 9 Oct 2025 14:03:31 +0100 Subject: [PATCH] Refactor state management in StompSubProtocolHandler Closes gh-35591 --- .../messaging/StompSubProtocolHandler.java | 100 ++++++++++++------ ...essageBrokerConfigurationSupportTests.java | 9 ++ .../StompWebSocketIntegrationTests.java | 5 +- 3 files changed, 80 insertions(+), 34 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 4254c4d21cb..3937ce7714b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -108,10 +109,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @Nullable private MessageHeaderInitializer headerInitializer; - @Nullable - private Map orderedHandlingMessageChannels; + private final Map sessions = new ConcurrentHashMap<>(); - private final Map stompAuthentications = new ConcurrentHashMap<>(); + private boolean preserveReceiveOrder; @Nullable private Boolean immutableMessageInterceptorPresent; @@ -208,7 +208,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE * @since 6.1 */ public void setPreserveReceiveOrder(boolean preserveReceiveOrder) { - this.orderedHandlingMessageChannels = (preserveReceiveOrder ? new ConcurrentHashMap<>() : null); + this.preserveReceiveOrder = preserveReceiveOrder; } /** @@ -217,7 +217,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE * @since 6.1 */ public boolean isPreserveReceiveOrder() { - return (this.orderedHandlingMessageChannels != null); + return this.preserveReceiveOrder; } @Override @@ -252,7 +252,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE */ @Override public void handleMessageFromClient(WebSocketSession session, - WebSocketMessage webSocketMessage, MessageChannel targetChannel) { + WebSocketMessage webSocketMessage, MessageChannel channel) { List> messages; try { @@ -295,35 +295,36 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE return; } - MessageChannel channelToUse = targetChannel; - if (this.orderedHandlingMessageChannels != null) { - channelToUse = this.orderedHandlingMessageChannels.computeIfAbsent( - session.getId(), id -> new OrderedMessageChannelDecorator(targetChannel, logger)); - } + SessionInfo info = this.sessions.get(session.getId()); + MessageChannel channelToUse = (info != null ? info.getMessageChannelToUse() : null); for (Message message : messages) { - StompHeaderAccessor headerAccessor = - MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + StompHeaderAccessor headerAccessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); Assert.state(headerAccessor != null, "No StompHeaderAccessor"); StompCommand command = headerAccessor.getCommand(); - boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command); - + boolean isConnect = (StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command)); + String sessionId = session.getId(); boolean sent = false; + try { + if (isConnect) { + channelToUse = (this.preserveReceiveOrder ? new OrderedMessageChannelDecorator(channel, logger) : channel); + info = new SessionInfo(channelToUse, session.getPrincipal()); + SessionInfo prevInfo = this.sessions.putIfAbsent(sessionId, info); + Assert.state(prevInfo == null, "Session already exists"); + headerAccessor.setUserChangeCallback(info); + } + else { + Assert.state(channelToUse != null, "Unknown session: " + sessionId); + } - headerAccessor.setSessionId(session.getId()); + headerAccessor.setSessionId(sessionId); headerAccessor.setSessionAttributes(session.getAttributes()); headerAccessor.setUser(getUser(session)); - if (isConnect) { - headerAccessor.setUserChangeCallback(user -> { - if (user != null && user != session.getPrincipal()) { - this.stompAuthentications.put(session.getId(), user); - } - }); - } headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat()); - if (!detectImmutableMessageInterceptor(targetChannel)) { + + if (!detectImmutableMessageInterceptor(channel)) { headerAccessor.setImmutable(); } @@ -363,24 +364,29 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } catch (Throwable ex) { if (logger.isDebugEnabled()) { - logger.debug("Failed to send message to MessageChannel in session " + session.getId(), ex); + logger.debug("Failed to send message to MessageChannel in session " + sessionId, ex); } else if (logger.isErrorEnabled()) { // Skip for unsent CONNECT or SUBSCRIBE (likely authentication/authorization issues) if (sent || !(isConnect || StompCommand.SUBSCRIBE.equals(command))) { logger.error("Failed to send message to MessageChannel in session " + - session.getId() + ":" + ex.getMessage()); + sessionId + ":" + ex.getMessage()); } } handleError(session, ex, message); } + + if (!sent && isConnect) { + this.sessions.remove(sessionId); + break; + } } } @Nullable private Principal getUser(WebSocketSession session) { - Principal user = this.stompAuthentications.get(session.getId()); - return (user != null ? user : session.getPrincipal()); + SessionInfo info = this.sessions.get(session.getId()); + return (info != null ? info.getUser() : session.getPrincipal()); } private void handleError(WebSocketSession session, Throwable ex, @Nullable Message clientMessage) { @@ -685,10 +691,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE outputChannel.send(message); } finally { - if (this.orderedHandlingMessageChannels != null) { - this.orderedHandlingMessageChannels.remove(session.getId()); - } - this.stompAuthentications.remove(session.getId()); + this.sessions.remove(session.getId()); SimpAttributesContextHolder.resetAttributes(); simpAttributes.sessionCompleted(); } @@ -718,6 +721,39 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } + private static class SessionInfo implements Consumer { + + private final MessageChannel channel; + + @Nullable + private final Principal webSocketUser; + + @Nullable + private volatile Principal stompUser; + + SessionInfo(MessageChannel channel, @Nullable Principal user) { + this.channel = channel; + this.webSocketUser = user; + } + + public MessageChannel getMessageChannelToUse() { + return this.channel; + } + + @Nullable + public Principal getUser() { + return (this.stompUser != null ? this.stompUser : this.webSocketUser); + } + + @Override + public void accept(@Nullable Principal stompUser) { + if (stompUser != null && stompUser != this.webSocketUser) { + this.stompUser = stompUser; + } + } + } + + /** * Contract for access to session counters. * @since 5.2 diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java index a8708700538..2bce05c4382 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java @@ -101,6 +101,9 @@ class WebSocketMessageBrokerConfigurationSupportTests { session.setOpen(true); webSocketHandler.afterConnectionEstablished(session); + webSocketHandler.handleMessage(session, + StompTextMessageBuilder.create(StompCommand.CONNECT).headers("destination:/foo").build()); + webSocketHandler.handleMessage(session, StompTextMessageBuilder.create(StompCommand.SEND).headers("destination:/foo").build()); @@ -108,6 +111,12 @@ class WebSocketMessageBrokerConfigurationSupportTests { StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); assertThat(accessor).isNotNull(); assertThat(accessor.isMutable()).isFalse(); + assertThat(accessor.getMessageType()).isEqualTo(SimpMessageType.CONNECT); + + message = channel.messages.get(1); + accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + assertThat(accessor).isNotNull(); + assertThat(accessor.isMutable()).isFalse(); assertThat(accessor.getMessageType()).isEqualTo(SimpMessageType.MESSAGE); assertThat(accessor.getDestination()).isEqualTo("/foo"); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java index 8c45f4001c0..dcc8b0f97ed 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java @@ -89,9 +89,10 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { super.setup(server, webSocketClient, testInfo); - TextMessage message = create(StompCommand.SEND).headers("destination:/app/simple").build(); + TextMessage m1 = create(StompCommand.CONNECT).headers("accept-version:1.1").build(); + TextMessage m2 = create(StompCommand.SEND).headers("destination:/app/simple").build(); - try (WebSocketSession session = execute(new TestClientWebSocketHandler(0, message), "/ws").get()) { + try (WebSocketSession session = execute(new TestClientWebSocketHandler(0, m1, m2), "/ws").get()) { assertThat(session).isNotNull(); SimpleController controller = this.wac.getBean(SimpleController.class); assertThat(controller.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();