@ -16,19 +16,13 @@
@@ -16,19 +16,13 @@
package org.springframework.messaging.simp.config ;
import java.util.List ;
import java.util.Map ;
import org.junit.Before ;
import org.junit.Test ;
import org.mockito.ArgumentCaptor ;
import org.mockito.Mockito ;
import org.springframework.context.annotation.AnnotationConfigApplicationContext ;
import org.springframework.context.annotation.Bean ;
import org.springframework.context.annotation.Configuration ;
import org.springframework.messaging.Message ;
import org.springframework.messaging.MessageHandler ;
import org.springframework.messaging.SubscribableChannel ;
import org.springframework.messaging.handler.annotation.MessageMapping ;
import org.springframework.messaging.handler.annotation.SendTo ;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler ;
@ -43,6 +37,8 @@ import org.springframework.messaging.simp.stomp.StompCommand;
@@ -43,6 +37,8 @@ import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor ;
import org.springframework.messaging.simp.stomp.StompTextMessageBuilder ;
import org.springframework.messaging.support.MessageBuilder ;
import org.springframework.messaging.support.channel.AbstractSubscribableChannel ;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel ;
import org.springframework.messaging.support.converter.CompositeMessageConverter ;
import org.springframework.messaging.support.converter.DefaultContentTypeResolver ;
import org.springframework.stereotype.Controller ;
@ -52,9 +48,11 @@ import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
@@ -52,9 +48,11 @@ import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.TextMessage ;
import org.springframework.web.socket.support.TestWebSocketSession ;
import java.util.ArrayList ;
import java.util.List ;
import java.util.Map ;
import static org.junit.Assert.* ;
import static org.mockito.Matchers.* ;
import static org.mockito.Mockito.* ;
/ * *
@ -95,27 +93,20 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@@ -95,27 +93,20 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void webSocketRequestChannel ( ) {
SubscribableChannel channel = this . cxtSimpleBroker . getBean ( "webSocketRequestChannel" , SubscribableChannel . class ) ;
ArgumentCaptor < MessageHandler > captor = ArgumentCaptor . forClass ( MessageHandler . class ) ;
verify ( channel , times ( 3 ) ) . subscribe ( captor . capture ( ) ) ;
TestChannel channel = this . cxtSimpleBroker . getBean ( "webSocketRequestChannel" , TestChannel . class ) ;
List < MessageHandler > handlers = channel . handlers ;
List < MessageHandler > values = captor . getAllValues ( ) ;
assertEquals ( 3 , values . size ( ) ) ;
assertTrue ( values . contains ( cxtSimpleBroker . getBean ( SimpAnnotationMethodMessageHandler . class ) ) ) ;
assertTrue ( values . contains ( cxtSimpleBroker . getBean ( UserDestinationMessageHandler . class ) ) ) ;
assertTrue ( values . contains ( cxtSimpleBroker . getBean ( SimpleBrokerMessageHandler . class ) ) ) ;
assertEquals ( 3 , handlers . size ( ) ) ;
assertTrue ( handlers . contains ( cxtSimpleBroker . getBean ( SimpAnnotationMethodMessageHandler . class ) ) ) ;
assertTrue ( handlers . contains ( cxtSimpleBroker . getBean ( UserDestinationMessageHandler . class ) ) ) ;
assertTrue ( handlers . contains ( cxtSimpleBroker . getBean ( SimpleBrokerMessageHandler . class ) ) ) ;
}
@Test
public void webSocketRequestChannelWithStompBroker ( ) {
SubscribableChannel channel = this . cxtStompBroker . getBean ( "webSocketRequestChannel" , SubscribableChannel . class ) ;
TestChannel channel = this . cxtStompBroker . getBean ( "webSocketRequestChannel" , TestChannel . class ) ;
List < MessageHandler > values = channel . handlers ;
ArgumentCaptor < MessageHandler > captor = ArgumentCaptor . forClass ( MessageHandler . class ) ;
verify ( channel , times ( 3 ) ) . subscribe ( captor . capture ( ) ) ;
List < MessageHandler > values = captor . getAllValues ( ) ;
assertEquals ( 3 , values . size ( ) ) ;
assertTrue ( values . contains ( cxtStompBroker . getBean ( SimpAnnotationMethodMessageHandler . class ) ) ) ;
assertTrue ( values . contains ( cxtStompBroker . getBean ( UserDestinationMessageHandler . class ) ) ) ;
@ -125,16 +116,13 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@@ -125,16 +116,13 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void webSocketRequestChannelSendMessage ( ) throws Exception {
Subscribable Channel channel = this . cxtSimpleBroker . getBean ( "webSocketRequestChannel" , Subscribable Channel. class ) ;
Test Channel channel = this . cxtSimpleBroker . getBean ( "webSocketRequestChannel" , Test Channel. class ) ;
SubProtocolWebSocketHandler webSocketHandler = this . cxtSimpleBroker . getBean ( SubProtocolWebSocketHandler . class ) ;
TextMessage textMessage = StompTextMessageBuilder . create ( StompCommand . SEND ) . headers ( "destination:/foo" ) . build ( ) ;
webSocketHandler . handleMessage ( new TestWebSocketSession ( ) , textMessage ) ;
ArgumentCaptor < Message > captor = ArgumentCaptor . forClass ( Message . class ) ;
verify ( channel ) . send ( captor . capture ( ) ) ;
Message message = captor . getValue ( ) ;
Message < ? > message = channel . messages . get ( 0 ) ;
StompHeaderAccessor headers = StompHeaderAccessor . wrap ( message ) ;
assertEquals ( SimpMessageType . MESSAGE , headers . getMessageType ( ) ) ;
@ -143,15 +131,17 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@@ -143,15 +131,17 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void webSocketResponseChannel ( ) {
SubscribableChannel channel = this . cxtSimpleBroker . getBean ( "webSocketResponseChannel" , SubscribableChannel . class ) ;
verify ( channel ) . subscribe ( any ( SubProtocolWebSocketHandler . class ) ) ;
verifyNoMoreInteractions ( channel ) ;
TestChannel channel = this . cxtSimpleBroker . getBean ( "webSocketResponseChannel" , TestChannel . class ) ;
List < MessageHandler > values = channel . handlers ;
assertEquals ( 1 , values . size ( ) ) ;
assertTrue ( values . get ( 0 ) instanceof SubProtocolWebSocketHandler ) ;
}
@Test
public void webSocketResponseChannelUsedByAnnotatedMethod ( ) {
Subscribable Channel channel = this . cxtSimpleBroker . getBean ( "webSocketResponseChannel" , Subscribable Channel. class ) ;
Test Channel channel = this . cxtSimpleBroker . getBean ( "webSocketResponseChannel" , Test Channel. class ) ;
SimpAnnotationMethodMessageHandler messageHandler = this . cxtSimpleBroker . getBean ( SimpAnnotationMethodMessageHandler . class ) ;
StompHeaderAccessor headers = StompHeaderAccessor . create ( StompCommand . SUBSCRIBE ) ;
@ -160,12 +150,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@@ -160,12 +150,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
headers . setDestination ( "/foo" ) ;
Message < ? > message = MessageBuilder . withPayload ( new byte [ 0 ] ) . setHeaders ( headers ) . build ( ) ;
when ( channel . send ( any ( Message . class ) ) ) . thenReturn ( true ) ;
messageHandler . handleMessage ( message ) ;
ArgumentCaptor < Message > captor = ArgumentCaptor . forClass ( Message . class ) ;
verify ( channel ) . send ( captor . capture ( ) ) ;
message = captor . getValue ( ) ;
message = channel . messages . get ( 0 ) ;
headers = StompHeaderAccessor . wrap ( message ) ;
assertEquals ( SimpMessageType . MESSAGE , headers . getMessageType ( ) ) ;
@ -175,7 +162,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@@ -175,7 +162,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void webSocketResponseChannelUsedBySimpleBroker ( ) {
Subscribable Channel channel = this . cxtSimpleBroker . getBean ( "webSocketResponseChannel" , Subscribable Channel. class ) ;
Test Channel channel = this . cxtSimpleBroker . getBean ( "webSocketResponseChannel" , Test Channel. class ) ;
SimpleBrokerMessageHandler broker = this . cxtSimpleBroker . getBean ( SimpleBrokerMessageHandler . class ) ;
StompHeaderAccessor headers = StompHeaderAccessor . create ( StompCommand . SUBSCRIBE ) ;
@ -193,12 +180,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@@ -193,12 +180,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
message = MessageBuilder . withPayload ( "bar" . getBytes ( ) ) . setHeaders ( headers ) . build ( ) ;
// message
when ( channel . send ( any ( Message . class ) ) ) . thenReturn ( true ) ;
broker . handleMessage ( message ) ;
ArgumentCaptor < Message > captor = ArgumentCaptor . forClass ( Message . class ) ;
verify ( channel ) . send ( captor . capture ( ) ) ;
message = captor . getValue ( ) ;
message = channel . messages . get ( 0 ) ;
headers = StompHeaderAccessor . wrap ( message ) ;
assertEquals ( SimpMessageType . MESSAGE , headers . getMessageType ( ) ) ;
@ -208,45 +192,36 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@@ -208,45 +192,36 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void brokerChannel ( ) {
SubscribableChannel channel = this . cxtSimpleBroker . getBean ( "brokerChannel" , SubscribableChannel . class ) ;
ArgumentCaptor < MessageHandler > captor = ArgumentCaptor . forClass ( MessageHandler . class ) ;
verify ( channel , times ( 2 ) ) . subscribe ( captor . capture ( ) ) ;
TestChannel channel = this . cxtSimpleBroker . getBean ( "brokerChannel" , TestChannel . class ) ;
List < MessageHandler > handlers = channel . handlers ;
List < MessageHandler > values = captor . getAllValues ( ) ;
assertEquals ( 2 , values . size ( ) ) ;
assertTrue ( values . contains ( cxtSimpleBroker . getBean ( UserDestinationMessageHandler . class ) ) ) ;
assertTrue ( values . contains ( cxtSimpleBroker . getBean ( SimpleBrokerMessageHandler . class ) ) ) ;
assertEquals ( 2 , handlers . size ( ) ) ;
assertTrue ( handlers . contains ( cxtSimpleBroker . getBean ( UserDestinationMessageHandler . class ) ) ) ;
assertTrue ( handlers . contains ( cxtSimpleBroker . getBean ( SimpleBrokerMessageHandler . class ) ) ) ;
}
@Test
public void brokerChannelWithStompBroker ( ) {
SubscribableChannel channel = this . cxtStompBroker . getBean ( "brokerChannel" , SubscribableChannel . class ) ;
ArgumentCaptor < MessageHandler > captor = ArgumentCaptor . forClass ( MessageHandler . class ) ;
verify ( channel , times ( 2 ) ) . subscribe ( captor . capture ( ) ) ;
TestChannel channel = this . cxtStompBroker . getBean ( "brokerChannel" , TestChannel . class ) ;
List < MessageHandler > handlers = channel . handlers ;
List < MessageHandler > values = captor . getAllValues ( ) ;
assertEquals ( 2 , values . size ( ) ) ;
assertTrue ( values . contains ( cxtStompBroker . getBean ( UserDestinationMessageHandler . class ) ) ) ;
assertTrue ( values . contains ( cxtStompBroker . getBean ( StompBrokerRelayMessageHandler . class ) ) ) ;
assertEquals ( 2 , handlers . size ( ) ) ;
assertTrue ( handlers . contains ( cxtStompBroker . getBean ( UserDestinationMessageHandler . class ) ) ) ;
assertTrue ( handlers . contains ( cxtStompBroker . getBean ( StompBrokerRelayMessageHandler . class ) ) ) ;
}
@Test
public void brokerChannelUsedByAnnotatedMethod ( ) {
Subscribable Channel channel = this . cxtSimpleBroker . getBean ( "brokerChannel" , Subscribable Channel. class ) ;
Test Channel channel = this . cxtSimpleBroker . getBean ( "brokerChannel" , Test Channel. class ) ;
SimpAnnotationMethodMessageHandler messageHandler = this . cxtSimpleBroker . getBean ( SimpAnnotationMethodMessageHandler . class ) ;
StompHeaderAccessor headers = StompHeaderAccessor . create ( StompCommand . SEND ) ;
headers . setDestination ( "/foo" ) ;
Message < ? > message = MessageBuilder . withPayload ( new byte [ 0 ] ) . setHeaders ( headers ) . build ( ) ;
when ( channel . send ( any ( Message . class ) ) ) . thenReturn ( true ) ;
messageHandler . handleMessage ( message ) ;
ArgumentCaptor < Message > captor = ArgumentCaptor . forClass ( Message . class ) ;
verify ( channel ) . send ( captor . capture ( ) ) ;
message = captor . getValue ( ) ;
message = channel . messages . get ( 0 ) ;
headers = StompHeaderAccessor . wrap ( message ) ;
assertEquals ( SimpMessageType . MESSAGE , headers . getMessageType ( ) ) ;
@ -256,7 +231,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@@ -256,7 +231,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void brokerChannelUsedByUserDestinationMessageHandler ( ) {
Subscribable Channel channel = this . cxtSimpleBroker . getBean ( "brokerChannel" , Subscribable Channel. class ) ;
Test Channel channel = this . cxtSimpleBroker . getBean ( "brokerChannel" , Test Channel. class ) ;
UserDestinationMessageHandler messageHandler = this . cxtSimpleBroker . getBean ( UserDestinationMessageHandler . class ) ;
this . cxtSimpleBroker . getBean ( UserSessionRegistry . class ) . registerSessionId ( "joe" , "s1" ) ;
@ -265,12 +240,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@@ -265,12 +240,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
headers . setDestination ( "/user/joe/foo" ) ;
Message < ? > message = MessageBuilder . withPayload ( new byte [ 0 ] ) . setHeaders ( headers ) . build ( ) ;
when ( channel . send ( any ( Message . class ) ) ) . thenReturn ( true ) ;
messageHandler . handleMessage ( message ) ;
ArgumentCaptor < Message > captor = ArgumentCaptor . forClass ( Message . class ) ;
verify ( channel ) . send ( captor . capture ( ) ) ;
message = captor . getValue ( ) ;
message = channel . messages . get ( 0 ) ;
headers = StompHeaderAccessor . wrap ( message ) ;
assertEquals ( SimpMessageType . MESSAGE , headers . getMessageType ( ) ) ;
@ -340,19 +312,39 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@@ -340,19 +312,39 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Override
@Bean
public SubscribableChannel webSocketRequestChannel ( ) {
return Mockito . mock ( SubscribableChannel . class ) ;
public Abstract SubscribableChannel webSocketRequestChannel ( ) {
return new TestChannel ( ) ;
}
@Override
@Bean
public SubscribableChannel webSocketResponseChannel ( ) {
return Mockito . mock ( SubscribableChannel . class ) ;
public AbstractSubscribableChannel webSocketResponseChannel ( ) {
return new TestChannel ( ) ;
}
@Override
public AbstractSubscribableChannel brokerChannel ( ) {
return new TestChannel ( ) ;
}
}
private static class TestChannel extends ExecutorSubscribableChannel {
private final List < MessageHandler > handlers = new ArrayList < > ( ) ;
private final List < Message < ? > > messages = new ArrayList < > ( ) ;
@Override
public boolean subscribeInternal ( MessageHandler handler ) {
this . handlers . add ( handler ) ;
return super . subscribeInternal ( handler ) ;
}
@Override
public SubscribableChannel brokerChannel ( ) {
return Mockito . mock ( SubscribableChannel . class ) ;
public boolean sendInternal ( Message < ? > message , long timeout ) {
this . messages . add ( message ) ;
return true ;
}
}