From f4cf55cb2b189d0071eff6834a3f7080ce0b000a Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Wed, 10 May 2017 14:35:32 +0200 Subject: [PATCH] Add support for WebFilter and WebExceptionHandler This commit adds support for configuring `WebFilter` and `WebExceptionHandler` instances in HandlerStrategies. It also drops the "native" support for `ResponseStatusException`s, in favor of the `ResponseStatusExceptionHandler`, which is registered by default. Issue: SPR-15518 --- .../DefaultHandlerStrategiesBuilder.java | 44 ++++++++++++++++++- .../function/server/DefaultServerRequest.java | 9 ++-- .../function/server/HandlerStrategies.java | 30 +++++++++++++ .../function/server/RouterFunctions.java | 31 +++++-------- .../server/DefaultServerRequestTests.java | 35 +++++---------- .../server/ResourceHandlerFunctionTests.java | 9 ++-- .../function/server/RouterFunctionsTests.java | 34 ++++++++++++++ 7 files changed, 138 insertions(+), 54 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultHandlerStrategiesBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultHandlerStrategiesBuilder.java index 2ff119d80aa..939ed1fc4e6 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultHandlerStrategiesBuilder.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultHandlerStrategiesBuilder.java @@ -32,6 +32,9 @@ import org.springframework.http.codec.HttpMessageWriter; import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.util.Assert; import org.springframework.web.reactive.result.view.ViewResolver; +import org.springframework.web.server.WebExceptionHandler; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.handler.ResponseStatusExceptionHandler; /** * Default implementation of {@link HandlerStrategies.Builder}. @@ -53,6 +56,9 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder { private Function> localeResolver; + private final List webFilters = new ArrayList<>(); + + private final List exceptionHandlers = new ArrayList<>(); public DefaultHandlerStrategiesBuilder() { @@ -62,6 +68,7 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder { public void defaultConfiguration() { this.codecConfigurer.registerDefaults(true); localeResolver(DEFAULT_LOCALE_RESOLVER); + exceptionHandler(new ResponseStatusExceptionHandler()); } @Override @@ -94,10 +101,25 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder { return this; } + @Override + public HandlerStrategies.Builder webFilter(WebFilter filter) { + Assert.notNull(filter, "'filter' must not be null"); + this.webFilters.add(filter); + return this; + } + + @Override + public HandlerStrategies.Builder exceptionHandler(WebExceptionHandler exceptionHandler) { + Assert.notNull(exceptionHandler, "'exceptionHandler' must not be null"); + this.exceptionHandlers.add(exceptionHandler); + return this; + } + @Override public HandlerStrategies build() { return new DefaultHandlerStrategies(this.codecConfigurer.getReaders(), - this.codecConfigurer.getWriters(), this.viewResolvers, this.localeResolver); + this.codecConfigurer.getWriters(), this.viewResolvers, this.localeResolver, + this.webFilters, this.exceptionHandlers); } @@ -111,16 +133,24 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder { private final Function> localeResolver; + private final List webFilters; + + private final List exceptionHandlers; + public DefaultHandlerStrategies( List> messageReaders, List> messageWriters, List viewResolvers, - Function> localeResolver) { + Function> localeResolver, + List webFilters, + List exceptionHandlers) { this.messageReaders = unmodifiableCopy(messageReaders); this.messageWriters = unmodifiableCopy(messageWriters); this.viewResolvers = unmodifiableCopy(viewResolvers); this.localeResolver = localeResolver; + this.webFilters = unmodifiableCopy(webFilters); + this.exceptionHandlers = unmodifiableCopy(exceptionHandlers); } private static List unmodifiableCopy(List list) { @@ -146,6 +176,16 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder { public Supplier>> localeResolver() { return () -> this.localeResolver; } + + @Override + public Supplier> webFilters() { + return this.webFilters::stream; + } + + @Override + public Supplier> exceptionHandlers() { + return this.exceptionHandlers::stream; + } } } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequest.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequest.java index b40e3c34dba..ecc87d06187 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequest.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequest.java @@ -65,12 +65,13 @@ class DefaultServerRequest implements ServerRequest { private final Headers headers; - private final HandlerStrategies strategies; + private final Supplier>> messageReaders; - DefaultServerRequest(ServerWebExchange exchange, HandlerStrategies strategies) { + DefaultServerRequest(ServerWebExchange exchange, + Supplier>> messageReaders) { this.exchange = exchange; - this.strategies = strategies; + this.messageReaders = messageReaders; this.headers = new DefaultHeaders(); } @@ -102,7 +103,7 @@ class DefaultServerRequest implements ServerRequest { new BodyExtractor.Context() { @Override public Supplier>> messageReaders() { - return DefaultServerRequest.this.strategies.messageReaders(); + return DefaultServerRequest.this.messageReaders; } @Override diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerStrategies.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerStrategies.java index 14e37996ad8..4fc70435c22 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerStrategies.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerStrategies.java @@ -27,6 +27,8 @@ import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.HttpMessageWriter; import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.web.reactive.result.view.ViewResolver; +import org.springframework.web.server.WebExceptionHandler; +import org.springframework.web.server.WebFilter; /** * Defines the strategies to be used for processing {@link HandlerFunction}s. An instance of @@ -71,6 +73,20 @@ public interface HandlerStrategies { */ Supplier>> localeResolver(); + /** + * Supply a {@linkplain Stream stream} of {@link WebFilter}s to be used for filtering the + * request and response. + * @return the stream of web filters + */ + Supplier> webFilters(); + + /** + * Supply a {@linkplain Stream stream} of {@link WebExceptionHandler}s to be used for handling + * exceptions. + * @return the stream of exception handlers + */ + Supplier> exceptionHandlers(); + // Static methods @@ -138,6 +154,20 @@ public interface HandlerStrategies { */ Builder localeResolver(Function> localeResolver); + /** + * Add the given web filter to this builder. + * @param filter the filter to add + * @return this builder + */ + Builder webFilter(WebFilter filter); + + /** + * Add the given exception handler to this builder. + * @param exceptionHandler the exception handler to add + * @return this builder + */ + Builder exceptionHandler(WebExceptionHandler exceptionHandler); + /** * Builds the {@link HandlerStrategies}. * @return the built strategies diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java index a09e8061f42..8932c31e28a 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java @@ -30,10 +30,8 @@ import org.springframework.util.Assert; import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.reactive.function.server.support.HandlerFunctionAdapter; import org.springframework.web.reactive.function.server.support.ServerResponseResultHandler; -import org.springframework.web.server.ResponseStatusException; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebHandler; -import org.springframework.web.server.adapter.HttpWebHandlerAdapter; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; /** @@ -197,7 +195,7 @@ public abstract class RouterFunctions { * @param routerFunction the router function to convert * @return an http handler that handles HTTP request using the given router function */ - public static HttpWebHandlerAdapter toHttpHandler(RouterFunction routerFunction) { + public static HttpHandler toHttpHandler(RouterFunction routerFunction) { return toHttpHandler(routerFunction, HandlerStrategies.withDefaults()); } @@ -213,32 +211,27 @@ public abstract class RouterFunctions { *
  • Undertow using the * {@link org.springframework.http.server.reactive.UndertowHttpHandlerAdapter}.
  • * - *

    Note that {@code HttpWebHandlerAdapter} also implements {@link WebHandler}, allowing - * for additional filter and exception handler registration through * @param routerFunction the router function to convert * @param strategies the strategies to use * @return an http handler that handles HTTP request using the given router function */ - public static HttpWebHandlerAdapter toHttpHandler(RouterFunction routerFunction, HandlerStrategies strategies) { + public static HttpHandler toHttpHandler(RouterFunction routerFunction, HandlerStrategies strategies) { Assert.notNull(routerFunction, "RouterFunction must not be null"); Assert.notNull(strategies, "HandlerStrategies must not be null"); - return new HttpWebHandlerAdapter(exchange -> { - ServerRequest request = new DefaultServerRequest(exchange, strategies); + WebHandler webHandler = exchange -> { + ServerRequest request = new DefaultServerRequest(exchange, strategies.messageReaders()); addAttributes(exchange, request); return routerFunction.route(request) .defaultIfEmpty(notFound()) .flatMap(handlerFunction -> wrapException(() -> handlerFunction.handle(request))) - .flatMap(response -> wrapException(() -> response.writeTo(exchange, strategies))) - .onErrorResume(ResponseStatusException.class, - ex -> { - exchange.getResponse().setStatusCode(ex.getStatus()); - if (ex.getMessage() != null) { - logger.error(ex.getMessage()); - } - return Mono.empty(); - }); - }); + .flatMap(response -> wrapException(() -> response.writeTo(exchange, strategies))); + }; + + WebHttpHandlerBuilder handlerBuilder = WebHttpHandlerBuilder.webHandler(webHandler); + strategies.webFilters().get().forEach(handlerBuilder::filter); + strategies.exceptionHandlers().get().forEach(handlerBuilder::exceptionHandler); + return handlerBuilder.build(); } private static Mono wrapException(Supplier> supplier) { @@ -280,7 +273,7 @@ public abstract class RouterFunctions { Assert.notNull(strategies, "HandlerStrategies must not be null"); return exchange -> { - ServerRequest request = new DefaultServerRequest(exchange, strategies); + ServerRequest request = new DefaultServerRequest(exchange, strategies.messageReaders()); addAttributes(exchange, request); return routerFunction.route(request).map(handlerFunction -> (Object)handlerFunction); }; diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/DefaultServerRequestTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/DefaultServerRequestTests.java index 65a2cd58eb2..f00d89fc4e2 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/DefaultServerRequestTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/DefaultServerRequestTests.java @@ -26,7 +26,8 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalLong; -import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Stream; import org.junit.Before; import org.junit.Test; @@ -52,9 +53,8 @@ import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.UnsupportedMediaTypeStatusException; import org.springframework.web.server.WebSession; -import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; import static org.springframework.web.reactive.function.BodyExtractors.toMono; /** @@ -64,11 +64,9 @@ public class DefaultServerRequestTests { private ServerHttpRequest mockRequest; - private ServerHttpResponse mockResponse; - private ServerWebExchange mockExchange; - private HandlerStrategies mockHandlerStrategies; + Supplier>> messageReaders; private DefaultServerRequest defaultRequest; @@ -76,14 +74,15 @@ public class DefaultServerRequestTests { @Before public void createMocks() { mockRequest = mock(ServerHttpRequest.class); - mockResponse = mock(ServerHttpResponse.class); + ServerHttpResponse mockResponse = mock(ServerHttpResponse.class); mockExchange = mock(ServerWebExchange.class); when(mockExchange.getRequest()).thenReturn(mockRequest); when(mockExchange.getResponse()).thenReturn(mockResponse); - mockHandlerStrategies = mock(HandlerStrategies.class); - defaultRequest = new DefaultServerRequest(mockExchange, mockHandlerStrategies); + this.messageReaders = Collections.>singleton(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true)))::stream; + + defaultRequest = new DefaultServerRequest(mockExchange, messageReaders); } @@ -190,10 +189,6 @@ public class DefaultServerRequestTests { when(mockRequest.getHeaders()).thenReturn(httpHeaders); when(mockRequest.getBody()).thenReturn(body); - Set> messageReaders = Collections - .singleton(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true))); - when(mockHandlerStrategies.messageReaders()).thenReturn(messageReaders::stream); - Mono resultMono = defaultRequest.body(toMono(String.class)); assertEquals("foo", resultMono.block()); } @@ -210,10 +205,6 @@ public class DefaultServerRequestTests { when(mockRequest.getHeaders()).thenReturn(httpHeaders); when(mockRequest.getBody()).thenReturn(body); - Set> messageReaders = Collections - .singleton(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true))); - when(mockHandlerStrategies.messageReaders()).thenReturn(messageReaders::stream); - Mono resultMono = defaultRequest.bodyToMono(String.class); assertEquals("foo", resultMono.block()); } @@ -230,10 +221,6 @@ public class DefaultServerRequestTests { when(mockRequest.getHeaders()).thenReturn(httpHeaders); when(mockRequest.getBody()).thenReturn(body); - Set> messageReaders = Collections - .singleton(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true))); - when(mockHandlerStrategies.messageReaders()).thenReturn(messageReaders::stream); - Flux resultFlux = defaultRequest.bodyToFlux(String.class); Mono> result = resultFlux.collectList(); assertEquals(Collections.singletonList("foo"), result.block()); @@ -251,8 +238,8 @@ public class DefaultServerRequestTests { when(mockRequest.getHeaders()).thenReturn(httpHeaders); when(mockRequest.getBody()).thenReturn(body); - Set> messageReaders = Collections.emptySet(); - when(mockHandlerStrategies.messageReaders()).thenReturn(messageReaders::stream); + this.messageReaders = Collections.>emptySet()::stream; + this.defaultRequest = new DefaultServerRequest(mockExchange, messageReaders); Flux resultFlux = defaultRequest.bodyToFlux(String.class); StepVerifier.create(resultFlux) diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/ResourceHandlerFunctionTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/ResourceHandlerFunctionTests.java index 59c51c3fd0d..0942f65dd85 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/ResourceHandlerFunctionTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/ResourceHandlerFunctionTests.java @@ -33,8 +33,7 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import org.springframework.mock.http.server.reactive.test.MockServerWebExchange; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.Assert.*; /** * @author Arjen Poutsma @@ -51,7 +50,7 @@ public class ResourceHandlerFunctionTests { MockServerWebExchange exchange = MockServerHttpRequest.get("http://localhost").toExchange(); MockServerHttpResponse mockResponse = exchange.getResponse(); - ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults()); + ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults().messageReaders()); Mono responseMono = this.handlerFunction.handle(request); @@ -86,7 +85,7 @@ public class ResourceHandlerFunctionTests { MockServerWebExchange exchange = MockServerHttpRequest.head("http://localhost").toExchange(); MockServerHttpResponse mockResponse = exchange.getResponse(); - ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults()); + ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults().messageReaders()); Mono responseMono = this.handlerFunction.handle(request); @@ -110,7 +109,7 @@ public class ResourceHandlerFunctionTests { MockServerWebExchange exchange = MockServerHttpRequest.options("http://localhost").toExchange(); MockServerHttpResponse mockResponse = exchange.getResponse(); - ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults()); + ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults().messageReaders()); Mono responseMono = this.handlerFunction.handle(request); Mono result = responseMono.flatMap(response -> { diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionsTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionsTests.java index 26a7550c884..99ec483bf8e 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionsTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionsTests.java @@ -17,6 +17,7 @@ package org.springframework.web.reactive.function.server; import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import reactor.core.publisher.Mono; @@ -29,6 +30,8 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import org.springframework.web.server.ResponseStatusException; import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -245,4 +248,35 @@ public class RouterFunctionsTests { assertEquals(HttpStatus.NOT_FOUND, httpResponse.getStatusCode()); } + @Test + public void toHttpHandlerWebFilter() throws Exception { + AtomicBoolean filterInvoked = new AtomicBoolean(); + + WebFilter webFilter = new WebFilter() { + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + filterInvoked.set(true); + return chain.filter(exchange); + } + }; + + HandlerFunction handlerFunction = request -> ServerResponse.accepted().build(); + RouterFunction routerFunction = + RouterFunctions.route(RequestPredicates.all(), handlerFunction); + + HandlerStrategies handlerStrategies = HandlerStrategies.builder() + .webFilter(webFilter).build(); + + HttpHandler result = RouterFunctions.toHttpHandler(routerFunction, handlerStrategies); + assertNotNull(result); + + MockServerHttpRequest httpRequest = MockServerHttpRequest.get("http://localhost").build(); + MockServerHttpResponse httpResponse = new MockServerHttpResponse(); + result.handle(httpRequest, httpResponse).block(); + assertEquals(HttpStatus.ACCEPTED, httpResponse.getStatusCode()); + + assertTrue(filterInvoked.get()); + } + + }