From b060ec050a5b961cb68945f29995faef55be8f58 Mon Sep 17 00:00:00 2001 From: Eric Deandrea Date: Thu, 24 May 2018 14:16:04 -0400 Subject: [PATCH] Automatically add CsrfServerLogoutHandler if csrf enabled The configuration DSL should automatically add CsrfServerLogoutHandler if csrf is enabled Fixes gh-5337 --- .../config/web/server/ServerHttpSecurity.java | 30 ++++++- .../web/server/ServerHttpSecurityTests.java | 78 +++++++++++++++-- .../logout/DelegatingServerLogoutHandler.java | 22 +++-- .../logout/LogoutWebFilter.java | 10 ++- .../logout/LogoutWebFilterTests.java | 86 +++++++++++++++++++ 5 files changed, 205 insertions(+), 21 deletions(-) create mode 100644 web/src/test/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilterTests.java diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index d1a1414996..401ad2d14b 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -27,6 +27,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.Function; import reactor.core.publisher.Mono; @@ -92,7 +93,9 @@ import org.springframework.security.web.server.authentication.ServerAuthenticati import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; import org.springframework.security.web.server.authentication.ServerFormLoginAuthenticationConverter; import org.springframework.security.web.server.authentication.ServerHttpBasicAuthenticationConverter; +import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.LogoutWebFilter; +import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.ServerLogoutSuccessHandler; import org.springframework.security.web.server.authorization.AuthorizationContext; @@ -106,8 +109,10 @@ import org.springframework.security.web.server.context.ReactorContextWebFilter; import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; +import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler; import org.springframework.security.web.server.csrf.CsrfWebFilter; import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; +import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository; import org.springframework.security.web.server.header.CacheControlServerHttpHeadersWriter; import org.springframework.security.web.server.header.CompositeServerHttpHeadersWriter; import org.springframework.security.web.server.header.ContentSecurityPolicyServerHttpHeadersWriter; @@ -1538,6 +1543,7 @@ public class ServerHttpSecurity { */ public class CsrfSpec { private CsrfWebFilter filter = new CsrfWebFilter(); + private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository(); private boolean specifiedRequireCsrfProtectionMatcher; @@ -1563,7 +1569,7 @@ public class ServerHttpSecurity { */ public CsrfSpec csrfTokenRepository( ServerCsrfTokenRepository csrfTokenRepository) { - this.filter.setCsrfTokenRepository(csrfTokenRepository); + this.csrfTokenRepository = csrfTokenRepository; return this; } @@ -1600,6 +1606,10 @@ public class ServerHttpSecurity { } protected void configure(ServerHttpSecurity http) { + Optional.ofNullable(this.csrfTokenRepository).ifPresent(serverCsrfTokenRepository -> { + this.filter.setCsrfTokenRepository(serverCsrfTokenRepository); + http.logout().logoutHandler(new CsrfServerLogoutHandler(serverCsrfTokenRepository)); + }); http.addFilterAt(this.filter, SecurityWebFiltersOrder.CSRF); } @@ -2332,6 +2342,7 @@ public class ServerHttpSecurity { */ public final class LogoutSpec { private LogoutWebFilter logoutWebFilter = new LogoutWebFilter(); + private List logoutHandlers = new ArrayList<>(Arrays.asList(new SecurityContextServerLogoutHandler())); /** * Configures the logout handler. Default is {@code SecurityContextServerLogoutHandler} @@ -2339,7 +2350,10 @@ public class ServerHttpSecurity { * @return the {@link LogoutSpec} to configure */ public LogoutSpec logoutHandler(ServerLogoutHandler logoutHandler) { - this.logoutWebFilter.setLogoutHandler(logoutHandler); + if (logoutHandler != null) { + this.logoutHandlers.add(logoutHandler); + } + return this; } @@ -2387,7 +2401,19 @@ public class ServerHttpSecurity { return and(); } + private Optional createLogoutHandler() { + if (this.logoutHandlers.isEmpty()) { + return Optional.empty(); + } + else if (this.logoutHandlers.size() == 1) { + return Optional.of(this.logoutHandlers.get(0)); + } + + return Optional.of(new DelegatingServerLogoutHandler(this.logoutHandlers)); + } + protected void configure(ServerHttpSecurity http) { + createLogoutHandler().ifPresent(this.logoutWebFilter::setLogoutHandler); http.addFilterAt(this.logoutWebFilter, SecurityWebFiltersOrder.LOGOUT); } diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index e13633efc8..bac111a24a 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -16,12 +16,27 @@ package org.springframework.security.config.web.server; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + import org.apache.http.HttpHeaders; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + +import reactor.core.publisher.Mono; +import reactor.test.publisher.TestPublisher; + import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; @@ -29,21 +44,23 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.WebFilterChainProxy; +import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler; +import org.springframework.security.web.server.authentication.logout.LogoutWebFilter; +import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler; +import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler; import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; +import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler; +import org.springframework.security.web.server.csrf.CsrfWebFilter; +import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; +import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.reactive.server.EntityExchangeResult; import org.springframework.test.web.reactive.server.FluxExchangeResult; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; -import reactor.test.publisher.TestPublisher; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.BDDMockito.given; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.when; +import org.springframework.web.server.WebFilter; /** * @author Rob Winch @@ -55,6 +72,8 @@ public class ServerHttpSecurityTests { private ServerSecurityContextRepository contextRepository; @Mock private ReactiveAuthenticationManager authenticationManager; + @Mock + private ServerCsrfTokenRepository csrfTokenRepository; private ServerHttpSecurity http; @@ -134,6 +153,51 @@ public class ServerHttpSecurityTests { .expectBody(String.class).isEqualTo("/foo/bar"); } + @Test + public void csrfServerLogoutHandlerNotAppliedIfCsrfIsntEnabled() { + SecurityWebFilterChain securityWebFilterChain = this.http.csrf().disable().build(); + + assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)) + .isNotPresent(); + + Optional logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class) + .map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); + + assertThat(logoutHandler) + .get() + .isExactlyInstanceOf(SecurityContextServerLogoutHandler.class); + } + + @Test + public void csrfServerLogoutHandlerAppliedIfCsrfIsEnabled() { + SecurityWebFilterChain securityWebFilterChain = this.http.csrf().csrfTokenRepository(this.csrfTokenRepository).and().build(); + + assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)) + .get() + .extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository")) + .isEqualTo(this.csrfTokenRepository); + + Optional logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class) + .map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); + + assertThat(logoutHandler) + .get() + .isExactlyInstanceOf(DelegatingServerLogoutHandler.class) + .extracting(delegatingLogoutHandler -> + ((List) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream() + .map(ServerLogoutHandler::getClass) + .collect(Collectors.toList())) + .isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class)); + } + + private Optional getWebFilter(SecurityWebFilterChain filterChain, Class filterClass) { + return (Optional) filterChain.getWebFilters() + .filter(Objects::nonNull) + .filter(filter -> filter.getClass().isAssignableFrom(filterClass)) + .singleOrEmpty() + .blockOptional(); + } + private WebTestClient buildClient() { WebFilterChainProxy springSecurityFilterChain = new WebFilterChainProxy( this.http.build()); diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/logout/DelegatingServerLogoutHandler.java b/web/src/main/java/org/springframework/security/web/server/authentication/logout/DelegatingServerLogoutHandler.java index f809f43f17..de4cd2cc0c 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/logout/DelegatingServerLogoutHandler.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/logout/DelegatingServerLogoutHandler.java @@ -18,16 +18,17 @@ package org.springframework.security.web.server.authentication.logout; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.List; +import java.util.Objects; import java.util.stream.Collectors; -import java.util.stream.Stream; + +import reactor.core.publisher.Mono; import org.springframework.security.core.Authentication; import org.springframework.security.web.server.WebFilterExchange; import org.springframework.util.Assert; -import reactor.core.publisher.Mono; - /** * Delegates to a collection of {@link ServerLogoutHandler} implementations. * @@ -35,21 +36,24 @@ import reactor.core.publisher.Mono; * @since 5.1 */ public class DelegatingServerLogoutHandler implements ServerLogoutHandler { - private final List delegates; + private final List delegates = new ArrayList<>(); public DelegatingServerLogoutHandler(ServerLogoutHandler... delegates) { Assert.notEmpty(delegates, "delegates cannot be null or empty"); - this.delegates = Arrays.asList(delegates); + this.delegates.addAll(Arrays.asList(delegates)); } - public DelegatingServerLogoutHandler(List delegates) { + public DelegatingServerLogoutHandler(Collection delegates) { Assert.notEmpty(delegates, "delegates cannot be null or empty"); - this.delegates = new ArrayList<>(delegates); + this.delegates.addAll(delegates); } @Override public Mono logout(WebFilterExchange exchange, Authentication authentication) { - Stream> results = this.delegates.stream().map(delegate -> delegate.logout(exchange, authentication)); - return Mono.when(results.collect(Collectors.toList())); + return Mono.when(this.delegates.stream() + .filter(Objects::nonNull) + .map(delegate -> delegate.logout(exchange, authentication)) + .collect(Collectors.toList()) + ); } } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilter.java b/web/src/main/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilter.java index f43e0a7a9b..91f28e75df 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilter.java @@ -16,17 +16,17 @@ package org.springframework.security.web.server.authentication.logout; -import org.springframework.http.HttpMethod; -import org.springframework.security.core.context.ReactiveSecurityContextHolder; -import org.springframework.util.Assert; import reactor.core.publisher.Mono; +import org.springframework.http.HttpMethod; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.web.server.WebFilterExchange; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers; +import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; @@ -85,6 +85,10 @@ public class LogoutWebFilter implements WebFilter { this.logoutSuccessHandler = logoutSuccessHandler; } + /** + * Sets the {@link ServerLogoutHandler}. The default is {@link SecurityContextServerLogoutHandler}. + * @param logoutHandler The handler to use + */ public void setLogoutHandler(ServerLogoutHandler logoutHandler) { Assert.notNull(logoutHandler, "logoutHandler must not be null"); this.logoutHandler = logoutHandler; diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilterTests.java new file mode 100644 index 0000000000..ea39434620 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilterTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.authentication.logout; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Arrays; +import java.util.Collection; +import java.util.stream.Collectors; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +import org.springframework.test.util.ReflectionTestUtils; + +/** + * @author Eric Deandrea + * @since 5.1 + */ +@RunWith(MockitoJUnitRunner.class) +public class LogoutWebFilterTests { + @Mock + private ServerLogoutHandler handler1; + + @Mock + private ServerLogoutHandler handler2; + + @Mock + private ServerLogoutHandler handler3; + + private LogoutWebFilter logoutWebFilter = new LogoutWebFilter(); + + @Test + public void defaultLogoutHandler() { + assertThat(getLogoutHandler()) + .isNotNull() + .isExactlyInstanceOf(SecurityContextServerLogoutHandler.class); + } + + @Test + public void singleLogoutHandler() { + this.logoutWebFilter.setLogoutHandler(this.handler1); + this.logoutWebFilter.setLogoutHandler(this.handler2); + + assertThat(getLogoutHandler()) + .isNotNull() + .isInstanceOf(ServerLogoutHandler.class) + .isNotInstanceOf(SecurityContextServerLogoutHandler.class) + .extracting(ServerLogoutHandler::getClass) + .isEqualTo(this.handler2.getClass()); + } + + @Test + public void multipleLogoutHandlers() { + this.logoutWebFilter.setLogoutHandler(new DelegatingServerLogoutHandler(this.handler1, this.handler2, this.handler3)); + + assertThat(getLogoutHandler()) + .isNotNull() + .isExactlyInstanceOf(DelegatingServerLogoutHandler.class) + .extracting(delegatingLogoutHandler -> ((Collection) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")) + .stream() + .map(ServerLogoutHandler::getClass) + .collect(Collectors.toList())) + .isEqualTo(Arrays.asList(this.handler1.getClass(), this.handler2.getClass(), this.handler3.getClass())); + } + + private ServerLogoutHandler getLogoutHandler() { + return (ServerLogoutHandler) ReflectionTestUtils.getField(this.logoutWebFilter, LogoutWebFilter.class, "logoutHandler"); + } +}