Browse Source

Merge branch '6.2.x'

pull/35649/head
rstoyanchev 2 months ago
parent
commit
f15c12a190
  1. 96
      spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
  2. 9
      spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java
  3. 5
      spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java

96
spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

@ -25,6 +25,7 @@ import java.util.Map; @@ -25,6 +25,7 @@ import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -106,9 +107,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -106,9 +107,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
private @Nullable MessageHeaderInitializer headerInitializer;
private @Nullable Map<String, MessageChannel> orderedHandlingMessageChannels;
private final Map<String, SessionInfo> sessions = new ConcurrentHashMap<>();
private final Map<String, Principal> stompAuthentications = new ConcurrentHashMap<>();
private boolean preserveReceiveOrder;
private @Nullable Boolean immutableMessageInterceptorPresent;
@ -201,7 +202,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -201,7 +202,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
* @since 6.1
*/
public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
this.orderedHandlingMessageChannels = (preserveReceiveOrder ? new ConcurrentHashMap<>() : null);
this.preserveReceiveOrder = preserveReceiveOrder;
}
/**
@ -210,7 +211,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -210,7 +211,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
* @since 6.1
*/
public boolean isPreserveReceiveOrder() {
return (this.orderedHandlingMessageChannels != null);
return this.preserveReceiveOrder;
}
@Override
@ -245,7 +246,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -245,7 +246,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
*/
@Override
public void handleMessageFromClient(WebSocketSession session,
WebSocketMessage<?> webSocketMessage, MessageChannel targetChannel) {
WebSocketMessage<?> webSocketMessage, MessageChannel channel) {
List<Message<byte[]>> messages;
try {
@ -288,35 +289,36 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -288,35 +289,36 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
return;
}
MessageChannel channelToUse = targetChannel;
if (this.orderedHandlingMessageChannels != null) {
channelToUse = this.orderedHandlingMessageChannels.computeIfAbsent(
session.getId(), id -> new OrderedMessageChannelDecorator(targetChannel, logger));
}
SessionInfo info = this.sessions.get(session.getId());
MessageChannel channelToUse = (info != null ? info.getMessageChannelToUse() : null);
for (Message<byte[]> message : messages) {
StompHeaderAccessor headerAccessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
StompHeaderAccessor headerAccessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
Assert.state(headerAccessor != null, "No StompHeaderAccessor");
StompCommand command = headerAccessor.getCommand();
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
boolean isConnect = (StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command));
String sessionId = session.getId();
boolean sent = false;
try {
if (isConnect) {
channelToUse = (this.preserveReceiveOrder ? new OrderedMessageChannelDecorator(channel, logger) : channel);
info = new SessionInfo(channelToUse, session.getPrincipal());
SessionInfo prevInfo = this.sessions.putIfAbsent(sessionId, info);
Assert.state(prevInfo == null, "Session already exists");
headerAccessor.setUserChangeCallback(info);
}
else {
Assert.state(channelToUse != null, "Unknown session: " + sessionId);
}
headerAccessor.setSessionId(session.getId());
headerAccessor.setSessionId(sessionId);
headerAccessor.setSessionAttributes(session.getAttributes());
headerAccessor.setUser(getUser(session));
if (isConnect) {
headerAccessor.setUserChangeCallback(user -> {
if (user != null && user != session.getPrincipal()) {
this.stompAuthentications.put(session.getId(), user);
}
});
}
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
if (!detectImmutableMessageInterceptor(targetChannel)) {
if (!detectImmutableMessageInterceptor(channel)) {
headerAccessor.setImmutable();
}
@ -356,23 +358,28 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -356,23 +358,28 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
}
catch (Throwable ex) {
if (logger.isDebugEnabled()) {
logger.debug("Failed to send message to MessageChannel in session " + session.getId(), ex);
logger.debug("Failed to send message to MessageChannel in session " + sessionId, ex);
}
else if (logger.isErrorEnabled()) {
// Skip for unsent CONNECT or SUBSCRIBE (likely authentication/authorization issues)
if (sent || !(isConnect || StompCommand.SUBSCRIBE.equals(command))) {
logger.error("Failed to send message to MessageChannel in session " +
session.getId() + ":" + ex.getMessage());
sessionId + ":" + ex.getMessage());
}
}
handleError(session, ex, message);
}
if (!sent && isConnect) {
this.sessions.remove(sessionId);
break;
}
}
}
private @Nullable Principal getUser(WebSocketSession session) {
Principal user = this.stompAuthentications.get(session.getId());
return (user != null ? user : session.getPrincipal());
SessionInfo info = this.sessions.get(session.getId());
return (info != null ? info.getUser() : session.getPrincipal());
}
private void handleError(WebSocketSession session, Throwable ex, @Nullable Message<byte[]> clientMessage) {
@ -674,10 +681,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -674,10 +681,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
outputChannel.send(message);
}
finally {
if (this.orderedHandlingMessageChannels != null) {
this.orderedHandlingMessageChannels.remove(session.getId());
}
this.stompAuthentications.remove(session.getId());
this.sessions.remove(session.getId());
SimpAttributesContextHolder.resetAttributes();
simpAttributes.sessionCompleted();
}
@ -707,6 +711,36 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -707,6 +711,36 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
}
private static class SessionInfo implements Consumer<Principal> {
private final MessageChannel channel;
private final @Nullable Principal webSocketUser;
private volatile @Nullable Principal stompUser;
SessionInfo(MessageChannel channel, @Nullable Principal user) {
this.channel = channel;
this.webSocketUser = user;
}
public MessageChannel getMessageChannelToUse() {
return this.channel;
}
public @Nullable Principal getUser() {
return (this.stompUser != null ? this.stompUser : this.webSocketUser);
}
@Override
public void accept(@Nullable Principal stompUser) {
if (stompUser != null && stompUser != this.webSocketUser) {
this.stompUser = stompUser;
}
}
}
/**
* Contract for access to session counters.
* @since 5.2

9
spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java

@ -101,6 +101,9 @@ class WebSocketMessageBrokerConfigurationSupportTests { @@ -101,6 +101,9 @@ class WebSocketMessageBrokerConfigurationSupportTests {
session.setOpen(true);
webSocketHandler.afterConnectionEstablished(session);
webSocketHandler.handleMessage(session,
StompTextMessageBuilder.create(StompCommand.CONNECT).headers("destination:/foo").build());
webSocketHandler.handleMessage(session,
StompTextMessageBuilder.create(StompCommand.SEND).headers("destination:/foo").build());
@ -108,6 +111,12 @@ class WebSocketMessageBrokerConfigurationSupportTests { @@ -108,6 +111,12 @@ class WebSocketMessageBrokerConfigurationSupportTests {
StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
assertThat(accessor).isNotNull();
assertThat(accessor.isMutable()).isFalse();
assertThat(accessor.getMessageType()).isEqualTo(SimpMessageType.CONNECT);
message = channel.messages.get(1);
accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
assertThat(accessor).isNotNull();
assertThat(accessor.isMutable()).isFalse();
assertThat(accessor.getMessageType()).isEqualTo(SimpMessageType.MESSAGE);
assertThat(accessor.getDestination()).isEqualTo("/foo");
}

5
spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java

@ -89,9 +89,10 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -89,9 +89,10 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
super.setup(server, webSocketClient, testInfo);
TextMessage message = create(StompCommand.SEND).headers("destination:/app/simple").build();
TextMessage m1 = create(StompCommand.CONNECT).headers("accept-version:1.1").build();
TextMessage m2 = create(StompCommand.SEND).headers("destination:/app/simple").build();
try (WebSocketSession session = execute(new TestClientWebSocketHandler(0, message), "/ws").get()) {
try (WebSocketSession session = execute(new TestClientWebSocketHandler(0, m1, m2), "/ws").get()) {
assertThat(session).isNotNull();
SimpleController controller = this.wac.getBean(SimpleController.class);
assertThat(controller.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();

Loading…
Cancel
Save