diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java index a9cbb09539e..f5009775df1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java @@ -61,6 +61,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler { Assert.notNull(this.clientChannel, "No clientChannel to send messages to"); + Message returnMessage = (Message) returnValue; if (message == null) { return; } @@ -68,7 +69,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler { PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message); Assert.notNull(headers.getSubscriptionId(), "No subscription id: " + message); - PubSubHeaderAccesssor returnHeaders = PubSubHeaderAccesssor.wrap(message); + PubSubHeaderAccesssor returnHeaders = PubSubHeaderAccesssor.wrap(returnMessage); returnHeaders.setSessionId(headers.getSessionId()); returnHeaders.setSubscriptionId(headers.getSubscriptionId()); @@ -76,11 +77,10 @@ public class MessageReturnValueHandler implements ReturnValueHandler { returnHeaders.setDestination(headers.getDestination()); } - Message returnMessage = MessageBuilder.withPayload( - message.getPayload()).copyHeaders(headers.toHeaders()).build(); + returnMessage = MessageBuilder.withPayload( + returnMessage.getPayload()).copyHeaders(returnHeaders.toHeaders()).build(); this.clientChannel.send(returnMessage); } - } diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java index c2f88e9db49..8aaf2d766b7 100644 --- a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java @@ -32,21 +32,22 @@ import static org.junit.Assert.*; */ public class StompMessageConverterTests { - private StompMessageConverter> converter; + private StompMessageConverter converter; @Before public void setup() { - this.converter = new StompMessageConverter>(); + this.converter = new StompMessageConverter(); } + @SuppressWarnings("unchecked") @Test public void connectFrame() throws Exception { String accept = "accept-version:1.1\n"; String host = "host:github.org\n"; String frame = "\n\n\nCONNECT\n" + accept + host + "\n"; - Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); + Message message = (Message) this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); assertEquals(0, message.getPayload().length); @@ -76,7 +77,8 @@ public class StompMessageConverterTests { String accept = "accept-version:1.1\n"; String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; String frame = "CONNECT\n" + accept + host + "\n"; - Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); + @SuppressWarnings("unchecked") + Message message = (Message) this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); assertEquals(0, message.getPayload().length); @@ -97,7 +99,8 @@ public class StompMessageConverterTests { String accept = "accept-version:1.2\n"; String host = "host:github.org\n"; String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; - Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); + @SuppressWarnings("unchecked") + Message message = (Message) this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); assertEquals(0, message.getPayload().length); @@ -118,7 +121,8 @@ public class StompMessageConverterTests { String accept = "accept-version:1.1\n"; String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; - Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); + @SuppressWarnings("unchecked") + Message message = (Message) this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); assertEquals(0, message.getPayload().length);