From 8b3fb55aeae3b8101444a88cf50c255fd33d4fea Mon Sep 17 00:00:00 2001 From: Ankur Pathak Date: Fri, 30 Nov 2018 12:34:10 +0530 Subject: [PATCH] Added methods to add filter relatively in ServerHttpSecurity Addition of two new methods addFilterBefore and addFilterAfter in ServerHttpSecurity to allow addition of WebFilter before and after of specified order Fixes: gh-6138 --- .../config/web/server/ServerHttpSecurity.java | 26 ++++++++++++++ .../web/server/ServerHttpSecurityTests.java | 34 +++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 7021e05de3..ca75c30d3c 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -288,6 +288,32 @@ public class ServerHttpSecurity { return this; } + /** + * Adds a {@link WebFilter} before specific position. + * @param webFilter the {@link WebFilter} to add + * @param order the place before which to insert the {@link WebFilter} + * @return the {@link ServerHttpSecurity} to continue configuring + * @since 5.2.0 + * @author Ankur Pathak + */ + public ServerHttpSecurity addFilterBefore(WebFilter webFilter, SecurityWebFiltersOrder order) { + this.webFilters.add(new OrderedWebFilter(webFilter, order.getOrder() - 1)); + return this; + } + + /** + * Adds a {@link WebFilter} after specific position. + * @param webFilter the {@link WebFilter} to add + * @param order the place after which to insert the {@link WebFilter} + * @return the {@link ServerHttpSecurity} to continue configuring + * @since 5.2.0 + * @author Ankur Pathak + */ + public ServerHttpSecurity addFilterAfter(WebFilter webFilter, SecurityWebFiltersOrder order) { + this.webFilters.add(new OrderedWebFilter(webFilter, order.getOrder() + 1)); + return this; + } + /** * Gets the ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance. * @return the ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance. diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index bac111a24a..eb10e23464 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -34,6 +34,8 @@ import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; +import org.springframework.web.server.WebFilterChain; import reactor.core.publisher.Mono; import reactor.test.publisher.TestPublisher; @@ -190,6 +192,30 @@ public class ServerHttpSecurityTests { .isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class)); } + @Test + @SuppressWarnings("unchecked") + public void addFilterAfterIsApplied(){ + SecurityWebFilterChain securityWebFilterChain = this.http.addFilterAfter(new TestWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE).build(); + List filters = securityWebFilterChain.getWebFilters().map(WebFilter::getClass).collectList().block(); + + assertThat(filters).isNotNull() + .isNotEmpty() + .containsSequence(SecurityContextServerWebExchangeWebFilter.class, TestWebFilter.class); + + } + + @Test + @SuppressWarnings("unchecked") + public void addFilterBeforeIsApplied(){ + SecurityWebFilterChain securityWebFilterChain = this.http.addFilterBefore(new TestWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE).build(); + List filters = securityWebFilterChain.getWebFilters().map(WebFilter::getClass).collectList().block(); + + assertThat(filters).isNotNull() + .isNotEmpty() + .containsSequence(TestWebFilter.class, SecurityContextServerWebExchangeWebFilter.class); + + } + private Optional getWebFilter(SecurityWebFilterChain filterChain, Class filterClass) { return (Optional) filterChain.getWebFilters() .filter(Objects::nonNull) @@ -214,4 +240,12 @@ public class ServerHttpSecurityTests { .map(e -> e.getRequest().getPath().pathWithinApplication().value()); } } + + private static class TestWebFilter implements WebFilter { + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + return chain.filter(exchange); + } + } }