diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java index f2e94092681..019aa4e3966 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java @@ -52,6 +52,10 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH private final boolean annotationRequired; + private String defaultDestinationPrefix = "/topic"; + + private String defaultUserDestinationPrefix = "/queue"; + public SendToMethodReturnValueHandler(SimpMessageSendingOperations messagingTemplate, boolean annotationRequired) { Assert.notNull(messagingTemplate, "messagingTemplate is required"); @@ -60,6 +64,45 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH } + /** + * Configure a default prefix to add to message destinations in cases where a method + * is not annotated with {@link SendTo @SendTo} or does not specify any destinations + * through the annotation's value attribute. + *
+ * By default, the prefix is set to "/topic". + */ + public void setDefaultDestinationPrefix(String defaultDestinationPrefix) { + this.defaultDestinationPrefix = defaultDestinationPrefix; + } + + /** + * Return the configured default destination prefix. + * @see #setDefaultDestinationPrefix(String) + */ + public String getDefaultDestinationPrefix() { + return this.defaultDestinationPrefix; + } + + /** + * Configure a default prefix to add to message destinations in cases where a + * method is annotated with {@link SendToUser @SendToUser} but does not specify + * any destinations through the annotation's value attribute. + *
+ * By default, the prefix is set to "/queue". + */ + public void setDefaultUserDestinationPrefix(String prefix) { + this.defaultUserDestinationPrefix = prefix; + } + + /** + * Return the configured default user destination prefix. + * @see #setDefaultUserDestinationPrefix(String) + */ + public String getDefaultUserDestinationPrefix() { + return this.defaultUserDestinationPrefix; + } + + @Override public boolean supportsReturnType(MethodParameter returnType) { if ((returnType.getMethodAnnotation(SendTo.class) != null) || @@ -88,26 +131,31 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH throw new MissingSessionUserException(inputMessage); } String user = inputHeaders.getUser().getName(); - for (String destination : getDestinations(sendToUser, inputHeaders.getDestination())) { + String[] destinations = getTargetDestinations(sendToUser, inputHeaders, this.defaultUserDestinationPrefix); + for (String destination : destinations) { this.messagingTemplate.convertAndSendToUser(user, destination, returnValue, postProcessor); } return; } - - SendTo sendTo = returnType.getMethodAnnotation(SendTo.class); - if (sendTo != null) { - for (String destination : getDestinations(sendTo, inputHeaders.getDestination())) { + else { + SendTo sendTo = returnType.getMethodAnnotation(SendTo.class); + String[] destinations = getTargetDestinations(sendTo, inputHeaders, this.defaultDestinationPrefix); + for (String destination : getTargetDestinations(sendTo, inputHeaders, this.defaultDestinationPrefix)) { this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor); } - return; } - - this.messagingTemplate.convertAndSend(inputHeaders.getDestination(), returnValue, postProcessor); } - private String[] getDestinations(Annotation annot, String inputDestination) { - String[] destinations = (String[]) AnnotationUtils.getValue(annot); - return ObjectUtils.isEmpty(destinations) ? new String[] { inputDestination } : destinations; + protected String[] getTargetDestinations(Annotation annot, SimpMessageHeaderAccessor inputHeaders, + String defaultPrefix) { + + if (annot != null) { + String[] value = (String[]) AnnotationUtils.getValue(annot); + if (!ObjectUtils.isEmpty(value)) { + return value; + } + } + return new String[] { defaultPrefix + inputHeaders.getDestination() }; } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java index 089f1d9f20a..31c7df1b8a3 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java @@ -30,7 +30,6 @@ import org.mockito.MockitoAnnotations; import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessagingTemplate; @@ -63,11 +62,11 @@ public class SendToMethodReturnValueHandlerTests { @Mock private MessageConverter messageConverter; + private MethodParameter noAnnotationsReturnType; private MethodParameter sendToReturnType; - + private MethodParameter sendToDefaultDestReturnType; private MethodParameter sendToUserReturnType; - - private MethodParameter missingSendToReturnType; + private MethodParameter sendToUserDefaultDestReturnType; @SuppressWarnings("unchecked") @@ -85,14 +84,20 @@ public class SendToMethodReturnValueHandlerTests { this.handler = new SendToMethodReturnValueHandler(messagingTemplate, true); this.handlerAnnotationNotRequired = new SendToMethodReturnValueHandler(messagingTemplate, false); - Method method = this.getClass().getDeclaredMethod("handleAndSendTo"); + Method method = this.getClass().getDeclaredMethod("handleNoAnnotations"); + this.noAnnotationsReturnType = new MethodParameter(method, -1); + + method = this.getClass().getDeclaredMethod("handleAndSendToDefaultDestination"); + this.sendToDefaultDestReturnType = new MethodParameter(method, -1); + + method = this.getClass().getDeclaredMethod("handleAndSendTo"); this.sendToReturnType = new MethodParameter(method, -1); method = this.getClass().getDeclaredMethod("handleAndSendToUser"); this.sendToUserReturnType = new MethodParameter(method, -1); - method = this.getClass().getDeclaredMethod("handleWithMissingSendTo"); - this.missingSendToReturnType = new MethodParameter(method, -1); + method = this.getClass().getDeclaredMethod("handleAndSendToUserDefaultDestination"); + this.sendToUserDefaultDestReturnType = new MethodParameter(method, -1); } @@ -100,8 +105,25 @@ public class SendToMethodReturnValueHandlerTests { public void supportsReturnType() throws Exception { assertTrue(this.handler.supportsReturnType(this.sendToReturnType)); assertTrue(this.handler.supportsReturnType(this.sendToUserReturnType)); - assertFalse(this.handler.supportsReturnType(this.missingSendToReturnType)); - assertTrue(this.handlerAnnotationNotRequired.supportsReturnType(this.missingSendToReturnType)); + assertFalse(this.handler.supportsReturnType(this.noAnnotationsReturnType)); + assertTrue(this.handlerAnnotationNotRequired.supportsReturnType(this.noAnnotationsReturnType)); + } + + @Test + public void sendToNoAnnotations() throws Exception { + + when(this.messageChannel.send(any(Message.class))).thenReturn(true); + + Message> inputMessage = createInputMessage("sess1", "sub1", "/dest", null); + this.handler.handleReturnValue(payloadContent, this.noAnnotationsReturnType, inputMessage); + + verify(this.messageChannel, times(1)).send(this.messageCaptor.capture()); + + Message> message = this.messageCaptor.getAllValues().get(0); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + assertEquals("sess1", headers.getSessionId()); + assertNull(headers.getSubscriptionId()); + assertEquals("/topic/dest", headers.getDestination()); } @Test @@ -110,27 +132,42 @@ public class SendToMethodReturnValueHandlerTests { when(this.messageChannel.send(any(Message.class))).thenReturn(true); String sessionId = "sess1"; - Message> inputMessage = createInputMessage(sessionId, "sub1", "/dest", null); - + Message> inputMessage = createInputMessage(sessionId, "sub1", null, null); this.handler.handleReturnValue(payloadContent, this.sendToReturnType, inputMessage); verify(this.messageChannel, times(2)).send(this.messageCaptor.capture()); Message> message = this.messageCaptor.getAllValues().get(0); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - assertEquals(sessionId, headers.getSessionId()); assertNull(headers.getSubscriptionId()); assertEquals("/dest1", headers.getDestination()); message = this.messageCaptor.getAllValues().get(1); headers = SimpMessageHeaderAccessor.wrap(message); - assertEquals(sessionId, headers.getSessionId()); assertNull(headers.getSubscriptionId()); assertEquals("/dest2", headers.getDestination()); } + @Test + public void sendToDefaultDestinationMethod() throws Exception { + + when(this.messageChannel.send(any(Message.class))).thenReturn(true); + + String sessionId = "sess1"; + Message> inputMessage = createInputMessage(sessionId, "sub1", "/dest", null); + this.handler.handleReturnValue(payloadContent, this.sendToDefaultDestReturnType, inputMessage); + + verify(this.messageChannel, times(1)).send(this.messageCaptor.capture()); + + Message> message = this.messageCaptor.getAllValues().get(0); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + assertEquals(sessionId, headers.getSessionId()); + assertNull(headers.getSubscriptionId()); + assertEquals("/topic/dest", headers.getDestination()); + } + @Test public void sendToUserMethod() throws Exception { @@ -138,34 +175,54 @@ public class SendToMethodReturnValueHandlerTests { String sessionId = "sess1"; TestUser user = new TestUser(); - Message> inputMessage = createInputMessage(sessionId, "sub1", "/dest", user); - + Message> inputMessage = createInputMessage(sessionId, "sub1", null, user); this.handler.handleReturnValue(payloadContent, this.sendToUserReturnType, inputMessage); verify(this.messageChannel, times(2)).send(this.messageCaptor.capture()); Message> message = this.messageCaptor.getAllValues().get(0); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - assertEquals(sessionId, headers.getSessionId()); assertNull(headers.getSubscriptionId()); assertEquals("/user/" + user.getName() + "/dest1", headers.getDestination()); message = this.messageCaptor.getAllValues().get(1); headers = SimpMessageHeaderAccessor.wrap(message); - assertEquals(sessionId, headers.getSessionId()); assertNull(headers.getSubscriptionId()); assertEquals("/user/" + user.getName() + "/dest2", headers.getDestination()); } + @Test + public void sendToUserDefaultDestinationMethod() throws Exception { + + when(this.messageChannel.send(any(Message.class))).thenReturn(true); + + String sessionId = "sess1"; + TestUser user = new TestUser(); + Message> inputMessage = createInputMessage(sessionId, "sub1", "/dest", user); + this.handler.handleReturnValue(payloadContent, this.sendToUserDefaultDestReturnType, inputMessage); + + verify(this.messageChannel, times(1)).send(this.messageCaptor.capture()); + + Message> message = this.messageCaptor.getAllValues().get(0); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + assertEquals(sessionId, headers.getSessionId()); + assertNull(headers.getSubscriptionId()); + assertEquals("/user/" + user.getName() + "/queue/dest", headers.getDestination()); + } + - private Message> createInputMessage(String sessId, String subsId, String dest, Principal principal) { + private Message> createInputMessage(String sessId, String subsId, String destination, Principal principal) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); headers.setSessionId(sessId); headers.setSubscriptionId(subsId); - headers.setDestination(dest); - headers.setUser(principal); + if (destination != null) { + headers.setDestination(destination); + } + if (principal != null) { + headers.setUser(principal); + } return MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); } @@ -180,18 +237,25 @@ public class SendToMethodReturnValueHandlerTests { } } - @MessageMapping("/handle") // not needed for the tests but here for completeness - public String handleWithMissingSendTo() { + public String handleNoAnnotations() { + return payloadContent; + } + + @SendTo + public String handleAndSendToDefaultDestination() { return payloadContent; } - @MessageMapping("/handle") // not needed for the tests but here for completeness @SendTo({"/dest1", "/dest2"}) public String handleAndSendTo() { return payloadContent; } - @MessageMapping("/handle") // not needed for the tests but here for completeness + @SendToUser + public String handleAndSendToUserDefaultDestination() { + return payloadContent; + } + @SendToUser({"/dest1", "/dest2"}) public String handleAndSendToUser() { return payloadContent; diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodIntegrationTests.java index 03efe0af97a..3d83450ebfa 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodIntegrationTests.java @@ -102,7 +102,7 @@ public class SimpAnnotationMethodIntegrationTests extends AbstractWebSocketInteg "id:subs1", "destination:/topic/increment").build(); TextMessage message2 = create(StompCommand.SEND).headers( - "destination:/app/topic/increment").body("5").build(); + "destination:/app/increment").body("5").build(); TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, message1, message2); WebSocketSession session = doHandshake(clientHandler, "/ws").get(); @@ -163,7 +163,7 @@ public class SimpAnnotationMethodIntegrationTests extends AbstractWebSocketInteg @IntegrationTestController static class IncrementController { - @MessageMapping(value="/topic/increment") + @MessageMapping(value="/increment") public int handle(int i) { return i + 1; }