diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java index bda6e87a4b8..94ace30f42a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java @@ -156,7 +156,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec } SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message); if (SimpMessageType.MESSAGE.equals(headerAccessor.getMessageType())) { - headerAccessor.setHeader(SUBSCRIBE_DESTINATION, result.getSubscribeDestination()); + headerAccessor.setNativeHeader(SUBSCRIBE_DESTINATION, result.getSubscribeDestination()); message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headerAccessor).build(); } for (String targetDestination : destinations) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java index 717388af0d1..bb0751e51dc 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java @@ -29,6 +29,7 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.TestPrincipal; import org.springframework.messaging.support.MessageBuilder; +import sun.security.provider.SHA; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; @@ -93,10 +94,10 @@ public class UserDestinationMessageHandlerTests { ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); Mockito.verify(this.brokerChannel).send(captor.capture()); - assertEquals("/queue/foo-user123", - captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER)); - assertEquals("/user/queue/foo", - captor.getValue().getHeaders().get(UserDestinationMessageHandler.SUBSCRIBE_DESTINATION)); + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(captor.getValue()); + + assertEquals("/queue/foo-user123", accessor.getDestination()); + assertEquals("/user/queue/foo", accessor.getFirstNativeHeader(UserDestinationMessageHandler.SUBSCRIBE_DESTINATION)); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 95c23da0a58..d6c56f8c5d1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -251,9 +251,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE logger.error("Ignoring message, no subscriptionId header: " + message); return; } - String header = UserDestinationMessageHandler.SUBSCRIBE_DESTINATION; - if (message.getHeaders().containsKey(header)) { - headers.setDestination((String) message.getHeaders().get(header)); + String name = UserDestinationMessageHandler.SUBSCRIBE_DESTINATION; + String origDestination = headers.getFirstNativeHeader(name); + if (origDestination != null) { + headers.setDestination(origDestination); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 426a0599ab4..983ed65f374 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -209,6 +209,7 @@ public class StompSubProtocolHandlerTests { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); TextMessage textMessage = new TextMessage(new StompEncoder().encode(message)); + this.protocolHandler.afterSessionStarted(this.session, this.channel); this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); verify(this.channel).send(this.messageCaptor.capture()); @@ -240,8 +241,9 @@ public class StompSubProtocolHandlerTests { headers.setMessageId("mess0"); headers.setSubscriptionId("sub0"); headers.setDestination("/queue/foo-user123"); - headers.setHeader(UserDestinationMessageHandler.SUBSCRIBE_DESTINATION, "/user/queue/foo"); + headers.setNativeHeader(UserDestinationMessageHandler.SUBSCRIBE_DESTINATION, "/user/queue/foo"); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + this.protocolHandler.afterSessionStarted(this.session, this.channel); this.protocolHandler.handleMessageToClient(this.session, message); assertEquals(1, this.session.getSentMessages().size()); @@ -278,8 +280,9 @@ public class StompSubProtocolHandlerTests { @Test public void handleMessageFromClientInvalidStompCommand() { - TextMessage textMessage = new TextMessage("FOO"); + TextMessage textMessage = new TextMessage("FOO\n\n\0"); + this.protocolHandler.afterSessionStarted(this.session, this.channel); this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); verifyZeroInteractions(this.channel);