Browse Source

Merge b82a43568c into ea2f2302da

pull/18814/merge
JinHyeokCho 1 week ago committed by GitHub
parent
commit
5e4fe02f6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 20
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java
  2. 55
      config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java
  3. 21
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java
  4. 27
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java

20
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java

@ -37,6 +37,7 @@ import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequest @@ -37,6 +37,7 @@ import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequest
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.util.Assert;
@ -177,6 +178,8 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> @@ -177,6 +178,8 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>>
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
private AuthenticationSuccessHandler authenticationSuccessHandler;
private AuthorizationCodeGrantConfigurer() {
}
@ -231,6 +234,20 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> @@ -231,6 +234,20 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>>
return this;
}
/**
* Sets the {@link AuthenticationSuccessHandler} used for handling a successful
* authorization response.
* @param authenticationSuccessHandler the handler used for handling a successful
* authorization response
* @return the {@link AuthorizationCodeGrantConfigurer} for further configuration
*/
public AuthorizationCodeGrantConfigurer authenticationSuccessHandler(
AuthenticationSuccessHandler authenticationSuccessHandler) {
Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
this.authenticationSuccessHandler = authenticationSuccessHandler;
return this;
}
private void init(B builder) {
OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
getAccessTokenResponseClient());
@ -288,6 +305,9 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> @@ -288,6 +305,9 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>>
if (requestCache != null) {
authorizationCodeGrantFilter.setRequestCache(requestCache);
}
if (this.authenticationSuccessHandler != null) {
authorizationCodeGrantFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler);
}
return authorizationCodeGrantFilter;
}

55
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java

@ -61,6 +61,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -61,6 +61,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
@ -106,6 +107,8 @@ public class OAuth2ClientConfigurerTests { @@ -106,6 +107,8 @@ public class OAuth2ClientConfigurerTests {
private static RequestCache requestCache;
private static AuthenticationSuccessHandler authenticationSuccessHandler;
public final SpringTestContext spring = new SpringTestContext(this);
@Autowired
@ -146,6 +149,7 @@ public class OAuth2ClientConfigurerTests { @@ -146,6 +149,7 @@ public class OAuth2ClientConfigurerTests {
given(accessTokenResponseClient.getTokenResponse(any(OAuth2AuthorizationCodeGrantRequest.class)))
.willReturn(accessTokenResponse);
requestCache = mock(RequestCache.class);
authenticationSuccessHandler = null;
}
@Test
@ -345,6 +349,45 @@ public class OAuth2ClientConfigurerTests { @@ -345,6 +349,45 @@ public class OAuth2ClientConfigurerTests {
verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository);
}
@Test
public void configureWhenCustomAuthenticationSuccessHandlerSetThenAuthenticationSuccessHandlerUsed()
throws Exception {
authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class);
this.spring.register(OAuth2ClientConfig.class).autowire();
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId());
// @formatter:off
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri())
.clientId(this.registration1.getClientId())
.redirectUri("http://localhost/client-1")
.state("state")
.attributes(attributes)
.build();
// @formatter:on
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
MockHttpServletRequest request = new MockHttpServletRequest("GET", "");
MockHttpServletResponse response = new MockHttpServletResponse();
authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
MockHttpSession session = (MockHttpSession) request.getSession();
String principalName = "user1";
TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
// @formatter:off
MockHttpServletRequestBuilder clientRequest = get("/client-1")
.param(OAuth2ParameterNames.CODE, "code")
.param(OAuth2ParameterNames.STATE, "state")
.with(authentication(authentication))
.session(session);
this.mockMvc.perform(clientRequest)
.andExpect(status().isOk());
// @formatter:on
verify(authenticationSuccessHandler).onAuthenticationSuccess(any(HttpServletRequest.class),
any(HttpServletResponse.class), any());
OAuth2AuthorizedClient authorizedClient = authorizedClientRepository
.loadAuthorizedClient(this.registration1.getRegistrationId(), authentication, request);
assertThat(authorizedClient).isNotNull();
}
@EnableWebSecurity
@Configuration
@EnableWebMvc
@ -359,10 +402,14 @@ public class OAuth2ClientConfigurerTests { @@ -359,10 +402,14 @@ public class OAuth2ClientConfigurerTests {
.requestCache((cache) -> cache
.requestCache(requestCache))
.oauth2Client((client) -> client
.authorizationCodeGrant((code) -> code
.authorizationRequestResolver(authorizationRequestResolver)
.authorizationRedirectStrategy(authorizationRedirectStrategy)
.accessTokenResponseClient(accessTokenResponseClient)));
.authorizationCodeGrant((code) -> {
code.authorizationRequestResolver(authorizationRequestResolver)
.authorizationRedirectStrategy(authorizationRedirectStrategy)
.accessTokenResponseClient(accessTokenResponseClient);
if (authenticationSuccessHandler != null) {
code.authenticationSuccessHandler(authenticationSuccessHandler);
}
}));
return http.build();
// @formatter:on
}

