diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java index 1e41edf85c..ffe1cf6677 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; @@ -103,6 +104,9 @@ import org.springframework.web.util.UriComponentsBuilder; */ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private final ClientRegistrationRepository clientRegistrationRepository; private final OAuth2AuthorizedClientRepository authorizedClientRepository; @@ -158,6 +162,17 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { this.requestCache = requestCache; } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { @@ -232,7 +247,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { this.redirectStrategy.sendRedirect(request, response, uriBuilder.build().encode().toString()); return; } - Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication(); + Authentication currentAuthentication = this.securityContextHolderStrategy.getContext().getAuthentication(); String principalName = (currentAuthentication != null) ? currentAuthentication.getName() : "anonymousUser"; OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( authenticationResult.getClientRegistration(), principalName, authenticationResult.getAccessToken(), diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 95ef7526fa..a7b9b9687d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -27,6 +27,7 @@ import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; @@ -67,6 +68,9 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken("anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private OAuth2AuthorizedClientManager authorizedClientManager; /** @@ -112,7 +116,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth + "It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or " + "@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."); } - Authentication principal = SecurityContextHolder.getContext().getAuthentication(); + Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication(); if (principal == null) { principal = ANONYMOUS_AUTHENTICATION; } @@ -132,7 +136,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth private String resolveClientRegistrationId(MethodParameter parameter) { RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils .findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class); - Authentication principal = SecurityContextHolder.getContext().getAuthentication(); + Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication(); if (!StringUtils.isEmpty(authorizedClientAnnotation.registrationId())) { return authorizedClientAnnotation.registrationId(); } @@ -145,4 +149,15 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth return null; } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index c8cd1c58e1..d211168b97 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -36,6 +36,7 @@ import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; @@ -145,6 +146,9 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken("anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private OAuth2AuthorizedClientManager authorizedClientManager; private boolean defaultOAuth2AuthorizedClient; @@ -243,6 +247,17 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement this.defaultClientRegistrationId = clientRegistrationId; } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + /** * Configures the builder with {@link #defaultRequest()} and adds this as a * {@link ExchangeFilterFunction} @@ -431,7 +446,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement if (attrs.containsKey(AUTHENTICATION_ATTR_NAME)) { return; } - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication(); attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java index 777c6f6cf3..8b5a0abacd 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,8 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; @@ -306,6 +308,23 @@ public class OAuth2AuthorizationCodeGrantFilterTests { assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1"); } + @Test + public void doFilterWhenCustomSecurityContextHolderStrategyThenUses() throws Exception { + MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1"); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); + this.setUpAuthenticationResult(this.registration1); + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.getContext()) + .willReturn(new SecurityContextImpl(new TestingAuthenticationToken("user", "password"))); + this.filter.setSecurityContextHolderStrategy(strategy); + this.filter.doFilter(authorizationResponse, response, filterChain); + verify(strategy).getContext(); + assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1"); + } + @Test public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception { String requestUri = "/saved-request"; diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index 310bf1b122..16adae1a4b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -33,6 +33,8 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; @@ -67,6 +69,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -233,6 +236,18 @@ public class OAuth2AuthorizedClientArgumentResolverTests { new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1); } + @Test + public void resolveArgumentWhenCustomSecurityContextHolderStrategyThenUses() throws Exception { + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.getContext()).willReturn(new SecurityContextImpl(this.authentication)); + this.argumentResolver.setSecurityContextHolderStrategy(strategy); + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", + OAuth2AuthorizedClient.class); + assertThat(this.argumentResolver.resolveArgument(methodParameter, null, + new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1); + verify(strategy, atLeastOnce()).getContext(); + } + @Test public void resolveArgumentWhenRegistrationIdInvalidThenThrowIllegalArgumentException() { MethodParameter methodParameter = this.getMethodParameter("registrationIdInvalid", diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 3087c24b85..2b6e039296 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -64,6 +64,8 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.JwtBearerOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; @@ -250,6 +252,18 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { verifyNoInteractions(this.authorizedClientRepository); } + @Test + public void defaultRequestAuthenticationWhenCustomSecurityContextHolderStrategyThenAuthenticationSet() { + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.getContext()).willReturn(new SecurityContextImpl(this.authentication)); + this.function.setSecurityContextHolderStrategy(strategy); + Map attrs = getDefaultRequestAttributes(); + assertThat(ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication(attrs)) + .isEqualTo(this.authentication); + verify(strategy).getContext(); + verifyNoInteractions(this.authorizedClientRepository); + } + private Map getDefaultRequestAttributes() { this.function.defaultRequest().accept(this.spec); verify(this.spec).attributes(this.attrs.capture()); diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/authentication/BearerTokenAuthenticationFilter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/authentication/BearerTokenAuthenticationFilter.java index c0bc2f5e44..6377899c87 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/authentication/BearerTokenAuthenticationFilter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/authentication/BearerTokenAuthenticationFilter.java @@ -32,6 +32,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider; @@ -67,6 +68,9 @@ public class BearerTokenAuthenticationFilter extends OncePerRequestFilter { private final AuthenticationManagerResolver authenticationManagerResolver; + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private AuthenticationEntryPoint authenticationEntryPoint = new BearerTokenAuthenticationEntryPoint(); private AuthenticationFailureHandler authenticationFailureHandler = (request, response, exception) -> { @@ -135,9 +139,9 @@ public class BearerTokenAuthenticationFilter extends OncePerRequestFilter { try { AuthenticationManager authenticationManager = this.authenticationManagerResolver.resolve(request); Authentication authenticationResult = authenticationManager.authenticate(authenticationRequest); - SecurityContext context = SecurityContextHolder.createEmptyContext(); + SecurityContext context = this.securityContextHolderStrategy.createEmptyContext(); context.setAuthentication(authenticationResult); - SecurityContextHolder.setContext(context); + this.securityContextHolderStrategy.setContext(context); this.securityContextRepository.saveContext(context, request, response); if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authenticationResult)); @@ -145,12 +149,23 @@ public class BearerTokenAuthenticationFilter extends OncePerRequestFilter { filterChain.doFilter(request, response); } catch (AuthenticationException failed) { - SecurityContextHolder.clearContext(); + this.securityContextHolderStrategy.clearContext(); this.logger.trace("Failed to process authentication request", failed); this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed); } } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + /** * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on * authentication success. The default action is not to save the diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/authentication/BearerTokenAuthenticationFilterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/authentication/BearerTokenAuthenticationFilterTests.java index 938cc31b03..5fce70fb34 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/authentication/BearerTokenAuthenticationFilterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/authentication/BearerTokenAuthenticationFilterTests.java @@ -37,6 +37,8 @@ import org.springframework.security.authentication.AuthenticationManagerResolver import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.server.resource.BearerTokenError; import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes; @@ -210,6 +212,18 @@ public class BearerTokenAuthenticationFilterTests { verify(this.authenticationDetailsSource).buildDetails(this.request); } + @Test + public void doFilterWhenCustomSecurityContextHolderStrategyThenUses() throws ServletException, IOException { + given(this.bearerTokenResolver.resolve(this.request)).willReturn("token"); + BearerTokenAuthenticationFilter filter = addMocks( + new BearerTokenAuthenticationFilter(this.authenticationManager)); + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.createEmptyContext()).willReturn(new SecurityContextImpl()); + filter.setSecurityContextHolderStrategy(strategy); + filter.doFilter(this.request, this.response, this.filterChain); + verify(strategy).setContext(any()); + } + @Test public void setAuthenticationEntryPointWhenNullThenThrowsException() { BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(this.authenticationManager);