From b3bd5ba94610b8fbbdf41017365adadbc7d49448 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Mon, 25 Sep 2017 22:06:32 -0500 Subject: [PATCH] Add Reactive HttpSecurity.addWebFilterAt Fixes gh-4542 --- .../config/web/server/HttpSecurity.java | 89 ++++++++++++++----- .../web/server/SecurityWebFiltersOrder.java | 58 ++++++++++++ 2 files changed, 124 insertions(+), 23 deletions(-) create mode 100644 config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java diff --git a/config/src/main/java/org/springframework/security/config/web/server/HttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/HttpSecurity.java index fe27e3992e..8b3cb4a973 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/HttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/HttpSecurity.java @@ -21,12 +21,15 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.http.MediaType; import org.springframework.security.web.server.DelegatingAuthenticationEntryPoint; import org.springframework.security.web.server.authentication.AuthenticationFailureHandler; import org.springframework.security.web.server.authentication.logout.LogoutWebFiter; import org.springframework.security.web.server.util.matcher.MediaTypeServerWebExchangeMatcher; -import org.springframework.security.web.util.matcher.MediaTypeRequestMatcher; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilterChain; import reactor.core.publisher.Mono; import org.springframework.http.HttpMethod; @@ -94,6 +97,8 @@ public class HttpSecurity { private List defaultEntryPoints = new ArrayList<>(); + private List webFilters = new ArrayList<>(); + /** * The ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance. * @@ -106,6 +111,11 @@ public class HttpSecurity { return this; } + public HttpSecurity addFilterAt(WebFilter webFilter, SecurityWebFiltersOrder order) { + this.webFilters.add(new OrderedWebFilter(webFilter, order.getOrder())); + return this; + } + /** * Gets the ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance. * @return the ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance. @@ -154,20 +164,19 @@ public class HttpSecurity { } public SecurityWebFilterChain build() { - List filters = new ArrayList<>(); if(this.headers != null) { - filters.add(this.headers.build()); + this.webFilters.add(this.headers.build()); } - SecurityContextRepositoryWebFilter securityContextRepositoryWebFilter = securityContextRepositoryWebFilter(); + WebFilter securityContextRepositoryWebFilter = securityContextRepositoryWebFilter(); if(securityContextRepositoryWebFilter != null) { - filters.add(securityContextRepositoryWebFilter); + this.webFilters.add(securityContextRepositoryWebFilter); } if(this.httpBasic != null) { this.httpBasic.authenticationManager(this.authenticationManager); if(this.securityContextRepository != null) { this.httpBasic.securityContextRepository(this.securityContextRepository); } - filters.add(this.httpBasic.build()); + this.webFilters.add(this.httpBasic.build()); } if(this.formLogin != null) { this.formLogin.authenticationManager(this.authenticationManager); @@ -175,22 +184,24 @@ public class HttpSecurity { this.formLogin.securityContextRepository(this.securityContextRepository); } if(this.formLogin.authenticationEntryPoint == null) { - filters.add(new LoginPageGeneratingWebFilter()); + this.webFilters.add(new OrderedWebFilter(new LoginPageGeneratingWebFilter(), SecurityWebFiltersOrder.LOGIN_PAGE_GENERATING.getOrder())); } - filters.add(this.formLogin.build()); - filters.add(new LogoutWebFiter()); + this.webFilters.add(this.formLogin.build()); + this.webFilters + .add(new OrderedWebFilter(new LogoutWebFiter(), SecurityWebFiltersOrder.LOGOUT.getOrder())); } - filters.add(new AuthenticationReactorContextFilter()); + this.webFilters.add(new OrderedWebFilter(new AuthenticationReactorContextFilter(), SecurityWebFiltersOrder.AUTHENTICATION_CONTEXT.getOrder())); if(this.authorizeExchangeBuilder != null) { AuthenticationEntryPoint authenticationEntryPoint = getAuthenticationEntryPoint(); ExceptionTranslationWebFilter exceptionTranslationWebFilter = new ExceptionTranslationWebFilter(); if(authenticationEntryPoint != null) { exceptionTranslationWebFilter.setAuthenticationEntryPoint(authenticationEntryPoint); } - filters.add(exceptionTranslationWebFilter); - filters.add(this.authorizeExchangeBuilder.build()); + this.webFilters.add(new OrderedWebFilter(exceptionTranslationWebFilter, SecurityWebFiltersOrder.EXCEPTION_TRANSLATION.getOrder())); + this.webFilters.add(this.authorizeExchangeBuilder.build()); } - return new MatcherSecurityWebFilterChain(getSecurityMatcher(), filters); + AnnotationAwareOrderComparator.sort(this.webFilters); + return new MatcherSecurityWebFilterChain(getSecurityMatcher(), this.webFilters); } private AuthenticationEntryPoint getAuthenticationEntryPoint() { @@ -209,10 +220,13 @@ public class HttpSecurity { return new HttpSecurity(); } - private SecurityContextRepositoryWebFilter securityContextRepositoryWebFilter() { + private WebFilter securityContextRepositoryWebFilter() { SecurityContextRepository repository = this.securityContextRepository; - return repository == null ? null : - new SecurityContextRepositoryWebFilter(repository); + if(repository == null) { + return null; + } + WebFilter result = new SecurityContextRepositoryWebFilter(repository); + return new OrderedWebFilter(result, SecurityWebFiltersOrder.SECURITY_CONTEXT_REPOSITORY.getOrder()); } private HttpSecurity() {} @@ -253,7 +267,8 @@ public class HttpSecurity { if(this.matcher != null) { throw new IllegalStateException("The matcher " + this.matcher + " does not have an access rule defined"); } - return new AuthorizationWebFilter(this.managerBldr.build()); + AuthorizationWebFilter result = new AuthorizationWebFilter(this.managerBldr.build()); + return new OrderedWebFilter(result, SecurityWebFiltersOrder.AUTHORIZATION.getOrder()); } public final class Access { @@ -318,7 +333,7 @@ public class HttpSecurity { return HttpSecurity.this; } - protected AuthenticationWebFilter build() { + protected WebFilter build() { MediaTypeServerWebExchangeMatcher restMatcher = new MediaTypeServerWebExchangeMatcher( MediaType.APPLICATION_ATOM_XML, MediaType.APPLICATION_FORM_URLENCODED, MediaType.APPLICATION_JSON, @@ -333,7 +348,7 @@ public class HttpSecurity { if(this.securityContextRepository != null) { authenticationFilter.setSecurityContextRepository(this.securityContextRepository); } - return authenticationFilter; + return new OrderedWebFilter(authenticationFilter, SecurityWebFiltersOrder.HTTP_BASIC.getOrder()); } private HttpBasicBuilder() {} @@ -395,7 +410,7 @@ public class HttpSecurity { return HttpSecurity.this; } - protected AuthenticationWebFilter build() { + protected WebFilter build() { if(this.authenticationEntryPoint == null) { loginPage("/login"); } @@ -410,7 +425,7 @@ public class HttpSecurity { authenticationFilter.setAuthenticationConverter(new FormLoginAuthenticationConverter()); authenticationFilter.setAuthenticationSuccessHandler(new RedirectAuthenticationSuccessHandler("/")); authenticationFilter.setSecurityContextRepository(this.securityContextRepository); - return authenticationFilter; + return new OrderedWebFilter(authenticationFilter, SecurityWebFiltersOrder.FORM_LOGIN.getOrder()); } private FormLoginBuilder() { @@ -454,9 +469,10 @@ public class HttpSecurity { return new HstsSpec(); } - protected HttpHeaderWriterWebFilter build() { + protected WebFilter build() { HttpHeadersWriter writer = new CompositeHttpHeadersWriter(this.writers); - return new HttpHeaderWriterWebFilter(writer); + HttpHeaderWriterWebFilter result = new HttpHeaderWriterWebFilter(writer); + return new OrderedWebFilter(result, SecurityWebFiltersOrder.HTTP_HEADERS_WRITER.getOrder()); } public XssProtectionSpec xssProtection() { @@ -520,4 +536,31 @@ public class HttpSecurity { this.frameOptions, this.xss)); } } + + private static class OrderedWebFilter implements WebFilter, Ordered { + private final WebFilter webFilter; + private final int order; + + public OrderedWebFilter(WebFilter webFilter, int order) { + this.webFilter = webFilter; + this.order = order; + } + + @Override + public Mono filter(ServerWebExchange exchange, + WebFilterChain chain) { + return this.webFilter.filter(exchange, chain); + } + + @Override + public int getOrder() { + return this.order; + } + + @Override + public String toString() { + return "OrderedWebFilter{" + "webFilter=" + this.webFilter + ", order=" + this.order + + '}'; + } + } } diff --git a/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java b/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java new file mode 100644 index 0000000000..19c202be7c --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.web.server; + + +/** + * @author Rob Winch + * @since 5.0 + */ +public enum SecurityWebFiltersOrder { + FIRST(Integer.MIN_VALUE), + HTTP_HEADERS_WRITER, + SECURITY_CONTEXT_REPOSITORY, + LOGIN_PAGE_GENERATING, + /** + * Instance of AuthenticationWebFilter + */ + HTTP_BASIC, + /** + * Instance of AuthenticationWebFilter + */ + FORM_LOGIN, + AUTHENTICATION, + LOGOUT, + AUTHENTICATION_CONTEXT, + EXCEPTION_TRANSLATION, + AUTHORIZATION, + LAST(Integer.MAX_VALUE); + + private static final int INTERVAL = 100; + + private final int order; + + private SecurityWebFiltersOrder() { + this.order = ordinal() * INTERVAL; + } + + private SecurityWebFiltersOrder(int order) { + this.order = order; + } + + public int getOrder() { + return this.order; + } +}