diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java index 2337814fcad..ae8df8f91a2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java @@ -24,7 +24,6 @@ import java.util.Map; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.atomic.AtomicBoolean; import org.springframework.http.MediaType; import org.springframework.messaging.Message; @@ -171,10 +170,11 @@ public class StompRelayPubSubMessageHandler extends AbstractP private final Promise> promise; - private final AtomicBoolean isConnected = new AtomicBoolean(false); - private final BlockingQueue messageQueue = new LinkedBlockingQueue(50); + private final Object monitor = new Object(); + + private boolean isConnected = false; public RelaySession(final M message, final StompHeaders stompHeaders) { @@ -224,8 +224,10 @@ public class StompRelayPubSubMessageHandler extends AbstractP StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders()); if (StompCommand.CONNECTED == headers.getStompCommand()) { - this.isConnected.set(true); - flushMessages(promise.get()); + synchronized(this.monitor) { + this.isConnected = true; + flushMessages(promise.get()); + } return; } if (StompCommand.ERROR == headers.getStompCommand()) { @@ -248,14 +250,14 @@ public class StompRelayPubSubMessageHandler extends AbstractP public void forward(M message, StompHeaders headers) { - if (!this.isConnected.get()) { - @SuppressWarnings("unchecked") - M m = (M) MessageBuilder.fromPayloadAndHeaders(message.getPayload(), headers.toMessageHeaders()).build(); - if (logger.isTraceEnabled()) { - logger.trace("Adding to queue message " + m + ", queue size=" + this.messageQueue.size()); + synchronized(this.monitor) { + if (!this.isConnected) { + if (logger.isTraceEnabled()) { + logger.trace("Adding to queue message " + message + ", queue size=" + this.messageQueue.size()); + } + this.messageQueue.add(message); + return; } - this.messageQueue.add(m); - return; } TcpConnection connection = this.promise.get();