From 371d93b3463c5157f3c4b2b809084ddb00ae2b9e Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Thu, 25 Sep 2014 23:22:12 -0400 Subject: [PATCH] Detect unsent DISCONNECT messages This change uses a ChannelInterceptor (inserted at index 0) to detect when a DISCONNECT message is precluded from being sent on the clientInboundChannel. This can happen if another interceptor allows a runtime exception out from preSend or returns false. It is crucial for such messages to be processed, so when detected they're processed still. Issue: SPR-12218 --- .../broker/AbstractBrokerMessageHandler.java | 30 +++++++++++++++++++ .../support/AbstractMessageChannel.java | 14 +++++++++ .../MessageBrokerConfigurationTests.java | 2 +- ...essageBrokerBeanDefinitionParserTests.java | 6 ++-- 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractBrokerMessageHandler.java index 28004020bc3..4d55ac3fb67 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractBrokerMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractBrokerMessageHandler.java @@ -30,6 +30,11 @@ import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.SubscribableChannel; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.support.AbstractMessageChannel; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.ChannelInterceptorAdapter; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -69,6 +74,8 @@ public abstract class AbstractBrokerMessageHandler private final Object lifecycleMonitor = new Object(); + private ChannelInterceptor unsentDisconnectInterceptor = new UnsentDisconnectChannelInterceptor(); + /** * Constructor with no destination prefixes (matches all destinations). @@ -165,6 +172,9 @@ public abstract class AbstractBrokerMessageHandler } this.clientInboundChannel.subscribe(this); this.brokerChannel.subscribe(this); + if (this.clientInboundChannel instanceof AbstractMessageChannel) { + ((AbstractMessageChannel) this.clientInboundChannel).addInterceptor(0, this.unsentDisconnectInterceptor); + } startInternal(); this.running = true; if (logger.isInfoEnabled()) { @@ -185,6 +195,9 @@ public abstract class AbstractBrokerMessageHandler stopInternal(); this.clientInboundChannel.unsubscribe(this); this.brokerChannel.unsubscribe(this); + if (this.clientInboundChannel instanceof AbstractMessageChannel) { + ((AbstractMessageChannel) this.clientInboundChannel).removeInterceptor(this.unsentDisconnectInterceptor); + } this.running = false; if (logger.isDebugEnabled()) { logger.info("Stopped."); @@ -264,4 +277,21 @@ public abstract class AbstractBrokerMessageHandler } } + + /** + * Detect unsent DISCONNECT messages and process them anyway. + */ + private class UnsentDisconnectChannelInterceptor extends ChannelInterceptorAdapter { + + @Override + public void afterSendCompletion(Message message, MessageChannel channel, boolean sent, Exception ex) { + if (!sent) { + SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(message.getHeaders()); + if (SimpMessageType.DISCONNECT.equals(messageType)) { + logger.debug("Detected unsent DISCONNECT message. Processing anyway."); + handleMessage(message); + } + } + } + } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/AbstractMessageChannel.java b/spring-messaging/src/main/java/org/springframework/messaging/support/AbstractMessageChannel.java index 82090745b97..1f4f21b634d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/AbstractMessageChannel.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/AbstractMessageChannel.java @@ -81,6 +81,13 @@ public abstract class AbstractMessageChannel implements MessageChannel, BeanName this.interceptors.add(interceptor); } + /** + * Add a channel interceptor at the specified index. + */ + public void addInterceptor(int index, ChannelInterceptor interceptor) { + this.interceptors.add(index, interceptor); + } + /** * Return a read-only list of the configured interceptors. */ @@ -88,6 +95,13 @@ public abstract class AbstractMessageChannel implements MessageChannel, BeanName return Collections.unmodifiableList(this.interceptors); } + /** + * Remove the given interceptor. + */ + public boolean removeInterceptor(ChannelInterceptor interceptor) { + return this.interceptors.remove(interceptor); + } + @Override public final boolean send(Message message) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java index d2f596cc497..ce4aceb30db 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java @@ -130,7 +130,7 @@ public class MessageBrokerConfigurationTests { AbstractSubscribableChannel channel = this.customContext.getBean( "clientInboundChannel", AbstractSubscribableChannel.class); - assertEquals(1, channel.getInterceptors().size()); + assertEquals(2, channel.getInterceptors().size()); ThreadPoolTaskExecutor taskExecutor = this.customContext.getBean( "clientInboundChannelExecutor", ThreadPoolTaskExecutor.class); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index cdfb108425c..d02dad11f89 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -173,7 +173,7 @@ public class MessageBrokerBeanDefinitionParserTests { List> subscriberTypes = Arrays.>asList(SimpAnnotationMethodMessageHandler.class, UserDestinationMessageHandler.class, SimpleBrokerMessageHandler.class); - testChannel("clientInboundChannel", subscriberTypes, 0); + testChannel("clientInboundChannel", subscriberTypes, 1); testExecutor("clientInboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); @@ -241,7 +241,7 @@ public class MessageBrokerBeanDefinitionParserTests { List> subscriberTypes = Arrays.>asList(SimpAnnotationMethodMessageHandler.class, UserDestinationMessageHandler.class, StompBrokerRelayMessageHandler.class); - testChannel("clientInboundChannel", subscriberTypes, 0); + testChannel("clientInboundChannel", subscriberTypes, 1); testExecutor("clientInboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); @@ -314,7 +314,7 @@ public class MessageBrokerBeanDefinitionParserTests { Arrays.>asList(SimpAnnotationMethodMessageHandler.class, UserDestinationMessageHandler.class, SimpleBrokerMessageHandler.class); - testChannel("clientInboundChannel", subscriberTypes, 1); + testChannel("clientInboundChannel", subscriberTypes, 2); testExecutor("clientInboundChannel", 100, 200, 600); subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class);