From b82a43568c36aec067e1509a345dec1bd238d33d Mon Sep 17 00:00:00 2001 From: wonderfulrosemari Date: Sat, 28 Feb 2026 02:16:09 +0900 Subject: [PATCH] Add authenticationSuccessHandler for OAuth2 auth code callback Signed-off-by: wonderfulrosemari --- .../oauth2/client/OAuth2ClientConfigurer.java | 20 +++++++ .../client/OAuth2ClientConfigurerTests.java | 55 +++++++++++++++++-- .../OAuth2AuthorizationCodeGrantFilter.java | 21 ++++++- ...uth2AuthorizationCodeGrantFilterTests.java | 27 +++++++++ 4 files changed, 118 insertions(+), 5 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java index 870f560825..c5ad2c004e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java @@ -37,6 +37,7 @@ import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequest import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.web.RedirectStrategy; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.util.Assert; @@ -177,6 +178,8 @@ public final class OAuth2ClientConfigurer> private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private AuthenticationSuccessHandler authenticationSuccessHandler; + private AuthorizationCodeGrantConfigurer() { } @@ -231,6 +234,20 @@ public final class OAuth2ClientConfigurer> return this; } + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling a successful + * authorization response. + * @param authenticationSuccessHandler the handler used for handling a successful + * authorization response + * @return the {@link AuthorizationCodeGrantConfigurer} for further configuration + */ + public AuthorizationCodeGrantConfigurer authenticationSuccessHandler( + AuthenticationSuccessHandler authenticationSuccessHandler) { + Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); + this.authenticationSuccessHandler = authenticationSuccessHandler; + return this; + } + private void init(B builder) { OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( getAccessTokenResponseClient()); @@ -288,6 +305,9 @@ public final class OAuth2ClientConfigurer> if (requestCache != null) { authorizationCodeGrantFilter.setRequestCache(requestCache); } + if (this.authenticationSuccessHandler != null) { + authorizationCodeGrantFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler); + } return authorizationCodeGrantFilter; } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java index 2a2abf6af0..01194b3a4f 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java @@ -61,6 +61,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -106,6 +107,8 @@ public class OAuth2ClientConfigurerTests { private static RequestCache requestCache; + private static AuthenticationSuccessHandler authenticationSuccessHandler; + public final SpringTestContext spring = new SpringTestContext(this); @Autowired @@ -146,6 +149,7 @@ public class OAuth2ClientConfigurerTests { given(accessTokenResponseClient.getTokenResponse(any(OAuth2AuthorizationCodeGrantRequest.class))) .willReturn(accessTokenResponse); requestCache = mock(RequestCache.class); + authenticationSuccessHandler = null; } @Test @@ -345,6 +349,45 @@ public class OAuth2ClientConfigurerTests { verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository); } + @Test + public void configureWhenCustomAuthenticationSuccessHandlerSetThenAuthenticationSuccessHandlerUsed() + throws Exception { + authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class); + this.spring.register(OAuth2ClientConfig.class).autowire(); + Map attributes = new HashMap<>(); + attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId()); + // @formatter:off + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri()) + .clientId(this.registration1.getClientId()) + .redirectUri("http://localhost/client-1") + .state("state") + .attributes(attributes) + .build(); + // @formatter:on + AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); + MockHttpSession session = (MockHttpSession) request.getSession(); + String principalName = "user1"; + TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); + // @formatter:off + MockHttpServletRequestBuilder clientRequest = get("/client-1") + .param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, "state") + .with(authentication(authentication)) + .session(session); + this.mockMvc.perform(clientRequest) + .andExpect(status().isOk()); + // @formatter:on + verify(authenticationSuccessHandler).onAuthenticationSuccess(any(HttpServletRequest.class), + any(HttpServletResponse.class), any()); + OAuth2AuthorizedClient authorizedClient = authorizedClientRepository + .loadAuthorizedClient(this.registration1.getRegistrationId(), authentication, request); + assertThat(authorizedClient).isNotNull(); + } + @EnableWebSecurity @Configuration @EnableWebMvc @@ -359,10 +402,14 @@ public class OAuth2ClientConfigurerTests { .requestCache((cache) -> cache .requestCache(requestCache)) .oauth2Client((client) -> client - .authorizationCodeGrant((code) -> code - .authorizationRequestResolver(authorizationRequestResolver) - .authorizationRedirectStrategy(authorizationRedirectStrategy) - .accessTokenResponseClient(accessTokenResponseClient))); + .authorizationCodeGrant((code) -> { + code.authorizationRequestResolver(authorizationRequestResolver) + .authorizationRedirectStrategy(authorizationRedirectStrategy) + .accessTokenResponseClient(accessTokenResponseClient); + if (authenticationSuccessHandler != null) { + code.authenticationSuccessHandler(authenticationSuccessHandler); + } + })); return http.build(); // @formatter:on } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java index ec6b9ae2a5..94901729eb 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java @@ -46,6 +46,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.RedirectStrategy; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.security.web.savedrequest.RequestCache; @@ -121,6 +122,8 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { private RequestCache requestCache = new HttpSessionRequestCache(); + private AuthenticationSuccessHandler authenticationSuccessHandler; + /** * Constructs an {@code OAuth2AuthorizationCodeGrantFilter} using the provided * parameters. @@ -162,6 +165,18 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { this.requestCache = requestCache; } + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling a successful + * authorization response. + * @param authenticationSuccessHandler the handler used for handling a successful + * authorization response + * @since 7.1 + */ + public final void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) { + Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); + this.authenticationSuccessHandler = authenticationSuccessHandler; + } + /** * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. @@ -217,7 +232,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { } private void processAuthorizationResponse(HttpServletRequest request, HttpServletResponse response) - throws IOException { + throws IOException, ServletException { OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository .removeAuthorizationRequest(request, response); String registrationId = authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID); @@ -254,6 +269,10 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { authenticationResult.getRefreshToken()); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request, response); + if (this.authenticationSuccessHandler != null) { + this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult); + return; + } String redirectUrl = authorizationRequest.getRedirectUri(); SavedRequest savedRequest = this.requestCache.getRequest(request, response); if (savedRequest != null) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java index 7acb136b3a..bd985d5f7a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java @@ -57,6 +57,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.savedrequest.SavedRequest; @@ -152,6 +153,11 @@ public class OAuth2AuthorizationCodeGrantFilterTests { assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null)); } + @Test + public void setAuthenticationSuccessHandlerWhenAuthenticationSuccessHandlerIsNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null)); + } + @Test public void doFilterWhenNotAuthorizationResponseThenNotProcessed() throws Exception { String requestUri = "/path"; @@ -308,6 +314,27 @@ public class OAuth2AuthorizationCodeGrantFilterTests { assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1"); } + @Test + public void doFilterWhenAuthorizationSucceedsAndAuthenticationSuccessHandlerConfiguredThenAuthenticationSuccessHandlerUsed() + throws Exception { + MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1"); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); + this.setUpAuthenticationResult(this.registration1); + AuthenticationSuccessHandler authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class); + this.filter.setAuthenticationSuccessHandler(authenticationSuccessHandler); + this.filter.doFilter(authorizationResponse, response, filterChain); + verify(authenticationSuccessHandler).onAuthenticationSuccess(any(HttpServletRequest.class), + any(HttpServletResponse.class), any(Authentication.class)); + verifyNoInteractions(filterChain); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientService + .loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1); + assertThat(authorizedClient).isNotNull(); + assertThat(response.getRedirectedUrl()).isNull(); + } + @Test public void doFilterWhenCustomSecurityContextHolderStrategyThenUses() throws Exception { MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");