From 37592ea07c1257326c98f88ca07c37974ba9f118 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Thu, 16 Mar 2017 13:44:55 -0400 Subject: [PATCH] DefaultWebFilterChain is a top-level, public class Issue: SPR-15348 --- .../server/handler/DefaultWebFilterChain.java | 64 ++++++++++++++ .../server/handler/FilteringWebHandler.java | 36 +++----- .../handler/FilteringWebHandlerTests.java | 83 +++++++++---------- 3 files changed, 113 insertions(+), 70 deletions(-) create mode 100644 spring-web/src/main/java/org/springframework/web/server/handler/DefaultWebFilterChain.java diff --git a/spring-web/src/main/java/org/springframework/web/server/handler/DefaultWebFilterChain.java b/spring-web/src/main/java/org/springframework/web/server/handler/DefaultWebFilterChain.java new file mode 100644 index 00000000000..b2fe5801664 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/handler/DefaultWebFilterChain.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.server.handler; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import reactor.core.publisher.Mono; + +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.server.WebHandler; + +/** + * Default implementation of {@link WebFilterChain}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class DefaultWebFilterChain implements WebFilterChain { + + private final List filters; + + private final WebHandler handler; + + private volatile int index; + + + public DefaultWebFilterChain(WebHandler handler, WebFilter... filters) { + Assert.notNull(handler, "WebHandler is required"); + this.filters = ObjectUtils.isEmpty(filters) ? Collections.emptyList() : Arrays.asList(filters); + this.handler = handler; + } + + + @Override + public Mono filter(ServerWebExchange exchange) { + if (this.index < this.filters.size()) { + WebFilter filter = this.filters.get(this.index++); + return filter.filter(exchange, this); + } + else { + return this.handler.handle(exchange); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/handler/FilteringWebHandler.java b/spring-web/src/main/java/org/springframework/web/server/handler/FilteringWebHandler.java index 70710b8c0a2..94d2ff4fcfc 100644 --- a/spring-web/src/main/java/org/springframework/web/server/handler/FilteringWebHandler.java +++ b/spring-web/src/main/java/org/springframework/web/server/handler/FilteringWebHandler.java @@ -16,14 +16,15 @@ package org.springframework.web.server.handler; +import java.util.Arrays; import java.util.Collections; import java.util.List; import reactor.core.publisher.Mono; +import org.springframework.util.CollectionUtils; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; -import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebHandler; /** @@ -35,7 +36,7 @@ import org.springframework.web.server.WebHandler; */ public class FilteringWebHandler extends WebHandlerDecorator { - private final List filters; + private final WebFilter[] filters; /** @@ -44,41 +45,24 @@ public class FilteringWebHandler extends WebHandlerDecorator { */ public FilteringWebHandler(WebHandler webHandler, List filters) { super(webHandler); - this.filters = Collections.unmodifiableList(filters); + this.filters = !CollectionUtils.isEmpty(filters) ? + filters.toArray(new WebFilter[filters.size()]) : new WebFilter[0]; } /** - * Return read-only list of the configured filters. + * Return a read-only list of the configured filters. */ public List getFilters() { - return this.filters; + return Arrays.asList(this.filters); } @Override public Mono handle(ServerWebExchange exchange) { - if (this.filters.isEmpty()) { - return super.handle(exchange); - } - return new DefaultWebFilterChain().filter(exchange); - } - - - private class DefaultWebFilterChain implements WebFilterChain { - - private int index; - - @Override - public Mono filter(ServerWebExchange exchange) { - if (this.index < filters.size()) { - WebFilter filter = filters.get(this.index++); - return filter.filter(exchange, this); - } - else { - return getDelegate().handle(exchange); - } - } + return this.filters.length != 0 ? + new DefaultWebFilterChain(getDelegate(), this.filters).filter(exchange) : + super.handle(exchange); } } diff --git a/spring-web/src/test/java/org/springframework/web/server/handler/FilteringWebHandlerTests.java b/spring-web/src/test/java/org/springframework/web/server/handler/FilteringWebHandlerTests.java index ef8e63f14e3..a4aece7e4dc 100644 --- a/spring-web/src/test/java/org/springframework/web/server/handler/FilteringWebHandlerTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/handler/FilteringWebHandlerTests.java @@ -16,18 +16,16 @@ package org.springframework.web.server.handler; +import java.time.Duration; import java.util.Arrays; import java.util.Collections; -import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.junit.Before; import org.junit.Test; import reactor.core.publisher.Mono; import org.springframework.http.HttpStatus; -import org.springframework.http.server.reactive.HttpHandler; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import org.springframework.web.server.ServerWebExchange; @@ -49,89 +47,86 @@ public class FilteringWebHandlerTests { private static Log logger = LogFactory.getLog(FilteringWebHandlerTests.class); - private MockServerHttpRequest request; - - private MockServerHttpResponse response; - - - @Before - public void setUp() throws Exception { - this.request = MockServerHttpRequest.get("http://localhost").build(); - this.response = new MockServerHttpResponse(); - } - @Test public void multipleFilters() throws Exception { - StubWebHandler webHandler = new StubWebHandler(); + TestFilter filter1 = new TestFilter(); TestFilter filter2 = new TestFilter(); TestFilter filter3 = new TestFilter(); - HttpHandler httpHandler = createHttpHandler(webHandler, filter1, filter2, filter3); - httpHandler.handle(this.request, this.response).block(); + StubWebHandler targetHandler = new StubWebHandler(); + + new FilteringWebHandler(targetHandler, Arrays.asList(filter1, filter2, filter3)) + .handle(MockServerHttpRequest.get("/").toExchange()) + .block(Duration.ZERO); assertTrue(filter1.invoked()); assertTrue(filter2.invoked()); assertTrue(filter3.invoked()); - assertTrue(webHandler.invoked()); + assertTrue(targetHandler.invoked()); } @Test public void zeroFilters() throws Exception { - StubWebHandler webHandler = new StubWebHandler(); - HttpHandler httpHandler = createHttpHandler(webHandler); - httpHandler.handle(this.request, this.response).block(); - assertTrue(webHandler.invoked()); + StubWebHandler targetHandler = new StubWebHandler(); + + new FilteringWebHandler(targetHandler, Collections.emptyList()) + .handle(MockServerHttpRequest.get("/").toExchange()) + .block(Duration.ZERO); + + assertTrue(targetHandler.invoked()); } @Test public void shortcircuitFilter() throws Exception { - StubWebHandler webHandler = new StubWebHandler(); + TestFilter filter1 = new TestFilter(); ShortcircuitingFilter filter2 = new ShortcircuitingFilter(); TestFilter filter3 = new TestFilter(); - HttpHandler httpHandler = createHttpHandler(webHandler, filter1, filter2, filter3); - httpHandler.handle(this.request, this.response).block(); + StubWebHandler targetHandler = new StubWebHandler(); + + new FilteringWebHandler(targetHandler, Arrays.asList(filter1, filter2, filter3)) + .handle(MockServerHttpRequest.get("/").toExchange()) + .block(Duration.ZERO); assertTrue(filter1.invoked()); assertTrue(filter2.invoked()); assertFalse(filter3.invoked()); - assertFalse(webHandler.invoked()); + assertFalse(targetHandler.invoked()); } @Test public void asyncFilter() throws Exception { - StubWebHandler webHandler = new StubWebHandler(); + AsyncFilter filter = new AsyncFilter(); - HttpHandler httpHandler = createHttpHandler(webHandler, filter); - httpHandler.handle(this.request, this.response).block(); + StubWebHandler targetHandler = new StubWebHandler(); + + new FilteringWebHandler(targetHandler, Collections.singletonList(filter)) + .handle(MockServerHttpRequest.get("/").toExchange()) + .block(Duration.ZERO); assertTrue(filter.invoked()); - assertTrue(webHandler.invoked()); + assertTrue(targetHandler.invoked()); } @Test public void handleErrorFromFilter() throws Exception { + + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerHttpResponse response = new MockServerHttpResponse(); + TestExceptionHandler exceptionHandler = new TestExceptionHandler(); - List filters = Collections.singletonList(new ExceptionFilter()); - List exceptionHandlers = Collections.singletonList(exceptionHandler); WebHttpHandlerBuilder.webHandler(new StubWebHandler()) - .filters(filters).exceptionHandlers(exceptionHandlers).build() - .handle(this.request, this.response) + .filters(Collections.singletonList(new ExceptionFilter())) + .exceptionHandlers(Collections.singletonList(exceptionHandler)).build() + .handle(request, response) .block(); - assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, this.response.getStatusCode()); - - Throwable savedException = exceptionHandler.ex; - assertNotNull(savedException); - assertEquals("boo", savedException.getMessage()); - } - - - private HttpHandler createHttpHandler(StubWebHandler webHandler, WebFilter... filters) { - return WebHttpHandlerBuilder.webHandler(webHandler).filters(Arrays.asList(filters)).build(); + assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, response.getStatusCode()); + assertNotNull(exceptionHandler.ex); + assertEquals("boo", exceptionHandler.ex.getMessage()); }