Browse Source

Add SecurityContextHolderStrategy to OAuth2

Issue gh-11060
pull/11964/head
Josh Cummings 4 years ago
parent
commit
14584b0562
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
  1. 19
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java
  2. 19
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java
  3. 17
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java
  4. 21
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java
  5. 15
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java
  6. 14
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java
  7. 21
      oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/authentication/BearerTokenAuthenticationFilter.java
  8. 14
      oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/authentication/BearerTokenAuthenticationFilterTests.java

19
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,6 +32,7 @@ import org.springframework.security.authentication.AuthenticationDetailsSource; @@ -32,6 +32,7 @@ import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
@ -103,6 +104,9 @@ import org.springframework.web.util.UriComponentsBuilder; @@ -103,6 +104,9 @@ import org.springframework.web.util.UriComponentsBuilder;
*/
public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private final ClientRegistrationRepository clientRegistrationRepository;
private final OAuth2AuthorizedClientRepository authorizedClientRepository;
@ -158,6 +162,17 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { @@ -158,6 +162,17 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
this.requestCache = requestCache;
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
@ -232,7 +247,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { @@ -232,7 +247,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
this.redirectStrategy.sendRedirect(request, response, uriBuilder.build().encode().toString());
return;
}
Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
Authentication currentAuthentication = this.securityContextHolderStrategy.getContext().getAuthentication();
String principalName = (currentAuthentication != null) ? currentAuthentication.getName() : "anonymousUser";
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
authenticationResult.getClientRegistration(), principalName, authenticationResult.getAccessToken(),

19
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java

@ -27,6 +27,7 @@ import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -27,6 +27,7 @@ import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
@ -67,6 +68,9 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth @@ -67,6 +68,9 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken("anonymous",
"anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private OAuth2AuthorizedClientManager authorizedClientManager;
/**
@ -112,7 +116,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth @@ -112,7 +116,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
+ "It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or "
+ "@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
}
Authentication principal = SecurityContextHolder.getContext().getAuthentication();
Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication();
if (principal == null) {
principal = ANONYMOUS_AUTHENTICATION;
}
@ -132,7 +136,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth @@ -132,7 +136,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
private String resolveClientRegistrationId(MethodParameter parameter) {
RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils
.findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class);
Authentication principal = SecurityContextHolder.getContext().getAuthentication();
Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication();
if (!StringUtils.isEmpty(authorizedClientAnnotation.registrationId())) {
return authorizedClientAnnotation.registrationId();
}
@ -145,4 +149,15 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth @@ -145,4 +149,15 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
return null;
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
}

17
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

@ -36,6 +36,7 @@ import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -36,6 +36,7 @@ import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
@ -145,6 +146,9 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement @@ -145,6 +146,9 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken("anonymous",
"anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private OAuth2AuthorizedClientManager authorizedClientManager;
private boolean defaultOAuth2AuthorizedClient;
@ -243,6 +247,17 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement @@ -243,6 +247,17 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
this.defaultClientRegistrationId = clientRegistrationId;
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
/**
* Configures the builder with {@link #defaultRequest()} and adds this as a
* {@link ExchangeFilterFunction}
@ -431,7 +446,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement @@ -431,7 +446,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
if (attrs.containsKey(AUTHENTICATION_ATTR_NAME)) {
return;
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
}

21
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,6 +38,8 @@ import org.springframework.security.core.Authentication; @@ -38,6 +38,8 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
@ -306,6 +308,23 @@ public class OAuth2AuthorizationCodeGrantFilterTests { @@ -306,6 +308,23 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
}
@Test
public void doFilterWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
given(strategy.getContext())
.willReturn(new SecurityContextImpl(new TestingAuthenticationToken("user", "password")));
this.filter.setSecurityContextHolderStrategy(strategy);
this.filter.doFilter(authorizationResponse, response, filterChain);
verify(strategy).getContext();
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
}
@Test
public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception {
String requestUri = "/saved-request";

15
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java

@ -33,6 +33,8 @@ import org.springframework.security.authentication.TestingAuthenticationToken; @@ -33,6 +33,8 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
@ -67,6 +69,7 @@ import static org.mockito.ArgumentMatchers.any; @@ -67,6 +69,7 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -233,6 +236,18 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @@ -233,6 +236,18 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1);
}
@Test
public void resolveArgumentWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
given(strategy.getContext()).willReturn(new SecurityContextImpl(this.authentication));
this.argumentResolver.setSecurityContextHolderStrategy(strategy);
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient",
OAuth2AuthorizedClient.class);
assertThat(this.argumentResolver.resolveArgument(methodParameter, null,
new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1);
verify(strategy, atLeastOnce()).getContext();
}
@Test
public void resolveArgumentWhenRegistrationIdInvalidThenThrowIllegalArgumentException() {
MethodParameter methodParameter = this.getMethodParameter("registrationIdInvalid",

14
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@ -64,6 +64,8 @@ import org.springframework.security.core.Authentication; @@ -64,6 +64,8 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.JwtBearerOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
@ -250,6 +252,18 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @@ -250,6 +252,18 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
verifyNoInteractions(this.authorizedClientRepository);
}
@Test
public void defaultRequestAuthenticationWhenCustomSecurityContextHolderStrategyThenAuthenticationSet() {
SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
given(strategy.getContext()).willReturn(new SecurityContextImpl(this.authentication));
this.function.setSecurityContextHolderStrategy(strategy);
Map<String, Object> attrs = getDefaultRequestAttributes();
assertThat(ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication(attrs))
.isEqualTo(this.authentication);
verify(strategy).getContext();
verifyNoInteractions(this.authorizedClientRepository);
}
private Map<String, Object> getDefaultRequestAttributes() {
this.function.defaultRequest().accept(this.spec);
verify(this.spec).attributes(this.attrs.capture());

21
oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/authentication/BearerTokenAuthenticationFilter.java

@ -32,6 +32,7 @@ import org.springframework.security.core.Authentication; @@ -32,6 +32,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthenticationToken;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider;
@ -67,6 +68,9 @@ public class BearerTokenAuthenticationFilter extends OncePerRequestFilter { @@ -67,6 +68,9 @@ public class BearerTokenAuthenticationFilter extends OncePerRequestFilter {
private final AuthenticationManagerResolver<HttpServletRequest> authenticationManagerResolver;
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private AuthenticationEntryPoint authenticationEntryPoint = new BearerTokenAuthenticationEntryPoint();
private AuthenticationFailureHandler authenticationFailureHandler = (request, response, exception) -> {
@ -135,9 +139,9 @@ public class BearerTokenAuthenticationFilter extends OncePerRequestFilter { @@ -135,9 +139,9 @@ public class BearerTokenAuthenticationFilter extends OncePerRequestFilter {
try {
AuthenticationManager authenticationManager = this.authenticationManagerResolver.resolve(request);
Authentication authenticationResult = authenticationManager.authenticate(authenticationRequest);
SecurityContext context = SecurityContextHolder.createEmptyContext();
SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authenticationResult);
SecurityContextHolder.setContext(context);
this.securityContextHolderStrategy.setContext(context);
this.securityContextRepository.saveContext(context, request, response);
if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authenticationResult));
@ -145,12 +149,23 @@ public class BearerTokenAuthenticationFilter extends OncePerRequestFilter { @@ -145,12 +149,23 @@ public class BearerTokenAuthenticationFilter extends OncePerRequestFilter {
filterChain.doFilter(request, response);
}
catch (AuthenticationException failed) {
SecurityContextHolder.clearContext();
this.securityContextHolderStrategy.clearContext();
this.logger.trace("Failed to process authentication request", failed);
this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed);
}
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
/**
* Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
* authentication success. The default action is not to save the

14
oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/authentication/BearerTokenAuthenticationFilterTests.java

@ -37,6 +37,8 @@ import org.springframework.security.authentication.AuthenticationManagerResolver @@ -37,6 +37,8 @@ import org.springframework.security.authentication.AuthenticationManagerResolver
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.server.resource.BearerTokenError;
import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes;
@ -210,6 +212,18 @@ public class BearerTokenAuthenticationFilterTests { @@ -210,6 +212,18 @@ public class BearerTokenAuthenticationFilterTests {
verify(this.authenticationDetailsSource).buildDetails(this.request);
}
@Test
public void doFilterWhenCustomSecurityContextHolderStrategyThenUses() throws ServletException, IOException {
given(this.bearerTokenResolver.resolve(this.request)).willReturn("token");
BearerTokenAuthenticationFilter filter = addMocks(
new BearerTokenAuthenticationFilter(this.authenticationManager));
SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
given(strategy.createEmptyContext()).willReturn(new SecurityContextImpl());
filter.setSecurityContextHolderStrategy(strategy);
filter.doFilter(this.request, this.response, this.filterChain);
verify(strategy).setContext(any());
}
@Test
public void setAuthenticationEntryPointWhenNullThenThrowsException() {
BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(this.authenticationManager);

Loading…
Cancel
Save