21
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java

@ -46,6 +46,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp @@ -46,6 +46,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
@ -121,6 +122,8 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { @@ -121,6 +122,8 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
private RequestCache requestCache = new HttpSessionRequestCache();
private AuthenticationSuccessHandler authenticationSuccessHandler;
/**
* Constructs an {@code OAuth2AuthorizationCodeGrantFilter} using the provided
* parameters.
@ -162,6 +165,18 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { @@ -162,6 +165,18 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
this.requestCache = requestCache;
}
/**
* Sets the {@link AuthenticationSuccessHandler} used for handling a successful
* authorization response.
* @param authenticationSuccessHandler the handler used for handling a successful
* authorization response
* @since 7.1
*/
public final void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) {
Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
this.authenticationSuccessHandler = authenticationSuccessHandler;
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
@ -217,7 +232,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { @@ -217,7 +232,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
}
private void processAuthorizationResponse(HttpServletRequest request, HttpServletResponse response)
throws IOException {
throws IOException, ServletException {
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository
.removeAuthorizationRequest(request, response);
String registrationId = authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID);
@ -254,6 +269,10 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { @@ -254,6 +269,10 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
authenticationResult.getRefreshToken());
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request,
response);
if (this.authenticationSuccessHandler != null) {
this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult);
return;
}
String redirectUrl = authorizationRequest.getRedirectUri();
SavedRequest savedRequest = this.requestCache.getRequest(request, response);
if (savedRequest != null) {

27
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java

@ -57,6 +57,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ @@ -57,6 +57,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.savedrequest.SavedRequest;
@ -152,6 +153,11 @@ public class OAuth2AuthorizationCodeGrantFilterTests { @@ -152,6 +153,11 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null));
}
@Test
public void setAuthenticationSuccessHandlerWhenAuthenticationSuccessHandlerIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null));
}
@Test
public void doFilterWhenNotAuthorizationResponseThenNotProcessed() throws Exception {
String requestUri = "/path";
@ -308,6 +314,27 @@ public class OAuth2AuthorizationCodeGrantFilterTests { @@ -308,6 +314,27 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
}
@Test
public void doFilterWhenAuthorizationSucceedsAndAuthenticationSuccessHandlerConfiguredThenAuthenticationSuccessHandlerUsed()
throws Exception {
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
AuthenticationSuccessHandler authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class);
this.filter.setAuthenticationSuccessHandler(authenticationSuccessHandler);
this.filter.doFilter(authorizationResponse, response, filterChain);
verify(authenticationSuccessHandler).onAuthenticationSuccess(any(HttpServletRequest.class),
any(HttpServletResponse.class), any(Authentication.class));
verifyNoInteractions(filterChain);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService
.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
assertThat(authorizedClient).isNotNull();
assertThat(response.getRedirectedUrl()).isNull();
}
@Test
public void doFilterWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");

Loading…
Cancel
Save