diff --git a/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultControllerSpec.java b/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultControllerSpec.java index fe46da28205..9d61f4b346d 100644 --- a/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultControllerSpec.java +++ b/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultControllerSpec.java @@ -15,6 +15,7 @@ */ package org.springframework.test.web.reactive.server; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Optional; @@ -44,6 +45,8 @@ class DefaultControllerSpec implements WebTestClient.ControllerSpec { private final List controllers; + private final List controllerAdvice = new ArrayList<>(8); + private final TestWebFluxConfigurer configurer = new TestWebFluxConfigurer(); @@ -53,6 +56,12 @@ class DefaultControllerSpec implements WebTestClient.ControllerSpec { } + @Override + public DefaultControllerSpec controllerAdvice(Object... controllerAdvice) { + this.controllerAdvice.addAll(Arrays.asList(controllerAdvice)); + return this; + } + @Override public DefaultControllerSpec contentTypeResolver(Consumer consumer) { this.configurer.contentTypeResolverConsumer = consumer; @@ -103,12 +112,17 @@ class DefaultControllerSpec implements WebTestClient.ControllerSpec { @Override public WebTestClient.Builder configureClient() { + return WebTestClient.bindToApplicationContext(createApplicationContext()); + } + + protected AnnotationConfigApplicationContext createApplicationContext() { AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); this.controllers.forEach(controller -> registerBean(context, controller)); + this.controllerAdvice.forEach(advice -> registerBean(context, advice)); context.register(DelegatingWebFluxConfiguration.class); context.registerBean(WebFluxConfigurer.class, () -> this.configurer); context.refresh(); - return WebTestClient.bindToApplicationContext(context); + return context; } @SuppressWarnings("unchecked") diff --git a/spring-test/src/main/java/org/springframework/test/web/reactive/server/WebTestClient.java b/spring-test/src/main/java/org/springframework/test/web/reactive/server/WebTestClient.java index a3538fa2abd..c0c1a04423b 100644 --- a/spring-test/src/main/java/org/springframework/test/web/reactive/server/WebTestClient.java +++ b/spring-test/src/main/java/org/springframework/test/web/reactive/server/WebTestClient.java @@ -179,6 +179,13 @@ public interface WebTestClient { */ interface ControllerSpec { + /** + * Register one or more + * {@link org.springframework.web.bind.annotation.ControllerAdvice + * ControllerAdvice} instances to be used in tests. + */ + ControllerSpec controllerAdvice(Object... controllerAdvice); + /** * Customize content type resolution. * @see WebFluxConfigurer#configureContentTypeResolver diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/DefaultControllerSpecTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/DefaultControllerSpecTests.java new file mode 100644 index 00000000000..70879dcd2d9 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/DefaultControllerSpecTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.test.web.reactive.server; + +import org.junit.Test; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; + +import static org.junit.Assert.assertSame; + +/** + * Unit tests for {@link DefaultControllerSpec}. + * @author Rossen Stoyanchev + */ +public class DefaultControllerSpecTests { + + @Test + public void controllers() throws Exception { + OneController controller1 = new OneController(); + SecondController controller2 = new SecondController(); + + TestControllerSpec spec = new TestControllerSpec(controller1, controller2); + ApplicationContext context = spec.createApplicationContext(); + + assertSame(controller1, context.getBean(OneController.class)); + assertSame(controller2, context.getBean(SecondController.class)); + } + + @Test + public void controllerAdvice() throws Exception { + OneControllerAdvice advice = new OneControllerAdvice(); + + TestControllerSpec spec = new TestControllerSpec(new OneController()); + spec.controllerAdvice(advice); + ApplicationContext context = spec.createApplicationContext(); + + assertSame(advice, context.getBean(OneControllerAdvice.class)); + } + + private static class OneController {} + + private static class SecondController {} + + private static class OneControllerAdvice {} + + + private static class TestControllerSpec extends DefaultControllerSpec { + + TestControllerSpec(Object... controllers) { + super(controllers); + } + + @Override + public AnnotationConfigApplicationContext createApplicationContext() { + return super.createApplicationContext(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/ExceptionHandlerMethodResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/ExceptionHandlerMethodResolver.java index b56b3a4ddde..84558dde7f9 100644 --- a/spring-web/src/main/java/org/springframework/web/method/annotation/ExceptionHandlerMethodResolver.java +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/ExceptionHandlerMethodResolver.java @@ -125,6 +125,17 @@ public class ExceptionHandlerMethodResolver { * @return a Method to handle the exception, or {@code null} if none found */ public Method resolveMethod(Exception exception) { + return resolveMethod(exception); + } + + /** + * Find a {@link Method} to handle the given Throwable. + * Use {@link ExceptionDepthComparator} if more than one match is found. + * @param exception the exception + * @return a Method to handle the exception, or {@code null} if none found + * @since 5.0 + */ + public Method resolveMethodByThrowable(Throwable exception) { Method method = resolveMethodByExceptionType(exception.getClass()); if (method == null) { Throwable cause = exception.getCause(); diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java index d8bac642ff6..9541f910320 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java @@ -18,24 +18,26 @@ package org.springframework.web.reactive.result.method.annotation; import java.lang.reflect.Method; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; -import java.util.stream.Collectors; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; -import org.springframework.beans.BeansException; -import org.springframework.beans.factory.BeanFactory; -import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.config.ConfigurableBeanFactory; -import org.springframework.core.MethodIntrospector; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.context.ConfigurableApplicationContext; import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.codec.ByteArrayDecoder; import org.springframework.core.codec.ByteBufferDecoder; @@ -43,11 +45,13 @@ import org.springframework.core.codec.DataBufferDecoder; import org.springframework.core.codec.StringDecoder; import org.springframework.http.codec.DecoderHttpMessageReader; import org.springframework.http.codec.HttpMessageReader; +import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; import org.springframework.web.bind.annotation.InitBinder; import org.springframework.web.bind.annotation.ModelAttribute; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.support.WebBindingInitializer; +import org.springframework.web.method.ControllerAdviceBean; import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.annotation.ExceptionHandlerMethodResolver; import org.springframework.web.reactive.BindingContext; @@ -59,13 +63,15 @@ import org.springframework.web.reactive.result.method.SyncHandlerMethodArgumentR import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod; import org.springframework.web.server.ServerWebExchange; +import static org.springframework.core.MethodIntrospector.selectMethods; + /** * Supports the invocation of {@code @RequestMapping} methods. * * @author Rossen Stoyanchev * @since 5.0 */ -public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactoryAware, InitializingBean { +public class RequestMappingHandlerAdapter implements HandlerAdapter, ApplicationContextAware, InitializingBean { private static final Log logger = LogFactory.getLog(RequestMappingHandlerAdapter.class); @@ -84,10 +90,8 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory private List initBinderArgumentResolvers; - private ConfigurableBeanFactory beanFactory; - + private ConfigurableApplicationContext applicationContext; - private ModelInitializer modelInitializer; private final Map, Set> binderMethodCache = new ConcurrentHashMap<>(64); @@ -97,6 +101,18 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory new ConcurrentHashMap<>(64); + private final Map> binderAdviceCache = new LinkedHashMap<>(64); + + private final Map> attributeAdviceCache = new LinkedHashMap<>(64); + + private final Map exceptionHandlerAdviceCache = + new LinkedHashMap<>(64); + + + private ModelInitializer modelInitializer; + + + public RequestMappingHandlerAdapter() { this.messageReaders.add(new DecoderHttpMessageReader<>(new ByteArrayDecoder())); this.messageReaders.add(new DecoderHttpMessageReader<>(new ByteBufferDecoder())); @@ -204,25 +220,30 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory return this.initBinderArgumentResolvers; } - /** - * A {@link ConfigurableBeanFactory} is expected for resolving expressions - * in method argument default values. + * A {@link ConfigurableApplicationContext} is expected for resolving + * expressions in method argument default values as well as for + * detecting {@code @ControllerAdvice} beans. */ @Override - public void setBeanFactory(BeanFactory beanFactory) throws BeansException { - if (beanFactory instanceof ConfigurableBeanFactory) { - this.beanFactory = (ConfigurableBeanFactory) beanFactory; + public void setApplicationContext(ApplicationContext applicationContext) { + if (applicationContext instanceof ConfigurableApplicationContext) { + this.applicationContext = (ConfigurableApplicationContext) applicationContext; } } + public ConfigurableApplicationContext getApplicationContext() { + return this.applicationContext; + } + public ConfigurableBeanFactory getBeanFactory() { - return this.beanFactory; + return this.applicationContext.getBeanFactory(); } @Override public void afterPropertiesSet() throws Exception { + initControllerAdviceCache(); if (this.argumentResolvers == null) { this.argumentResolvers = getDefaultArgumentResolvers(); } @@ -232,6 +253,43 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory this.modelInitializer = new ModelInitializer(getReactiveAdapterRegistry()); } + private void initControllerAdviceCache() { + if (getApplicationContext() == null) { + return; + } + if (logger.isInfoEnabled()) { + logger.info("Looking for @ControllerAdvice: " + getApplicationContext()); + } + + List beans = ControllerAdviceBean.findAnnotatedBeans(getApplicationContext()); + AnnotationAwareOrderComparator.sort(beans); + + for (ControllerAdviceBean bean : beans) { + Class beanType = bean.getBeanType(); + Set attrMethods = selectMethods(beanType, ATTRIBUTE_METHODS); + if (!attrMethods.isEmpty()) { + this.attributeAdviceCache.put(bean, attrMethods); + if (logger.isInfoEnabled()) { + logger.info("Detected @ModelAttribute methods in " + bean); + } + } + Set binderMethods = selectMethods(beanType, BINDER_METHODS); + if (!binderMethods.isEmpty()) { + this.binderAdviceCache.put(bean, binderMethods); + if (logger.isInfoEnabled()) { + logger.info("Detected @InitBinder methods in " + bean); + } + } + ExceptionHandlerMethodResolver resolver = new ExceptionHandlerMethodResolver(beanType); + if (resolver.hasExceptionMappings()) { + this.exceptionHandlerAdviceCache.put(bean, resolver); + if (logger.isInfoEnabled()) { + logger.info("Detected @ExceptionHandler methods in " + bean); + } + } + } + } + protected List getDefaultArgumentResolvers() { List resolvers = new ArrayList<>(); @@ -305,80 +363,97 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory @Override public Mono handle(ServerWebExchange exchange, Object handler) { + Assert.notNull(handler, "Expected handler"); HandlerMethod handlerMethod = (HandlerMethod) handler; BindingContext bindingContext = new InitBinderBindingContext( getWebBindingInitializer(), getBinderMethods(handlerMethod)); - Mono modelCompletion = this.modelInitializer.initModel( - bindingContext, getAttributeMethods(handlerMethod), exchange); - - Function> exceptionHandler = - ex -> handleException(ex, handlerMethod, bindingContext, exchange); - - return modelCompletion.then(() -> { + return this.modelInitializer + .initModel(bindingContext, getAttributeMethods(handlerMethod), exchange) + .then(() -> { + Function> exceptionHandler = + ex -> handleException(exchange, handlerMethod, bindingContext, ex); - InvocableHandlerMethod invocable = new InvocableHandlerMethod(handlerMethod); - invocable.setArgumentResolvers(getArgumentResolvers()); + InvocableHandlerMethod invocable = new InvocableHandlerMethod(handlerMethod); + invocable.setArgumentResolvers(getArgumentResolvers()); - return invocable.invoke(exchange, bindingContext) - .doOnNext(result -> result.setExceptionHandler(exceptionHandler)) - .otherwise(exceptionHandler); - }); + return invocable.invoke(exchange, bindingContext) + .doOnNext(result -> result.setExceptionHandler(exceptionHandler)) + .otherwise(exceptionHandler); + }); } private List getBinderMethods(HandlerMethod handlerMethod) { + List result = new ArrayList<>(); Class handlerType = handlerMethod.getBeanType(); - Set methods = this.binderMethodCache.computeIfAbsent(handlerType, aClass -> - MethodIntrospector.selectMethods(handlerType, BINDER_METHODS)); + // Global methods first + this.binderAdviceCache.entrySet().forEach(entry -> { + if (entry.getKey().isApplicableToBeanType(handlerType)) { + Object bean = entry.getKey().resolveBean(); + entry.getValue().forEach(method -> result.add(createBinderMethod(bean, method))); + } + }); - return methods.stream() - .map(method -> { + this.binderMethodCache + .computeIfAbsent(handlerType, aClass -> selectMethods(handlerType, BINDER_METHODS)) + .forEach(method -> { Object bean = handlerMethod.getBean(); - SyncInvocableHandlerMethod invocable = new SyncInvocableHandlerMethod(bean, method); - invocable.setSyncArgumentResolvers(getInitBinderArgumentResolvers()); - return invocable; - }) - .collect(Collectors.toList()); + result.add(createBinderMethod(bean, method)); + }); + + return result; + } + + private SyncInvocableHandlerMethod createBinderMethod(Object bean, Method method) { + SyncInvocableHandlerMethod invocable = new SyncInvocableHandlerMethod(bean, method); + invocable.setSyncArgumentResolvers(getInitBinderArgumentResolvers()); + return invocable; } private List getAttributeMethods(HandlerMethod handlerMethod) { + List result = new ArrayList<>(); Class handlerType = handlerMethod.getBeanType(); - Set methods = this.attributeMethodCache.computeIfAbsent(handlerType, aClass -> - MethodIntrospector.selectMethods(handlerType, ATTRIBUTE_METHODS)); + // Global methods first + this.attributeAdviceCache.entrySet().forEach(entry -> { + if (entry.getKey().isApplicableToBeanType(handlerType)) { + Object bean = entry.getKey().resolveBean(); + entry.getValue().forEach(method -> result.add(createHandlerMethod(bean, method))); + } + }); - return methods.stream() - .map(method -> { + this.attributeMethodCache + .computeIfAbsent(handlerType, aClass -> selectMethods(handlerType, ATTRIBUTE_METHODS)) + .forEach(method -> { Object bean = handlerMethod.getBean(); - InvocableHandlerMethod invocable = new InvocableHandlerMethod(bean, method); - invocable.setArgumentResolvers(getArgumentResolvers()); - return invocable; - }) - .collect(Collectors.toList()); - } + result.add(createHandlerMethod(bean, method)); + }); - private Mono handleException(Throwable ex, HandlerMethod handlerMethod, - BindingContext bindingContext, ServerWebExchange exchange) { + return result; + } - ExceptionHandlerMethodResolver resolver = this.exceptionHandlerCache - .computeIfAbsent(handlerMethod.getBeanType(), ExceptionHandlerMethodResolver::new); + private InvocableHandlerMethod createHandlerMethod(Object bean, Method method) { + InvocableHandlerMethod invocable = new InvocableHandlerMethod(bean, method); + invocable.setArgumentResolvers(getArgumentResolvers()); + return invocable; + } - Method method = resolver.resolveMethodByExceptionType(ex.getClass()); + private Mono handleException(ServerWebExchange exchange, HandlerMethod handlerMethod, + BindingContext bindingContext, Throwable ex) { - if (method != null) { - Object bean = handlerMethod.getBean(); - InvocableHandlerMethod invocable = new InvocableHandlerMethod(bean, method); - invocable.setArgumentResolvers(getArgumentResolvers()); + InvocableHandlerMethod invocable = getExceptionHandlerMethod(ex, handlerMethod); + if (invocable != null) { try { if (logger.isDebugEnabled()) { logger.debug("Invoking @ExceptionHandler method: " + invocable.getMethod()); } bindingContext.getModel().asMap().clear(); - return invocable.invoke(exchange, bindingContext, ex); + Throwable cause = ex.getCause() != null ? ex.getCause() : ex; + return invocable.invoke(exchange, bindingContext, cause, handlerMethod); } catch (Throwable invocationEx) { if (logger.isWarnEnabled()) { @@ -386,10 +461,36 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, BeanFactory } } } - return Mono.error(ex); } + private InvocableHandlerMethod getExceptionHandlerMethod(Throwable ex, HandlerMethod handlerMethod) { + + Class handlerType = handlerMethod.getBeanType(); + + ExceptionHandlerMethodResolver resolver = this.exceptionHandlerCache + .computeIfAbsent(handlerType, ExceptionHandlerMethodResolver::new); + + return Optional + .ofNullable(resolver.resolveMethodByThrowable(ex)) + .map(method -> createHandlerMethod(handlerMethod.getBean(), method)) + .orElseGet(() -> + this.exceptionHandlerAdviceCache.entrySet().stream() + .map(entry -> { + if (entry.getKey().isApplicableToBeanType(handlerType)) { + Method method = entry.getValue().resolveMethodByThrowable(ex); + if (method != null) { + Object bean = entry.getKey().resolveBean(); + return createHandlerMethod(bean, method); + } + } + return null; + }) + .filter(Objects::nonNull) + .findFirst() + .orElse(null)); + } + /** * MethodFilter that matches {@link InitBinder @InitBinder} methods. diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ControllerAdviceTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ControllerAdviceTests.java new file mode 100644 index 00000000000..31f41a697d7 --- /dev/null +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ControllerAdviceTests.java @@ -0,0 +1,246 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.reactive.result.method.annotation; + +import java.lang.reflect.Method; +import java.time.Duration; +import java.util.Collections; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.beans.FatalBeanException; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.annotation.Order; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.stereotype.Controller; +import org.springframework.ui.Model; +import org.springframework.util.ClassUtils; +import org.springframework.validation.Validator; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.ControllerAdvice; +import org.springframework.web.bind.annotation.ExceptionHandler; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.InitBinder; +import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.bind.support.WebExchangeDataBinder; +import org.springframework.web.method.HandlerMethod; +import org.springframework.web.reactive.BindingContext; +import org.springframework.web.reactive.HandlerResult; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.adapter.DefaultServerWebExchange; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; + +/** + * {@code @ControllerAdvice} related tests for {@link RequestMappingHandlerAdapter}. + * @author Rossen Stoyanchev + */ +public class ControllerAdviceTests { + + private ServerWebExchange exchange; + + + @Before + public void setUp() throws Exception { + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerHttpResponse response = new MockServerHttpResponse(); + this.exchange = new DefaultServerWebExchange(request, response); + } + + + @Test + public void resolveExceptionGlobalHandler() throws Exception { + testException(new IllegalAccessException(), "SecondControllerAdvice: IllegalAccessException"); + } + + @Test + public void resolveExceptionGlobalHandlerOrdered() throws Exception { + testException(new IllegalStateException(), "OneControllerAdvice: IllegalStateException"); + } + + @Test // SPR-12605 + public void resolveExceptionWithHandlerMethodArg() throws Exception { + testException(new ArrayIndexOutOfBoundsException(), "HandlerMethod: handle"); + } + + @Test + public void resolveExceptionWithAssertionError() throws Exception { + AssertionError error = new AssertionError("argh"); + testException(error, error.toString()); + } + + @Test + public void resolveExceptionWithAssertionErrorAsRootCause() throws Exception { + AssertionError cause = new AssertionError("argh"); + FatalBeanException exception = new FatalBeanException("wrapped", cause); + testException(exception, cause.toString()); + } + + private void testException(Throwable exception, String expected) throws Exception { + ApplicationContext context = new AnnotationConfigApplicationContext(TestConfig.class); + RequestMappingHandlerAdapter adapter = createAdapter(context); + + TestController controller = context.getBean(TestController.class); + controller.setException(exception); + + Object actual = handle(adapter, controller, "handle").getReturnValue().orElse(null); + assertEquals(expected, actual); + } + + @Test + public void modelAttributeAdvice() throws Exception { + ApplicationContext context = new AnnotationConfigApplicationContext(TestConfig.class); + RequestMappingHandlerAdapter adapter = createAdapter(context); + TestController controller = context.getBean(TestController.class); + + Model model = handle(adapter, controller, "handle").getModel(); + + assertEquals(2, model.asMap().size()); + assertEquals("lAttr1", model.asMap().get("attr1")); + assertEquals("gAttr2", model.asMap().get("attr2")); + } + + @Test + public void initBinderAdvice() throws Exception { + ApplicationContext context = new AnnotationConfigApplicationContext(TestConfig.class); + RequestMappingHandlerAdapter adapter = createAdapter(context); + TestController controller = context.getBean(TestController.class); + + Validator validator = mock(Validator.class); + controller.setValidator(validator); + + BindingContext bindingContext = handle(adapter, controller, "handle").getBindingContext(); + + WebExchangeDataBinder binder = bindingContext.createDataBinder(this.exchange, "name"); + assertEquals(Collections.singletonList(validator), binder.getValidators()); + } + + + private RequestMappingHandlerAdapter createAdapter(ApplicationContext context) throws Exception { + RequestMappingHandlerAdapter adapter = new RequestMappingHandlerAdapter(); + adapter.setApplicationContext(context); + adapter.afterPropertiesSet(); + return adapter; + } + + private HandlerResult handle(RequestMappingHandlerAdapter adapter, + Object controller, String methodName) throws Exception { + + Method method = controller.getClass().getMethod(methodName); + HandlerMethod handlerMethod = new HandlerMethod(controller, method); + return adapter.handle(this.exchange, handlerMethod).block(Duration.ZERO); + } + + + @Configuration + static class TestConfig { + + @Bean + public TestController testController() { + return new TestController(); + } + + @Bean + public OneControllerAdvice testExceptionResolver() { + return new OneControllerAdvice(); + } + + @Bean + public SecondControllerAdvice anotherTestExceptionResolver() { + return new SecondControllerAdvice(); + } + } + + @Controller + static class TestController { + + private Validator validator; + + private Throwable exception; + + + void setValidator(Validator validator) { + this.validator = validator; + } + + void setException(Throwable exception) { + this.exception = exception; + } + + + @InitBinder + public void initDataBinder(WebDataBinder dataBinder) { + if (this.validator != null) { + dataBinder.addValidators(this.validator); + } + } + + @ModelAttribute + public void addAttributes(Model model) { + model.addAttribute("attr1", "lAttr1"); + } + + @GetMapping + public void handle() throws Throwable { + if (this.exception != null) { + throw this.exception; + } + } + } + + @ControllerAdvice + @Order(1) + static class OneControllerAdvice { + + @ModelAttribute + public void addAttributes(Model model) { + model.addAttribute("attr1", "gAttr1"); + model.addAttribute("attr2", "gAttr2"); + } + + @ExceptionHandler + public String handleException(IllegalStateException ex) { + return "OneControllerAdvice: " + ClassUtils.getShortName(ex.getClass()); + } + + @ExceptionHandler(ArrayIndexOutOfBoundsException.class) + public String handleWithHandlerMethod(HandlerMethod handlerMethod) { + return "HandlerMethod: " + handlerMethod.getMethod().getName(); + } + + @ExceptionHandler(AssertionError.class) + public String handleAssertionError(Error err) { + return err.toString(); + } + } + + @ControllerAdvice + @Order(2) + static class SecondControllerAdvice { + + @ExceptionHandler({IllegalStateException.class, IllegalAccessException.class}) + public String handleException(Exception ex) { + return "SecondControllerAdvice: " + ClassUtils.getShortName(ex.getClass()); + } + } + +} diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ModelInitializerTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ModelInitializerTests.java index 4b10e248d54..6e7ab164b12 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ModelInitializerTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ModelInitializerTests.java @@ -31,7 +31,6 @@ import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import org.springframework.ui.Model; -import org.springframework.util.ObjectUtils; import org.springframework.validation.Validator; import org.springframework.web.bind.WebDataBinder; import org.springframework.web.bind.annotation.InitBinder; @@ -48,8 +47,8 @@ import org.springframework.web.server.adapter.DefaultServerWebExchange; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; -import static org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter.BINDER_METHODS; import static org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter.ATTRIBUTE_METHODS; +import static org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter.BINDER_METHODS; /** * Unit tests for {@link ModelInitializer}. @@ -76,8 +75,10 @@ public class ModelInitializerTests { @SuppressWarnings("unchecked") @Test public void basic() throws Exception { + TestController controller = new TestController(); + Validator validator = mock(Validator.class); - Object controller = new TestController(validator); + controller.setValidator(validator); List binderMethods = getBinderMethods(controller); List attributeMethods = getAttributeMethods(controller); @@ -131,16 +132,18 @@ public class ModelInitializerTests { @SuppressWarnings("unused") private static class TestController { - private Validator[] validators; + private Validator validator; - public TestController(Validator... validators) { - this.validators = validators; + + void setValidator(Validator validator) { + this.validator = validator; } + @InitBinder public void initDataBinder(WebDataBinder dataBinder) { - if (!ObjectUtils.isEmpty(this.validators)) { - dataBinder.addValidators(this.validators); + if (this.validator != null) { + dataBinder.addValidators(this.validator); } }