diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java index fe9246b4b2d..957505cb26b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java @@ -24,6 +24,7 @@ import org.springframework.core.annotation.AnnotationUtils; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.core.MessagePostProcessor; +import org.springframework.messaging.handler.DestinationPatternsMessageCondition; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; @@ -158,7 +159,8 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH return value; } } - return new String[] { defaultPrefix + inputHeaders.getDestination() }; + return new String[] { defaultPrefix + + inputHeaders.getHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER) }; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java index 1994c7b4753..3bd8cfa1421 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java @@ -357,7 +357,6 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan String matchedPattern = mapping.getDestinationConditions().getPatterns().iterator().next(); Map vars = getPathMatcher().extractUriTemplateVariables(matchedPattern, lookupDestination); - headers.setDestination(lookupDestination); headers.setHeader(DestinationVariableMethodArgumentResolver.DESTINATION_TEMPLATE_VARIABLES_HEADER, vars); message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java index 84f71e14cd3..1e2dd57fceb 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java @@ -32,6 +32,7 @@ import org.mockito.MockitoAnnotations; import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.handler.DestinationPatternsMessageCondition; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessagingTemplate; @@ -116,7 +117,7 @@ public class SendToMethodReturnValueHandlerTests { when(this.messageChannel.send(any(Message.class))).thenReturn(true); - Message inputMessage = createInputMessage("sess1", "sub1", "/dest", null); + Message inputMessage = createInputMessage("sess1", "sub1", "/app", "/dest", null); this.handler.handleReturnValue(payloadContent, this.noAnnotationsReturnType, inputMessage); verify(this.messageChannel, times(1)).send(this.messageCaptor.capture()); @@ -134,7 +135,7 @@ public class SendToMethodReturnValueHandlerTests { when(this.messageChannel.send(any(Message.class))).thenReturn(true); String sessionId = "sess1"; - Message inputMessage = createInputMessage(sessionId, "sub1", null, null); + Message inputMessage = createInputMessage(sessionId, "sub1", null, null, null); this.handler.handleReturnValue(payloadContent, this.sendToReturnType, inputMessage); verify(this.messageChannel, times(2)).send(this.messageCaptor.capture()); @@ -158,7 +159,7 @@ public class SendToMethodReturnValueHandlerTests { when(this.messageChannel.send(any(Message.class))).thenReturn(true); String sessionId = "sess1"; - Message inputMessage = createInputMessage(sessionId, "sub1", "/dest", null); + Message inputMessage = createInputMessage(sessionId, "sub1", "/app", "/dest", null); this.handler.handleReturnValue(payloadContent, this.sendToDefaultDestReturnType, inputMessage); verify(this.messageChannel, times(1)).send(this.messageCaptor.capture()); @@ -177,7 +178,7 @@ public class SendToMethodReturnValueHandlerTests { String sessionId = "sess1"; TestUser user = new TestUser(); - Message inputMessage = createInputMessage(sessionId, "sub1", null, user); + Message inputMessage = createInputMessage(sessionId, "sub1", null, null, user); this.handler.handleReturnValue(payloadContent, this.sendToUserReturnType, inputMessage); verify(this.messageChannel, times(2)).send(this.messageCaptor.capture()); @@ -202,7 +203,7 @@ public class SendToMethodReturnValueHandlerTests { String sessionId = "sess1"; TestUser user = new UniqueUser(); - Message inputMessage = createInputMessage(sessionId, "sub1", null, user); + Message inputMessage = createInputMessage(sessionId, "sub1", null, null, user); this.handler.handleReturnValue(payloadContent, this.sendToUserReturnType, inputMessage); verify(this.messageChannel, times(2)).send(this.messageCaptor.capture()); @@ -221,7 +222,7 @@ public class SendToMethodReturnValueHandlerTests { String sessionId = "sess1"; TestUser user = new TestUser(); - Message inputMessage = createInputMessage(sessionId, "sub1", "/dest", user); + Message inputMessage = createInputMessage(sessionId, "sub1", "/app", "/dest", user); this.handler.handleReturnValue(payloadContent, this.sendToUserDefaultDestReturnType, inputMessage); verify(this.messageChannel, times(1)).send(this.messageCaptor.capture()); @@ -234,12 +235,14 @@ public class SendToMethodReturnValueHandlerTests { } - private Message createInputMessage(String sessId, String subsId, String destination, Principal principal) { + private Message createInputMessage(String sessId, String subsId, String destinationPrefix, + String destination, Principal principal) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); headers.setSessionId(sessId); headers.setSubscriptionId(subsId); - if (destination != null) { - headers.setDestination(destination); + if (destination != null && destinationPrefix != null) { + headers.setDestination(destinationPrefix + destination); + headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, destination); } if (principal != null) { headers.setUser(principal); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java index 756a86eda9d..10113ca48c9 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java @@ -36,6 +36,7 @@ import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.messaging.simp.annotation.SubscribeMapping; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.support.AbstractSubscribableChannel; @@ -134,6 +135,28 @@ public class StompWebSocketIntegrationTests extends AbstractWebSocketIntegration } } + // SPR-11648 + + @Test + public void sendSubscribeToControllerAndReceiveReply() throws Exception { + + TextMessage message = create(StompCommand.SUBSCRIBE).headers( + "id:subs1", "destination:/app/number").build(); + + TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, message); + WebSocketSession session = doHandshake(clientHandler, "/ws").get(); + + try { + assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS)); + String payload = clientHandler.actual.get(0).getPayload(); + assertTrue("Expected STOMP destination=/app/number, got " + payload, payload.contains("destination:/app/number")); + assertTrue("Expected STOMP Payload=42, got " + payload, payload.contains("42")); + } + finally { + session.close(); + } + } + @IntegrationTestController static class SimpleController { @@ -164,6 +187,11 @@ public class StompWebSocketIntegrationTests extends AbstractWebSocketIntegration public int handle(int i) { return i + 1; } + + @SubscribeMapping("/number") + public int number() { + return 42; + } }