From 443fb8e4eed3c0db4185633959266fdf0ff2cbe7 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Fri, 25 Apr 2014 12:12:56 -0400 Subject: [PATCH] Fix issue with subscribe destination The original fix for SPR-11423: https://github.com/spring-projects/spring-framework/commit/32e5f57e647022d9ea82c03670936bf31f8467de was insufficient when using an external broker since the original destination header has to be in the "native headers" map (i.e. with STOMP headers) in order to be included in messages broadcast by the broker. --- .../simp/user/UserDestinationMessageHandler.java | 2 +- .../simp/user/UserDestinationMessageHandlerTests.java | 9 +++++---- .../web/socket/messaging/StompSubProtocolHandler.java | 7 ++++--- .../socket/messaging/StompSubProtocolHandlerTests.java | 7 +++++-- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java index bda6e87a4b8..94ace30f42a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java @@ -156,7 +156,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec } SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message); if (SimpMessageType.MESSAGE.equals(headerAccessor.getMessageType())) { - headerAccessor.setHeader(SUBSCRIBE_DESTINATION, result.getSubscribeDestination()); + headerAccessor.setNativeHeader(SUBSCRIBE_DESTINATION, result.getSubscribeDestination()); message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headerAccessor).build(); } for (String targetDestination : destinations) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java index 717388af0d1..bb0751e51dc 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java @@ -29,6 +29,7 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.TestPrincipal; import org.springframework.messaging.support.MessageBuilder; +import sun.security.provider.SHA; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; @@ -93,10 +94,10 @@ public class UserDestinationMessageHandlerTests { ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); Mockito.verify(this.brokerChannel).send(captor.capture()); - assertEquals("/queue/foo-user123", - captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER)); - assertEquals("/user/queue/foo", - captor.getValue().getHeaders().get(UserDestinationMessageHandler.SUBSCRIBE_DESTINATION)); + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(captor.getValue()); + + assertEquals("/queue/foo-user123", accessor.getDestination()); + assertEquals("/user/queue/foo", accessor.getFirstNativeHeader(UserDestinationMessageHandler.SUBSCRIBE_DESTINATION)); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 95c23da0a58..d6c56f8c5d1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -251,9 +251,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE logger.error("Ignoring message, no subscriptionId header: " + message); return; } - String header = UserDestinationMessageHandler.SUBSCRIBE_DESTINATION; - if (message.getHeaders().containsKey(header)) { - headers.setDestination((String) message.getHeaders().get(header)); + String name = UserDestinationMessageHandler.SUBSCRIBE_DESTINATION; + String origDestination = headers.getFirstNativeHeader(name); + if (origDestination != null) { + headers.setDestination(origDestination); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 426a0599ab4..983ed65f374 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -209,6 +209,7 @@ public class StompSubProtocolHandlerTests { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); TextMessage textMessage = new TextMessage(new StompEncoder().encode(message)); + this.protocolHandler.afterSessionStarted(this.session, this.channel); this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); verify(this.channel).send(this.messageCaptor.capture()); @@ -240,8 +241,9 @@ public class StompSubProtocolHandlerTests { headers.setMessageId("mess0"); headers.setSubscriptionId("sub0"); headers.setDestination("/queue/foo-user123"); - headers.setHeader(UserDestinationMessageHandler.SUBSCRIBE_DESTINATION, "/user/queue/foo"); + headers.setNativeHeader(UserDestinationMessageHandler.SUBSCRIBE_DESTINATION, "/user/queue/foo"); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + this.protocolHandler.afterSessionStarted(this.session, this.channel); this.protocolHandler.handleMessageToClient(this.session, message); assertEquals(1, this.session.getSentMessages().size()); @@ -278,8 +280,9 @@ public class StompSubProtocolHandlerTests { @Test public void handleMessageFromClientInvalidStompCommand() { - TextMessage textMessage = new TextMessage("FOO"); + TextMessage textMessage = new TextMessage("FOO\n\n\0"); + this.protocolHandler.afterSessionStarted(this.session, this.channel); this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); verifyZeroInteractions(this.channel);