diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java
index 2058a88bbc5..9be712ba47c 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java
@@ -192,8 +192,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
if (SimpMessageType.MESSAGE.equals(messageType)) {
sessionId = (sessionId == null) ? SystemStompRelaySession.ID : sessionId;
headers.setSessionId(sessionId);
- command = (command == null) ? StompCommand.SEND : command;
- headers.setCommandIfNotSet(command);
+ headers.updateStompCommandAsClientMessage();
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
}
diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java
index f3771059446..cdb082ef695 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java
@@ -26,6 +26,7 @@ import java.util.concurrent.atomic.AtomicLong;
import org.springframework.http.MediaType;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
+import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
@@ -218,9 +219,29 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
return toNativeHeaderMap();
}
- public void setCommandIfNotSet(StompCommand command) {
+ public void updateStompCommandAsClientMessage() {
+
+ Assert.state(SimpMessageType.MESSAGE.equals(getMessageType()),
+ "Unexpected message type " + getMessage());
+
if (getCommand() == null) {
- setHeader(COMMAND_HEADER, command);
+ setHeader(COMMAND_HEADER, StompCommand.SEND);
+ }
+ else if (!getCommand().equals(StompCommand.SEND)) {
+ throw new IllegalStateException("Unexpected STOMP command " + getCommand());
+ }
+ }
+
+ public void updateStompCommandAsServerMessage() {
+
+ Assert.state(SimpMessageType.MESSAGE.equals(getMessageType()),
+ "Unexpected message type " + getMessage());
+
+ if ((getCommand() == null) || getCommand().equals(StompCommand.SEND)) {
+ setHeader(COMMAND_HEADER, StompCommand.MESSAGE);
+ }
+ else if (!getCommand().equals(StompCommand.MESSAGE)) {
+ throw new IllegalStateException("Unexpected STOMP command " + getCommand());
}
}
diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java
index 128eaa2d5cc..59ebf7f8c85 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java
@@ -150,9 +150,6 @@ public class StompProtocolHandler implements SubProtocolHandler {
public void handleMessageToClient(WebSocketSession session, Message> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
- if (headers.getCommand() == null && SimpMessageType.MESSAGE == headers.getMessageType()) {
- headers.setCommandIfNotSet(StompCommand.MESSAGE);
- }
if (headers.getMessageType() == SimpMessageType.CONNECT_ACK) {
StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
@@ -160,6 +157,9 @@ public class StompProtocolHandler implements SubProtocolHandler {
connectedHeaders.setHeartbeat(0, 0); // no heart-beat support with simple broker
headers = connectedHeaders;
}
+ else if (SimpMessageType.MESSAGE.equals(headers.getMessageType())) {
+ headers.updateStompCommandAsServerMessage();
+ }
if (headers.getCommand() == StompCommand.CONNECTED) {
augmentConnectedHeaders(headers, session);
diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/AnnotationMethodIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/AnnotationMethodIntegrationTests.java
index bf224b1a2de..2ae643f7e3d 100644
--- a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/AnnotationMethodIntegrationTests.java
+++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/AnnotationMethodIntegrationTests.java
@@ -80,7 +80,7 @@ public class AnnotationMethodIntegrationTests extends AbstractWebSocketIntegrati
@Test
- public void simpleController() throws Exception {
+ public void sendMessageToController() throws Exception {
TextMessage message = create(StompCommand.SEND).headers("destination:/app/simple").build();
WebSocketSession session = doHandshake(new TestClientWebSocketHandler(0, message), "/ws").get();
@@ -95,10 +95,10 @@ public class AnnotationMethodIntegrationTests extends AbstractWebSocketIntegrati
}
@Test
- public void incrementController() throws Exception {
+ public void sendMessageToControllerAndReceiveReplyViaTopic() throws Exception {
TextMessage message1 = create(StompCommand.SUBSCRIBE).headers(
- "id:subs1", "destination:/topic/increment").body("5").build();
+ "id:subs1", "destination:/topic/increment").build();
TextMessage message2 = create(StompCommand.SEND).headers(
"destination:/app/topic/increment").body("5").build();
@@ -114,6 +114,28 @@ public class AnnotationMethodIntegrationTests extends AbstractWebSocketIntegrati
}
}
+ // SPR-10930
+
+ @Test
+ public void sendMessageToBrokerAndReceiveReplyViaTopic() throws Exception {
+
+ TextMessage message1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", "destination:/topic/foo").build();
+ TextMessage message2 = create(StompCommand.SEND).headers("destination:/topic/foo").body("5").build();
+
+ TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, message1, message2);
+ WebSocketSession session = doHandshake(clientHandler, "/ws").get();
+
+ try {
+ assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS));
+
+ String payload = clientHandler.actual.get(0).getPayload();
+ assertTrue("Expected STOMP Command=MESSAGE, got " + payload, payload.startsWith("MESSAGE\n"));
+ }
+ finally {
+ session.close();
+ }
+ }
+
@IntegrationTestController
static class SimpleController {
diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandlerTests.java
index f61bf7516b6..7974a47f0e6 100644
--- a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandlerTests.java
+++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandlerTests.java
@@ -31,7 +31,6 @@ import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageBuilder;
import static org.junit.Assert.*;
-
import static org.mockito.Mockito.*;
diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java
index 26931c57320..8f0a9fa7476 100644
--- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java
+++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java
@@ -146,7 +146,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
@Test(expected=MessageDeliveryException.class)
public void messageDeliverExceptionIfSystemSessionForwardFails() throws Exception {
- StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
+ StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
this.relay.handleMessage(MessageBuilder.withPayloadAndHeaders("test", headers).build());
}
diff --git a/spring-messaging/src/test/resources/log4j.xml b/spring-messaging/src/test/resources/log4j.xml
index dfcc9491666..75d03e82684 100644
--- a/spring-messaging/src/test/resources/log4j.xml
+++ b/spring-messaging/src/test/resources/log4j.xml
@@ -12,7 +12,7 @@
-
+