diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/AbstractMethodMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/AbstractMethodMessageHandler.java index 112bbb6b80a..6fb0b3c3d00 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/AbstractMethodMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/AbstractMethodMessageHandler.java @@ -71,9 +71,9 @@ public abstract class AbstractMethodMessageHandler private Collection destinationPrefixes = new ArrayList(); - private List customArgumentResolvers = new ArrayList(); + private final List customArgumentResolvers = new ArrayList(4); - private List customReturnValueHandlers = new ArrayList(); + private final List customReturnValueHandlers = new ArrayList(4); private HandlerMethodArgumentResolverComposite argumentResolvers = new HandlerMethodArgumentResolverComposite(); @@ -121,10 +121,15 @@ public abstract class AbstractMethodMessageHandler * @param customArgumentResolvers the list of resolvers; never {@code null}. */ public void setCustomArgumentResolvers(List customArgumentResolvers) { - Assert.notNull(customArgumentResolvers, "The 'customArgumentResolvers' cannot be null."); - this.customArgumentResolvers = customArgumentResolvers; + this.customArgumentResolvers.clear(); + if (customArgumentResolvers != null) { + this.customArgumentResolvers.addAll(customArgumentResolvers); + } } + /** + * Return the configured custom argument resolvers, if any. + */ public List getCustomArgumentResolvers() { return this.customArgumentResolvers; } @@ -135,10 +140,15 @@ public abstract class AbstractMethodMessageHandler * @param customReturnValueHandlers the list of custom return value handlers, never {@code null}. */ public void setCustomReturnValueHandlers(List customReturnValueHandlers) { - Assert.notNull(customReturnValueHandlers, "The 'customReturnValueHandlers' cannot be null."); - this.customReturnValueHandlers = customReturnValueHandlers; + this.customReturnValueHandlers.clear(); + if (customReturnValueHandlers != null) { + this.customReturnValueHandlers.addAll(customReturnValueHandlers); + } } + /** + * Return the configured custom return value handlers, if any. + */ public List getCustomReturnValueHandlers() { return this.customReturnValueHandlers; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java index a3914a1818f..be198b1b839 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java @@ -27,6 +27,8 @@ import org.springframework.context.ApplicationContextAware; import org.springframework.context.annotation.Bean; import org.springframework.messaging.Message; import org.springframework.messaging.converter.*; +import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; +import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler; import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler; @@ -40,6 +42,7 @@ import org.springframework.messaging.support.AbstractSubscribableChannel; import org.springframework.messaging.support.ExecutorSubscribableChannel; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; import org.springframework.util.MimeTypeUtils; import org.springframework.util.PathMatcher; import org.springframework.validation.Errors; @@ -213,6 +216,14 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC handler.setMessageConverter(brokerMessageConverter()); handler.setValidator(simpValidator()); + List argumentResolvers = new ArrayList(); + addArgumentResolvers(argumentResolvers); + handler.setCustomArgumentResolvers(argumentResolvers); + + List returnValueHandlers = new ArrayList(); + addReturnValueHandlers(returnValueHandlers); + handler.setCustomReturnValueHandlers(returnValueHandlers); + PathMatcher pathMatcher = this.getBrokerRegistry().getPathMatcher(); if (pathMatcher != null) { handler.setPathMatcher(pathMatcher); @@ -220,6 +231,12 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC return handler; } + protected void addArgumentResolvers(List argumentResolvers) { + } + + protected void addReturnValueHandlers(List returnValueHandlers) { + } + @Bean public AbstractBrokerMessageHandler simpleBrokerMessageHandler() { SimpleBrokerMessageHandler handler = getBrokerRegistry().getSimpleBroker(brokerChannel()); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java index 73e4fd362a6..d2f596cc497 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java @@ -26,8 +26,6 @@ import org.hamcrest.Matchers; import org.junit.Before; import org.junit.Test; -import org.mockito.Mockito; - import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -37,6 +35,8 @@ import org.springframework.messaging.MessageHandler; import org.springframework.messaging.converter.*; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.SendTo; +import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; +import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.SubscribeMapping; import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler; @@ -61,6 +61,7 @@ import org.springframework.validation.Validator; import org.springframework.validation.beanvalidation.OptionalValidatorFactoryBean; import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; /** * Test fixture for {@link AbstractMessageBrokerConfiguration}. @@ -77,9 +78,7 @@ public class MessageBrokerConfigurationTests { private AnnotationConfigApplicationContext defaultContext; - private AnnotationConfigApplicationContext customChannelContext; - - private AnnotationConfigApplicationContext customPathMatcherContext; + private AnnotationConfigApplicationContext customContext; @Before @@ -97,13 +96,9 @@ public class MessageBrokerConfigurationTests { this.defaultContext.register(DefaultConfig.class); this.defaultContext.refresh(); - this.customChannelContext = new AnnotationConfigApplicationContext(); - this.customChannelContext.register(CustomChannelConfig.class); - this.customChannelContext.refresh(); - - this.customPathMatcherContext = new AnnotationConfigApplicationContext(); - this.customPathMatcherContext.register(CustomPathMatcherConfig.class); - this.customPathMatcherContext.refresh(); + this.customContext = new AnnotationConfigApplicationContext(); + this.customContext.register(CustomConfig.class); + this.customContext.refresh(); } @@ -132,12 +127,12 @@ public class MessageBrokerConfigurationTests { @Test public void clientInboundChannelCustomized() { - AbstractSubscribableChannel channel = this.customChannelContext.getBean( + AbstractSubscribableChannel channel = this.customContext.getBean( "clientInboundChannel", AbstractSubscribableChannel.class); assertEquals(1, channel.getInterceptors().size()); - ThreadPoolTaskExecutor taskExecutor = this.customChannelContext.getBean( + ThreadPoolTaskExecutor taskExecutor = this.customContext.getBean( "clientInboundChannelExecutor", ThreadPoolTaskExecutor.class); assertEquals(11, taskExecutor.getCorePoolSize()); @@ -200,12 +195,12 @@ public class MessageBrokerConfigurationTests { @Test public void clientOutboundChannelCustomized() { - AbstractSubscribableChannel channel = this.customChannelContext.getBean( + AbstractSubscribableChannel channel = this.customContext.getBean( "clientOutboundChannel", AbstractSubscribableChannel.class); assertEquals(2, channel.getInterceptors().size()); - ThreadPoolTaskExecutor taskExecutor = this.customChannelContext.getBean( + ThreadPoolTaskExecutor taskExecutor = this.customContext.getBean( "clientOutboundChannelExecutor", ThreadPoolTaskExecutor.class); assertEquals(21, taskExecutor.getCorePoolSize()); @@ -280,12 +275,12 @@ public class MessageBrokerConfigurationTests { @Test public void brokerChannelCustomized() { - AbstractSubscribableChannel channel = this.customChannelContext.getBean( + AbstractSubscribableChannel channel = this.customContext.getBean( "brokerChannel", AbstractSubscribableChannel.class); assertEquals(3, channel.getInterceptors().size()); - ThreadPoolTaskExecutor taskExecutor = this.customChannelContext.getBean( + ThreadPoolTaskExecutor taskExecutor = this.customContext.getBean( "brokerChannelExecutor", ThreadPoolTaskExecutor.class); assertEquals(31, taskExecutor.getCorePoolSize()); @@ -328,7 +323,7 @@ public class MessageBrokerConfigurationTests { @Test public void configureMessageConvertersCustom() { - final MessageConverter testConverter = Mockito.mock(MessageConverter.class); + final MessageConverter testConverter = mock(MessageConverter.class); AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() { @Override protected boolean configureMessageConverters(List messageConverters) { @@ -346,7 +341,7 @@ public class MessageBrokerConfigurationTests { @Test public void configureMessageConvertersCustomAndDefault() { - final MessageConverter testConverter = Mockito.mock(MessageConverter.class); + final MessageConverter testConverter = mock(MessageConverter.class); AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() { @Override @@ -365,6 +360,19 @@ public class MessageBrokerConfigurationTests { assertThat(iterator.next(), Matchers.instanceOf(MappingJackson2MessageConverter.class)); } + @Test + public void customArgumentAndReturnValueTypes() throws Exception { + SimpAnnotationMethodMessageHandler handler = this.customContext.getBean(SimpAnnotationMethodMessageHandler.class); + + List customResolvers = handler.getCustomArgumentResolvers(); + assertEquals(1, customResolvers.size()); + assertTrue(handler.getArgumentResolvers().contains(customResolvers.get(0))); + + List customHandlers = handler.getCustomReturnValueHandlers(); + assertEquals(1, customHandlers.size()); + assertTrue(handler.getReturnValueHandlers().contains(customHandlers.get(0))); + } + @Test public void simpValidatorDefault() { AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() {}; @@ -376,7 +384,7 @@ public class MessageBrokerConfigurationTests { @Test public void simpValidatorCustom() { - final Validator validator = Mockito.mock(Validator.class); + final Validator validator = mock(Validator.class); AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() { @Override public Validator getValidator() { @@ -408,11 +416,11 @@ public class MessageBrokerConfigurationTests { @Test public void customPathMatcher() { - SimpleBrokerMessageHandler broker = this.customPathMatcherContext.getBean(SimpleBrokerMessageHandler.class); + SimpleBrokerMessageHandler broker = this.customContext.getBean(SimpleBrokerMessageHandler.class); DefaultSubscriptionRegistry registry = (DefaultSubscriptionRegistry) broker.getSubscriptionRegistry(); assertEquals("a.a", registry.getPathMatcher().combine("a", "a")); - SimpAnnotationMethodMessageHandler handler = this.customPathMatcherContext.getBean(SimpAnnotationMethodMessageHandler.class); + SimpAnnotationMethodMessageHandler handler = this.customContext.getBean(SimpAnnotationMethodMessageHandler.class); assertEquals("a.a", handler.getPathMatcher().combine("a", "a")); } @@ -474,7 +482,7 @@ public class MessageBrokerConfigurationTests { } @Configuration - static class CustomChannelConfig extends AbstractMessageBrokerConfiguration { + static class CustomConfig extends AbstractMessageBrokerConfiguration { private ChannelInterceptor interceptor = new ChannelInterceptorAdapter() {}; @@ -491,19 +499,19 @@ public class MessageBrokerConfigurationTests { } @Override - protected void configureMessageBroker(MessageBrokerRegistry registry) { - registry.configureBrokerChannel().setInterceptors( - this.interceptor, this.interceptor, this.interceptor); - registry.configureBrokerChannel().taskExecutor() - .corePoolSize(31).maxPoolSize(32).keepAliveSeconds(33).queueCapacity(34); + protected void addArgumentResolvers(List argumentResolvers) { + argumentResolvers.add(mock(HandlerMethodArgumentResolver.class)); } - } - @Configuration - static class CustomPathMatcherConfig extends SimpleBrokerConfig { + @Override + protected void addReturnValueHandlers(List returnValueHandlers) { + returnValueHandlers.add(mock(HandlerMethodReturnValueHandler.class)); + } @Override - public void configureMessageBroker(MessageBrokerRegistry registry) { + protected void configureMessageBroker(MessageBrokerRegistry registry) { + registry.configureBrokerChannel().setInterceptors(this.interceptor, this.interceptor, this.interceptor); + registry.configureBrokerChannel().taskExecutor().corePoolSize(31).maxPoolSize(32).keepAliveSeconds(33).queueCapacity(34); registry.setPathMatcher(new AntPathMatcher(".")).enableSimpleBroker("/topic", "/queue"); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java index 0b9c04da22c..58e7d945e7f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java @@ -401,9 +401,30 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { String pathMatcherRef = messageBrokerElement.getAttribute("path-matcher"); beanDef.getPropertyValues().add("pathMatcher", new RuntimeBeanReference(pathMatcherRef)); } + + Element resolversElement = DomUtils.getChildElementByTagName(messageBrokerElement, "argument-resolvers"); + if (resolversElement != null) { + values.add("customArgumentResolvers", extractBeanSubElements(resolversElement, context)); + } + + Element handlersElement = DomUtils.getChildElementByTagName(messageBrokerElement, "return-value-handlers"); + if (handlersElement != null) { + values.add("customReturnValueHandlers", extractBeanSubElements(handlersElement, context)); + } + registerBeanDef(beanDef, context, source); } + private ManagedList extractBeanSubElements(Element parentElement, ParserContext parserContext) { + ManagedList list = new ManagedList(); + list.setSource(parserContext.extractSource(parentElement)); + for (Element beanElement : DomUtils.getChildElementsByTagName(parentElement, "bean", "ref")) { + Object object = parserContext.getDelegate().parsePropertySubElement(beanElement, null); + list.add(object); + } + return list; + } + private RuntimeBeanReference registerUserDestinationResolver(Element brokerElem, RuntimeBeanReference userSessionRegistry, ParserContext context, Object source) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketMessageBrokerConfigurer.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketMessageBrokerConfigurer.java index f3b0e101c44..f744e468206 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketMessageBrokerConfigurer.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketMessageBrokerConfigurer.java @@ -17,6 +17,8 @@ package org.springframework.web.socket.config.annotation; import org.springframework.messaging.converter.MessageConverter; +import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; +import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; @@ -49,6 +51,14 @@ public abstract class AbstractWebSocketMessageBrokerConfigurer implements WebSoc return true; } + @Override + public void addArgumentResolvers(List argumentResolvers) { + } + + @Override + public void addReturnValueHandlers(List returnValueHandlers) { + } + @Override public void configureMessageBroker(MessageBrokerRegistry registry) { } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/DelegatingWebSocketMessageBrokerConfiguration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/DelegatingWebSocketMessageBrokerConfiguration.java index c7917a3008e..53ad7edd55d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/DelegatingWebSocketMessageBrokerConfiguration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/DelegatingWebSocketMessageBrokerConfiguration.java @@ -22,6 +22,8 @@ import java.util.List; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.messaging.converter.MessageConverter; +import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; +import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.util.CollectionUtils; @@ -79,6 +81,20 @@ public class DelegatingWebSocketMessageBrokerConfiguration extends WebSocketMess } } + @Override + protected void addArgumentResolvers(List argumentResolvers) { + for (WebSocketMessageBrokerConfigurer c : this.configurers) { + c.addArgumentResolvers(argumentResolvers); + } + } + + @Override + protected void addReturnValueHandlers(List returnValueHandlers) { + for (WebSocketMessageBrokerConfigurer c : this.configurers) { + c.addReturnValueHandlers(returnValueHandlers); + } + } + @Override protected boolean configureMessageConverters(List messageConverters) { boolean registerDefaults = true; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurer.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurer.java index 74c41d128ec..786d2dfdcb1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurer.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurer.java @@ -18,6 +18,8 @@ package org.springframework.web.socket.config.annotation; import org.springframework.messaging.converter.MessageConverter; +import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; +import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; @@ -62,6 +64,26 @@ public interface WebSocketMessageBrokerConfigurer { */ void configureClientOutboundChannel(ChannelRegistration registration); + /** + * Add resolvers to support custom controller method argument types. + *

This does not override the built-in support for resolving handler + * method arguments. To customize the built-in support for argument + * resolution, configure {@code SimpAnnotationMethodMessageHandler} directly. + * @param argumentResolvers initially an empty list + * @since 4.1.1 + */ + void addArgumentResolvers(List argumentResolvers); + + /** + * Add handlers to support custom controller method return value types. + *

Using this option does not override the built-in support for handling + * return values. To customize the built-in support for handling return + * values, configure {@code SimpAnnotationMethodMessageHandler} directly. + * @param returnValueHandlers initially an empty list + * @since 4.1.1 + */ + void addReturnValueHandlers(List returnValueHandlers); + /** * Configure the message converters to use when extracting the payload of * messages in annotated methods and when sending messages (e.g. through the diff --git a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd index ea3dfa0fe86..59576246271 100644 --- a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd +++ b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd @@ -615,6 +615,70 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + customResolvers = handler.getCustomArgumentResolvers(); + assertEquals(2, customResolvers.size()); + assertTrue(handler.getArgumentResolvers().contains(customResolvers.get(0))); + assertTrue(handler.getArgumentResolvers().contains(customResolvers.get(1))); + + List customHandlers = handler.getCustomReturnValueHandlers(); + assertEquals(2, customHandlers.size()); + assertTrue(handler.getReturnValueHandlers().contains(customHandlers.get(0))); + assertTrue(handler.getReturnValueHandlers().contains(customHandlers.get(1))); + } + @Test public void messageConverters() { loadBeanDefinitions("websocket-config-broker-converters.xml"); @@ -396,3 +417,30 @@ public class MessageBrokerBeanDefinitionParserTests { } } + +class CustomArgumentResolver implements HandlerMethodArgumentResolver { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return false; + } + + @Override + public Object resolveArgument(MethodParameter parameter, Message message) throws Exception { + return null; + } + +} + +class CustomReturnValueHandler implements HandlerMethodReturnValueHandler { + + @Override + public boolean supportsReturnType(MethodParameter returnType) { + return false; + } + + @Override + public void handleReturnValue(Object returnValue, MethodParameter returnType, Message message) throws Exception { + + } +} \ No newline at end of file diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-custom-argument-and-return-value-types.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-custom-argument-and-return-value-types.xml new file mode 100644 index 00000000000..0fede774e83 --- /dev/null +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-custom-argument-and-return-value-types.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + + + + + + + + + + + + + + + +