diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubChannelRegistryBuilder.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubChannelRegistryBuilder.java index 58828c64023..9d4aac3c805 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubChannelRegistryBuilder.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubChannelRegistryBuilder.java @@ -17,7 +17,6 @@ package org.springframework.web.messaging.support; import java.util.HashSet; -import java.util.List; import java.util.Set; import org.springframework.messaging.Message; @@ -45,12 +44,15 @@ public class PubSubChannelRegistryBuilder { public PubSubChannelRegistryBuilder( + SubscribableChannel, MessageHandler>> clientInputChannel, SubscribableChannel, MessageHandler>> clientOutputChannel, MessageHandler> clientGateway) { + Assert.notNull(clientInputChannel, "clientInputChannel is required"); Assert.notNull(clientOutputChannel, "clientOutputChannel is required"); Assert.notNull(clientGateway, "clientGateway is required"); + this.clientInputChannel = clientInputChannel; this.clientOutputChannel = clientOutputChannel; this.clientOutputChannel.subscribe(clientGateway); this.messageHandlers.add(clientGateway); @@ -58,25 +60,17 @@ public class PubSubChannelRegistryBuilder { public static PubSubChannelRegistryBuilder clientGateway( + SubscribableChannel, MessageHandler>> clientInputChannel, SubscribableChannel, MessageHandler>> clientOutputChannel, MessageHandler> clientGateway) { - return new PubSubChannelRegistryBuilder(clientOutputChannel, clientGateway); + return new PubSubChannelRegistryBuilder(clientInputChannel, clientOutputChannel, clientGateway); } - public PubSubChannelRegistryBuilder clientMessageHandlers( - SubscribableChannel, MessageHandler>> clientInputChannel, - List>> handlers) { - - Assert.notNull(clientInputChannel, "clientInputChannel is required"); - this.clientInputChannel = clientInputChannel; - - for (MessageHandler> handler : handlers) { - this.clientInputChannel.subscribe(handler); - this.messageHandlers.add(handler); - } - + public PubSubChannelRegistryBuilder messageHandler(MessageHandler> handler) { + this.clientInputChannel.subscribe(handler); + this.messageHandlers.add(handler); return this; } @@ -87,6 +81,10 @@ public class PubSubChannelRegistryBuilder { Assert.notNull(messageBrokerChannel, "messageBrokerChannel is required"); Assert.notNull(messageBrokerGateway, "messageBrokerGateway is required"); + if (!this.messageHandlers.contains(messageBrokerGateway)) { + this.clientInputChannel.subscribe(messageBrokerGateway); + } + this.messageBrokerChannel = messageBrokerChannel; this.messageBrokerChannel.subscribe(messageBrokerGateway); this.messageHandlers.add(messageBrokerGateway);