diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/AbstractMessageChannel.java b/spring-messaging/src/main/java/org/springframework/messaging/support/AbstractMessageChannel.java index 4324b5bbab2..9c421eca3ae 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/AbstractMessageChannel.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/AbstractMessageChannel.java @@ -16,6 +16,8 @@ package org.springframework.messaging.support; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import org.apache.commons.logging.Log; @@ -39,7 +41,7 @@ public abstract class AbstractMessageChannel implements MessageChannel, BeanName protected final Log logger = LogFactory.getLog(getClass()); - private final ChannelInterceptorChain interceptorChain = new ChannelInterceptorChain(); + private final List interceptors = new ArrayList(5); private String beanName; @@ -68,28 +70,22 @@ public abstract class AbstractMessageChannel implements MessageChannel, BeanName * Set the list of channel interceptors. This will clear any existing interceptors. */ public void setInterceptors(List interceptors) { - this.interceptorChain.set(interceptors); + this.interceptors.clear(); + this.interceptors.addAll(interceptors); } /** * Add a channel interceptor to the end of the list. */ public void addInterceptor(ChannelInterceptor interceptor) { - this.interceptorChain.add(interceptor); + this.interceptors.add(interceptor); } /** * Return a read-only list of the configured interceptors. */ public List getInterceptors() { - return this.interceptorChain.getInterceptors(); - } - - /** - * Exposes the interceptor list for subclasses. - */ - protected ChannelInterceptorChain getInterceptorChain() { - return this.interceptorChain; + return Collections.unmodifiableList(this.interceptors); } @@ -101,20 +97,29 @@ public abstract class AbstractMessageChannel implements MessageChannel, BeanName @Override public final boolean send(Message message, long timeout) { Assert.notNull(message, "Message must not be null"); - message = this.interceptorChain.preSend(message, this); - if (message == null) { - return false; - } + ChannelInterceptorChain chain = new ChannelInterceptorChain(); + boolean sent = false; try { - boolean sent = sendInternal(message, timeout); - this.interceptorChain.postSend(message, this, sent); + message = chain.applyPreSend(message, this); + if (message == null) { + return false; + } + sent = sendInternal(message, timeout); + chain.applyPostSend(message, this, sent); + chain.triggerAfterSendCompletion(message, this, sent, null); return sent; } - catch (Exception e) { - if (e instanceof MessagingException) { - throw (MessagingException) e; + catch (Exception ex) { + chain.triggerAfterSendCompletion(message, this, sent, ex); + if (ex instanceof MessagingException) { + throw (MessagingException) ex; } - throw new MessageDeliveryException(message,"Failed to send message to " + this, e); + throw new MessageDeliveryException(message,"Failed to send message to " + this, ex); + } + catch (Error ex) { + MessageDeliveryException ex2 = new MessageDeliveryException(message, "Failed to send message to " + this, ex); + chain.triggerAfterSendCompletion(message, this, sent, ex2); + throw ex2; } } @@ -123,7 +128,85 @@ public abstract class AbstractMessageChannel implements MessageChannel, BeanName @Override public String toString() { - return "MessageChannel[name=" + this.beanName + "]"; + return getClass().getSimpleName() + "[" + this.beanName + "]"; + } + + + /** + * Assists with the invocation of the configured channel interceptors. + */ + protected class ChannelInterceptorChain { + + private int sendInterceptorIndex = -1; + + private int receiveInterceptorIndex = -1; + + + public Message applyPreSend(Message message, MessageChannel channel) { + for (ChannelInterceptor interceptor : interceptors) { + message = interceptor.preSend(message, channel); + if (message == null) { + String name = interceptor.getClass().getSimpleName(); + logger.debug(name + " returned null from preSend, i.e. precluding the send."); + triggerAfterSendCompletion(message, channel, false, null); + return null; + } + this.sendInterceptorIndex++; + } + return message; + } + + public void applyPostSend(Message message, MessageChannel channel, boolean sent) { + for (ChannelInterceptor interceptor : interceptors) { + interceptor.postSend(message, channel, sent); + } + } + + public void triggerAfterSendCompletion(Message message, MessageChannel channel, boolean sent, Exception ex) { + for (int i = this.sendInterceptorIndex; i >= 0; i--) { + ChannelInterceptor interceptor = interceptors.get(i); + try { + interceptor.afterSendCompletion(message, channel, sent, ex); + } + catch (Throwable ex2) { + logger.error("Exception from afterSendCompletion in " + interceptor, ex2); + } + } + } + + public boolean applyPreReceive(MessageChannel channel) { + for (ChannelInterceptor interceptor : interceptors) { + if (!interceptor.preReceive(channel)) { + triggerAfterReceiveCompletion(null, channel, null); + return false; + } + this.receiveInterceptorIndex++; + } + return true; + } + + public Message applyPostReceive(Message message, MessageChannel channel) { + for (ChannelInterceptor interceptor : interceptors) { + message = interceptor.postReceive(message, channel); + if (message == null) { + return null; + } + } + return message; + } + + public void triggerAfterReceiveCompletion(Message message, MessageChannel channel, Exception ex) { + for (int i = this.receiveInterceptorIndex; i >= 0; i--) { + ChannelInterceptor interceptor = interceptors.get(i); + try { + interceptor.afterReceiveCompletion(message, channel, ex); + } + catch (Throwable ex2) { + logger.error("Exception from afterReceiveCompletion in " + interceptor, ex2); + } + } + } + } -} +} \ No newline at end of file diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/ChannelInterceptor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/ChannelInterceptor.java index 7eba5514499..0f2f2f98c8e 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/ChannelInterceptor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/ChannelInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2010 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import org.springframework.messaging.MessageChannel; * {@link MessageChannel}. * * @author Mark Fisher + * @author Rossen Stoyanchev * @since 4.0 */ public interface ChannelInterceptor { @@ -32,7 +33,7 @@ public interface ChannelInterceptor { /** * Invoked before the Message is actually sent to the channel. * This allows for modification of the Message if necessary. - * If this method returns null, then the actual + * If this method returns {@code null} then the actual * send invocation will not occur. */ Message preSend(Message message, MessageChannel channel); @@ -43,6 +44,15 @@ public interface ChannelInterceptor { */ void postSend(Message message, MessageChannel channel, boolean sent); + /** + * Invoked after the completion of a send regardless of any exception that + * have been raised thus allowing for proper resource cleanup. + *

Note that this will be invoked only if preSend successfully completed + * and returned a Message, i.e. it did not return {@code null}. + * @since 4.1 + */ + void afterSendCompletion(Message message, MessageChannel channel, boolean sent, Exception ex); + /** * Invoked as soon as receive is called and before a Message is * actually retrieved. If the return value is 'false', then no @@ -57,4 +67,13 @@ public interface ChannelInterceptor { */ Message postReceive(Message message, MessageChannel channel); + /** + * Invoked after the completion of a receive regardless of any exception that + * have been raised thus allowing for proper resource cleanup. + *

Note that this will be invoked only if preReceive successfully + * completed and returned {@code true}. + * @since 4.1 + */ + void afterReceiveCompletion(Message message, MessageChannel channel, Exception ex); + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/ChannelInterceptorAdapter.java b/spring-messaging/src/main/java/org/springframework/messaging/support/ChannelInterceptorAdapter.java index cb0f8154550..10994797a76 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/ChannelInterceptorAdapter.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/ChannelInterceptorAdapter.java @@ -24,23 +24,35 @@ import org.springframework.messaging.MessageChannel; * as a convenience. * * @author Mark Fisher + * @author Rossen Stoyanchev * @since 4.0 */ public abstract class ChannelInterceptorAdapter implements ChannelInterceptor { + @Override public Message preSend(Message message, MessageChannel channel) { return message; } + @Override public void postSend(Message message, MessageChannel channel, boolean sent) { } + @Override + public void afterSendCompletion(Message message, MessageChannel channel, boolean sent, Exception ex) { + } + public boolean preReceive(MessageChannel channel) { return true; } + @Override public Message postReceive(Message message, MessageChannel channel) { return message; } + @Override + public void afterReceiveCompletion(Message message, MessageChannel channel, Exception ex) { + } + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/ChannelInterceptorChain.java b/spring-messaging/src/main/java/org/springframework/messaging/support/ChannelInterceptorChain.java deleted file mode 100644 index 7fd9e900cda..00000000000 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/ChannelInterceptorChain.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright 2002-2014 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.messaging.support; - -import java.util.Collections; -import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.messaging.Message; -import org.springframework.messaging.MessageChannel; - -/** - * A convenience wrapper class for invoking a list of {@link ChannelInterceptor}s. - * - * @author Mark Fisher - * @author Rossen Stoyanchev - * @since 4.0 - */ -class ChannelInterceptorChain { - - private static final Log logger = LogFactory.getLog(ChannelInterceptorChain.class); - - private final List interceptors = new CopyOnWriteArrayList(); - - - public boolean set(List interceptors) { - synchronized (this.interceptors) { - this.interceptors.clear(); - return this.interceptors.addAll(interceptors); - } - } - - public boolean add(ChannelInterceptor interceptor) { - return this.interceptors.add(interceptor); - } - - public List getInterceptors() { - return Collections.unmodifiableList(this.interceptors); - } - - - public Message preSend(Message message, MessageChannel channel) { - for (ChannelInterceptor interceptor : this.interceptors) { - message = interceptor.preSend(message, channel); - if (message == null) { - logger.debug("preSend returned null precluding send"); - return null; - } - } - return message; - } - - public void postSend(Message message, MessageChannel channel, boolean sent) { - for (ChannelInterceptor interceptor : this.interceptors) { - interceptor.postSend(message, channel, sent); - } - } - - public boolean preReceive(MessageChannel channel) { - for (ChannelInterceptor interceptor : this.interceptors) { - if (!interceptor.preReceive(channel)) { - return false; - } - } - return true; - } - - public Message postReceive(Message message, MessageChannel channel) { - for (ChannelInterceptor interceptor : this.interceptors) { - message = interceptor.postReceive(message, channel); - if (message == null) { - return null; - } - } - return message; - } - -} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/support/ChannelInterceptorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/support/ChannelInterceptorTests.java index 8a4aa734a8f..d01c6892130 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/support/ChannelInterceptorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/support/ChannelInterceptorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessagingException; import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; /** * Test fixture for the use of {@link ChannelInterceptor}s. @@ -51,45 +52,61 @@ public class ChannelInterceptorTests { @Test public void preSendInterceptorReturningModifiedMessage() { - - this.channel.addInterceptor(new PreSendReturnsMessageInterceptor()); + Message expected = mock(Message.class); + PreSendInterceptor interceptor = new PreSendInterceptor(); + interceptor.setMessageToReturn(expected); + this.channel.addInterceptor(interceptor); this.channel.send(MessageBuilder.withPayload("test").build()); - assertEquals(1, this.messageHandler.messages.size()); - Message result = this.messageHandler.messages.get(0); + assertEquals(1, this.messageHandler.getMessages().size()); + Message result = this.messageHandler.getMessages().get(0); assertNotNull(result); - assertEquals("test", result.getPayload()); - assertEquals(1, result.getHeaders().get(PreSendReturnsMessageInterceptor.class.getSimpleName())); + assertSame(expected, result); + assertTrue(interceptor.wasAfterCompletionInvoked()); } @Test public void preSendInterceptorReturningNull() { - - PreSendReturnsNullInterceptor interceptor = new PreSendReturnsNullInterceptor(); - this.channel.addInterceptor(interceptor); + PreSendInterceptor interceptor1 = new PreSendInterceptor(); + NullReturningPreSendInterceptor interceptor2 = new NullReturningPreSendInterceptor(); + this.channel.addInterceptor(interceptor1); + this.channel.addInterceptor(interceptor2); Message message = MessageBuilder.withPayload("test").build(); this.channel.send(message); - assertEquals(1, interceptor.counter.get()); - assertEquals(0, this.messageHandler.messages.size()); + assertEquals(1, interceptor1.getCounter().get()); + assertEquals(1, interceptor2.getCounter().get()); + assertEquals(0, this.messageHandler.getMessages().size()); + assertTrue(interceptor1.wasAfterCompletionInvoked()); + assertFalse(interceptor2.wasAfterCompletionInvoked()); } @Test public void postSendInterceptorMessageWasSent() { - final AtomicBoolean invoked = new AtomicBoolean(false); + final AtomicBoolean preSendInvoked = new AtomicBoolean(false); + final AtomicBoolean completionInvoked = new AtomicBoolean(false); this.channel.addInterceptor(new ChannelInterceptorAdapter() { @Override public void postSend(Message message, MessageChannel channel, boolean sent) { + assertInput(message, channel, sent); + preSendInvoked.set(true); + } + @Override + public void afterSendCompletion(Message message, MessageChannel channel, boolean sent, Exception ex) { + assertInput(message, channel, sent); + completionInvoked.set(true); + } + private void assertInput(Message message, MessageChannel channel, boolean sent) { assertNotNull(message); assertNotNull(channel); assertSame(ChannelInterceptorTests.this.channel, channel); assertTrue(sent); - invoked.set(true); } }); this.channel.send(MessageBuilder.withPayload("test").build()); - assertTrue(invoked.get()); + assertTrue(preSendInvoked.get()); + assertTrue(completionInvoked.get()); } @Test @@ -100,19 +117,68 @@ public class ChannelInterceptorTests { return false; } }; - final AtomicBoolean invoked = new AtomicBoolean(false); + final AtomicBoolean preSendInvoked = new AtomicBoolean(false); + final AtomicBoolean completionInvoked = new AtomicBoolean(false); testChannel.addInterceptor(new ChannelInterceptorAdapter() { @Override public void postSend(Message message, MessageChannel channel, boolean sent) { + assertInput(message, channel, sent); + preSendInvoked.set(true); + } + @Override + public void afterSendCompletion(Message message, MessageChannel channel, boolean sent, Exception ex) { + assertInput(message, channel, sent); + completionInvoked.set(true); + } + private void assertInput(Message message, MessageChannel channel, boolean sent) { assertNotNull(message); assertNotNull(channel); assertSame(testChannel, channel); assertFalse(sent); - invoked.set(true); } }); testChannel.send(MessageBuilder.withPayload("test").build()); - assertTrue(invoked.get()); + assertTrue(preSendInvoked.get()); + assertTrue(completionInvoked.get()); + } + + @Test + public void afterCompletionWithSendException() { + final AbstractMessageChannel testChannel = new AbstractMessageChannel() { + @Override + protected boolean sendInternal(Message message, long timeout) { + throw new RuntimeException("Simulated exception"); + } + }; + PreSendInterceptor interceptor1 = new PreSendInterceptor(); + PreSendInterceptor interceptor2 = new PreSendInterceptor(); + testChannel.addInterceptor(interceptor1); + testChannel.addInterceptor(interceptor2); + try { + testChannel.send(MessageBuilder.withPayload("test").build()); + } + catch (Exception ex) { + assertEquals("Simulated exception", ex.getCause().getMessage()); + } + assertTrue(interceptor1.wasAfterCompletionInvoked()); + assertTrue(interceptor2.wasAfterCompletionInvoked()); + } + + @Test + public void afterCompletionWithPreSendException() { + PreSendInterceptor interceptor1 = new PreSendInterceptor(); + PreSendInterceptor interceptor2 = new PreSendInterceptor(); + interceptor2.setExceptionToRaise(new RuntimeException("Simulated exception")); + this.channel.addInterceptor(interceptor1); + this.channel.addInterceptor(interceptor2); + try { + this.channel.send(MessageBuilder.withPayload("test").build()); + } + catch (Exception ex) { + assertEquals("Simulated exception", ex.getCause().getMessage()); + } + assertTrue(interceptor1.wasAfterCompletionInvoked()); + assertFalse(interceptor2.wasAfterCompletionInvoked()); } @@ -120,32 +186,75 @@ public class ChannelInterceptorTests { private List> messages = new ArrayList>(); + + public List> getMessages() { + return this.messages; + } + @Override public void handleMessage(Message message) throws MessagingException { - this.messages.add(message); + this.getMessages().add(message); } } - private static class PreSendReturnsMessageInterceptor extends ChannelInterceptorAdapter { + private abstract static class AbstractTestInterceptor extends ChannelInterceptorAdapter { private AtomicInteger counter = new AtomicInteger(); + private volatile boolean afterCompletionInvoked; + + + public AtomicInteger getCounter() { + return this.counter; + } + + public boolean wasAfterCompletionInvoked() { + return this.afterCompletionInvoked; + } + @Override public Message preSend(Message message, MessageChannel channel) { assertNotNull(message); - return MessageBuilder.fromMessage(message).setHeader( - this.getClass().getSimpleName(), counter.incrementAndGet()).build(); + counter.incrementAndGet(); + return message; + } + + @Override + public void afterSendCompletion(Message message, MessageChannel channel, boolean sent, Exception ex) { + this.afterCompletionInvoked = true; } } - private static class PreSendReturnsNullInterceptor extends ChannelInterceptorAdapter { + private static class PreSendInterceptor extends AbstractTestInterceptor { + + private Message messageToReturn; + + private RuntimeException exceptionToRaise; - private AtomicInteger counter = new AtomicInteger(); + + public void setMessageToReturn(Message messageToReturn) { + this.messageToReturn = messageToReturn; + } + + public void setExceptionToRaise(RuntimeException exception) { + this.exceptionToRaise = exception; + } @Override public Message preSend(Message message, MessageChannel channel) { - assertNotNull(message); - counter.incrementAndGet(); + super.preSend(message, channel); + if (this.exceptionToRaise != null) { + throw this.exceptionToRaise; + } + return (this.messageToReturn != null ? this.messageToReturn : message); + } + } + + private static class NullReturningPreSendInterceptor extends AbstractTestInterceptor { + + @Override + public Message preSend(Message message, MessageChannel channel) { + super.preSend(message, channel); return null; } }