diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java index ef23b2d5891..8aac43a72e5 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java @@ -30,6 +30,7 @@ import org.springframework.messaging.MessageChannel; import org.springframework.messaging.handler.websocket.SubProtocolHandler; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.handler.MutableUserQueueSuffixResolver; +import org.springframework.messaging.simp.handler.SimpleUserQueueSuffixResolver; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; @@ -64,7 +65,7 @@ public class StompProtocolHandler implements SubProtocolHandler { private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); - private MutableUserQueueSuffixResolver queueSuffixResolver; + private MutableUserQueueSuffixResolver queueSuffixResolver = new SimpleUserQueueSuffixResolver(); /** diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationTests.java index 79f47d8e7eb..91735a98305 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationTests.java @@ -27,15 +27,12 @@ import org.junit.runners.Parameterized.Parameters; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.messaging.Message; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.simp.AbstractWebSocketIntegrationTests; import org.springframework.messaging.simp.JettyTestServer; import org.springframework.messaging.simp.stomp.StompCommand; -import org.springframework.messaging.simp.stomp.StompHeaderAccessor; -import org.springframework.messaging.simp.stomp.StompMessageConverter; -import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.simp.stomp.StompTextMessageBuilder; import org.springframework.messaging.support.channel.ExecutorSubscribableChannel; import org.springframework.stereotype.Controller; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; @@ -76,16 +73,13 @@ public class WebSocketMessageBrokerConfigurationTests extends AbstractWebSocketI this.server.init(cxt); this.server.start(); - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); - headers.setDestination("/app/foo"); - Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); - byte[] bytes = new StompMessageConverter().fromMessage(message); - final TextMessage webSocketMessage = new TextMessage(new String(bytes)); + final TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.SEND) + .headers("destination:/app/foo").build(); WebSocketHandler clientHandler = new TextWebSocketHandlerAdapter() { @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { - session.sendMessage(webSocketMessage); + session.sendMessage(textMessage); } }; diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java index 4a4b3f4fe6a..f2716271378 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java @@ -24,6 +24,7 @@ import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.web.socket.TextMessage; import static org.junit.Assert.*; @@ -41,14 +42,17 @@ public class StompMessageConverterTests { this.converter = new StompMessageConverter(); } - @SuppressWarnings("unchecked") @Test public void connectFrame() throws Exception { - String accept = "accept-version:1.1\n"; - String host = "host:github.org\n"; - String frame = "\n\n\nCONNECT\n" + accept + host + "\n"; - Message message = (Message) this.converter.toMessage(frame.getBytes("UTF-8")); + String accept = "accept-version:1.1"; + String host = "host:github.org"; + + TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT) + .headers(accept, host).build(); + + @SuppressWarnings("unchecked") + Message message = (Message) this.converter.toMessage(textMessage.getPayload()); assertEquals(0, message.getPayload().length); @@ -80,11 +84,14 @@ public class StompMessageConverterTests { @Test public void connectWithEscapes() throws Exception { - String accept = "accept-version:1.1\n"; - String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; - String frame = "CONNECT\n" + accept + host + "\n"; + String accept = "accept-version:1.1"; + String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org"; + + TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT) + .headers(accept, host).build(); + @SuppressWarnings("unchecked") - Message message = (Message) this.converter.toMessage(frame.getBytes("UTF-8")); + Message message = (Message) this.converter.toMessage(textMessage.getPayload()); assertEquals(0, message.getPayload().length); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java new file mode 100644 index 00000000000..62d16318fca --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java @@ -0,0 +1,97 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import java.util.Arrays; +import java.util.HashSet; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.support.TestPrincipal; +import org.springframework.web.socket.support.TestWebSocketSession; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Test fixture for {@link StompProtocolHandler} tests. + * + * @author Rossen Stoyanchev + */ +public class StompProtocolHandlerTests { + + private StompProtocolHandler stompHandler; + + private TestWebSocketSession session; + + private MessageChannel channel; + + private ArgumentCaptor messageCaptor; + + + @Before + public void setup() { + this.stompHandler = new StompProtocolHandler(); + this.channel = Mockito.mock(MessageChannel.class); + this.messageCaptor = ArgumentCaptor.forClass(Message.class); + + this.session = new TestWebSocketSession(); + this.session.setId("s1"); + this.session.setPrincipal(new TestPrincipal("joe")); + } + + @Test + public void handleConnect() { + + TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers( + "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build(); + + this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel); + + verify(this.channel).send(this.messageCaptor.capture()); + Message actual = this.messageCaptor.getValue(); + assertNotNull(actual); + + StompHeaderAccessor headers = StompHeaderAccessor.wrap(actual); + assertEquals(StompCommand.CONNECT, headers.getCommand()); + assertEquals("s1", headers.getSessionId()); + assertEquals("joe", headers.getUser().getName()); + assertEquals("guest", headers.getLogin()); + assertEquals("PROTECTED", headers.getPasscode()); + assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat()); + assertEquals(new HashSet<>(Arrays.asList("1.1","1.0")), headers.getAcceptVersion()); + + // Check CONNECTED reply + + assertEquals(1, this.session.getSentMessages().size()); + textMessage = (TextMessage) this.session.getSentMessages().get(0); + Message message = new StompMessageConverter().toMessage(textMessage.getPayload()); + StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message); + + assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand()); + assertEquals("1.1", replyHeaders.getVersion()); + assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat()); + assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0)); + assertEquals("s1", replyHeaders.getNativeHeader("queue-suffix").get(0)); + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompTextMessageBuilder.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompTextMessageBuilder.java new file mode 100644 index 00000000000..e8c2df57f30 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompTextMessageBuilder.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.springframework.web.socket.TextMessage; + + +/** + * A builder for creating WebSocket messages with STOMP frame content. + * + * @author Rossen Stoyanchev + */ +public class StompTextMessageBuilder { + + private StompCommand command; + + private final List headerLines = new ArrayList(); + + private String body; + + + private StompTextMessageBuilder(StompCommand command) { + this.command = command; + } + + public static StompTextMessageBuilder create(StompCommand command) { + return new StompTextMessageBuilder(command); + } + + public StompTextMessageBuilder headers(String... headerLines) { + this.headerLines.addAll(Arrays.asList(headerLines)); + return this; + } + + public StompTextMessageBuilder body(String body) { + this.body = body; + return this; + } + + public TextMessage build() { + StringBuilder sb = new StringBuilder(this.command.name()).append("\n"); + for (String line : this.headerLines) { + sb.append(line).append("\n"); + } + sb.append("\n"); + if (this.body != null) { + sb.append(this.body); + } + sb.append("\u0000"); + return new TextMessage(sb.toString()); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java index 69eb590979f..c2f9fc80934 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java @@ -17,26 +17,26 @@ package org.springframework.web.socket.server.config; import java.util.Arrays; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; -import org.mockito.Mockito; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import org.springframework.web.socket.AbstractWebSocketIntegrationTests; import org.springframework.web.socket.JettyTestServer; -import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.adapter.WebSocketHandlerAdapter; import org.springframework.web.socket.client.jetty.JettyWebSocketClient; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; -import static org.mockito.Matchers.*; -import static org.mockito.Mockito.*; +import static org.junit.Assert.*; /** @@ -63,13 +63,10 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes this.server.init(cxt); this.server.start(); - WebSocketHandler clientHandler = Mockito.mock(WebSocketHandler.class); - WebSocketHandler serverHandler = cxt.getBean(WebSocketHandler.class); + this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/ws"); - this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/ws"); - - verify(serverHandler).afterConnectionEstablished(any(WebSocketSession.class)); - verify(clientHandler).afterConnectionEstablished(any(WebSocketSession.class)); + TestWebSocketHandler serverHandler = cxt.getBean(TestWebSocketHandler.class); + assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS)); } @Test @@ -81,13 +78,10 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes this.server.init(cxt); this.server.start(); - WebSocketHandler clientHandler = Mockito.mock(WebSocketHandler.class); - WebSocketHandler serverHandler = cxt.getBean(WebSocketHandler.class); - - this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/sockjs/websocket"); + this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/sockjs/websocket"); - verify(serverHandler).afterConnectionEstablished(any(WebSocketSession.class)); - verify(clientHandler).afterConnectionEstablished(any(WebSocketSession.class)); + TestWebSocketHandler serverHandler = cxt.getBean(TestWebSocketHandler.class); + assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS)); } @@ -110,8 +104,18 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes } @Bean - public WebSocketHandler serverHandler() { - return Mockito.mock(WebSocketHandler.class); + public TestWebSocketHandler serverHandler() { + return new TestWebSocketHandler(); + } + } + + private static class TestWebSocketHandler extends WebSocketHandlerAdapter { + + private CountDownLatch latch = new CountDownLatch(1); + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + this.latch.countDown(); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestPrincipal.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestPrincipal.java new file mode 100644 index 00000000000..866bdba2ba1 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestPrincipal.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.support; + +import java.security.Principal; + + +/** + * An implementation of Prinicipal for testing. + * @author Rossen Stoyanchev + */ +public class TestPrincipal implements Principal { + + private String name; + + public TestPrincipal(String name) { + this.name = name; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof TestPrincipal)) { + return false; + } + TestPrincipal p = (TestPrincipal) obj; + return this.name.equals(p.name); + } + + @Override + public int hashCode() { + return this.name.hashCode(); + } + +}