Browse Source

Refactor state management in StompSubProtocolHandler

Closes gh-35591
pull/35672/head
rstoyanchev 2 months ago
parent
commit
c88bfc54c9
  1. 100
      spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
  2. 9
      spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java
  3. 5
      spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java

100
spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

@ -25,6 +25,7 @@ import java.util.Map; @@ -25,6 +25,7 @@ import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -108,10 +109,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -108,10 +109,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
@Nullable
private MessageHeaderInitializer headerInitializer;
@Nullable
private Map<String, MessageChannel> orderedHandlingMessageChannels;
private final Map<String, SessionInfo> sessions = new ConcurrentHashMap<>();
private final Map<String, Principal> stompAuthentications = new ConcurrentHashMap<>();
private boolean preserveReceiveOrder;
@Nullable
private Boolean immutableMessageInterceptorPresent;
@ -208,7 +208,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -208,7 +208,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
* @since 6.1
*/
public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
this.orderedHandlingMessageChannels = (preserveReceiveOrder ? new ConcurrentHashMap<>() : null);
this.preserveReceiveOrder = preserveReceiveOrder;
}
/**
@ -217,7 +217,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -217,7 +217,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
* @since 6.1
*/
public boolean isPreserveReceiveOrder() {
return (this.orderedHandlingMessageChannels != null);
return this.preserveReceiveOrder;
}
@Override
@ -252,7 +252,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -252,7 +252,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
*/
@Override
public void handleMessageFromClient(WebSocketSession session,
WebSocketMessage<?> webSocketMessage, MessageChannel targetChannel) {
WebSocketMessage<?> webSocketMessage, MessageChannel channel) {
List<Message<byte[]>> messages;
try {
@ -295,35 +295,36 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -295,35 +295,36 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
return;
}
MessageChannel channelToUse = targetChannel;
if (this.orderedHandlingMessageChannels != null) {
channelToUse = this.orderedHandlingMessageChannels.computeIfAbsent(
session.getId(), id -> new OrderedMessageChannelDecorator(targetChannel, logger));
}
SessionInfo info = this.sessions.get(session.getId());
MessageChannel channelToUse = (info != null ? info.getMessageChannelToUse() : null);
for (Message<byte[]> message : messages) {
StompHeaderAccessor headerAccessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
StompHeaderAccessor headerAccessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
Assert.state(headerAccessor != null, "No StompHeaderAccessor");
StompCommand command = headerAccessor.getCommand();
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
boolean isConnect = (StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command));
String sessionId = session.getId();
boolean sent = false;
try {
if (isConnect) {
channelToUse = (this.preserveReceiveOrder ? new OrderedMessageChannelDecorator(channel, logger) : channel);
info = new SessionInfo(channelToUse, session.getPrincipal());
SessionInfo prevInfo = this.sessions.putIfAbsent(sessionId, info);
Assert.state(prevInfo == null, "Session already exists");
headerAccessor.setUserChangeCallback(info);
}
else {
Assert.state(channelToUse != null, "Unknown session: " + sessionId);
}
headerAccessor.setSessionId(session.getId());
headerAccessor.setSessionId(sessionId);
headerAccessor.setSessionAttributes(session.getAttributes());
headerAccessor.setUser(getUser(session));
if (isConnect) {
headerAccessor.setUserChangeCallback(user -> {
if (user != null && user != session.getPrincipal()) {
this.stompAuthentications.put(session.getId(), user);
}
});
}
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
if (!detectImmutableMessageInterceptor(targetChannel)) {
if (!detectImmutableMessageInterceptor(channel)) {
headerAccessor.setImmutable();
}
@ -363,24 +364,29 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -363,24 +364,29 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
}
catch (Throwable ex) {
if (logger.isDebugEnabled()) {
logger.debug("Failed to send message to MessageChannel in session " + session.getId(), ex);
logger.debug("Failed to send message to MessageChannel in session " + sessionId, ex);
}
else if (logger.isErrorEnabled()) {
// Skip for unsent CONNECT or SUBSCRIBE (likely authentication/authorization issues)
if (sent || !(isConnect || StompCommand.SUBSCRIBE.equals(command))) {
logger.error("Failed to send message to MessageChannel in session " +
session.getId() + ":" + ex.getMessage());
sessionId + ":" + ex.getMessage());
}
}
handleError(session, ex, message);
}
if (!sent && isConnect) {
this.sessions.remove(sessionId);
break;
}
}
}
@Nullable
private Principal getUser(WebSocketSession session) {
Principal user = this.stompAuthentications.get(session.getId());
return (user != null ? user : session.getPrincipal());
SessionInfo info = this.sessions.get(session.getId());
return (info != null ? info.getUser() : session.getPrincipal());
}
private void handleError(WebSocketSession session, Throwable ex, @Nullable Message<byte[]> clientMessage) {
@ -685,10 +691,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -685,10 +691,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
outputChannel.send(message);
}
finally {
if (this.orderedHandlingMessageChannels != null) {
this.orderedHandlingMessageChannels.remove(session.getId());
}
this.stompAuthentications.remove(session.getId());
this.sessions.remove(session.getId());
SimpAttributesContextHolder.resetAttributes();
simpAttributes.sessionCompleted();
}
@ -718,6 +721,39 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -718,6 +721,39 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
}
private static class SessionInfo implements Consumer<Principal> {
private final MessageChannel channel;
@Nullable
private final Principal webSocketUser;
@Nullable
private volatile Principal stompUser;
SessionInfo(MessageChannel channel, @Nullable Principal user) {
this.channel = channel;
this.webSocketUser = user;
}
public MessageChannel getMessageChannelToUse() {
return this.channel;
}
@Nullable
public Principal getUser() {
return (this.stompUser != null ? this.stompUser : this.webSocketUser);
}
@Override
public void accept(@Nullable Principal stompUser) {
if (stompUser != null && stompUser != this.webSocketUser) {
this.stompUser = stompUser;
}
}
}
/**
* Contract for access to session counters.
* @since 5.2

9
spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java

@ -101,6 +101,9 @@ class WebSocketMessageBrokerConfigurationSupportTests { @@ -101,6 +101,9 @@ class WebSocketMessageBrokerConfigurationSupportTests {
session.setOpen(true);
webSocketHandler.afterConnectionEstablished(session);
webSocketHandler.handleMessage(session,
StompTextMessageBuilder.create(StompCommand.CONNECT).headers("destination:/foo").build());
webSocketHandler.handleMessage(session,
StompTextMessageBuilder.create(StompCommand.SEND).headers("destination:/foo").build());
@ -108,6 +111,12 @@ class WebSocketMessageBrokerConfigurationSupportTests { @@ -108,6 +111,12 @@ class WebSocketMessageBrokerConfigurationSupportTests {
StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
assertThat(accessor).isNotNull();
assertThat(accessor.isMutable()).isFalse();
assertThat(accessor.getMessageType()).isEqualTo(SimpMessageType.CONNECT);
message = channel.messages.get(1);
accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
assertThat(accessor).isNotNull();
assertThat(accessor.isMutable()).isFalse();
assertThat(accessor.getMessageType()).isEqualTo(SimpMessageType.MESSAGE);
assertThat(accessor.getDestination()).isEqualTo("/foo");
}

5
spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java

@ -89,9 +89,10 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -89,9 +89,10 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
super.setup(server, webSocketClient, testInfo);
TextMessage message = create(StompCommand.SEND).headers("destination:/app/simple").build();
TextMessage m1 = create(StompCommand.CONNECT).headers("accept-version:1.1").build();
TextMessage m2 = create(StompCommand.SEND).headers("destination:/app/simple").build();
try (WebSocketSession session = execute(new TestClientWebSocketHandler(0, message), "/ws").get()) {
try (WebSocketSession session = execute(new TestClientWebSocketHandler(0, m1, m2), "/ws").get()) {
assertThat(session).isNotNull();
SimpleController controller = this.wac.getBean(SimpleController.class);
assertThat(controller.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();

Loading…
Cancel
Save