diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index dd4bad4b..c280db99 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -31,7 +31,6 @@ import org.springframework.http.MediaType; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; -import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; @@ -165,7 +164,6 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte request, response, authorizationCodeRequestAuthenticationResult); } catch (OAuth2AuthenticationException ex) { - SecurityContextHolder.clearContext(); this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 35ebe619..77488125 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -41,13 +41,13 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AuthorizationCode; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.oidc.OidcScopes; -import org.springframework.security.oauth2.core.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationException; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; @@ -266,6 +266,7 @@ public class OAuth2AuthorizationEndpointFilterTests { assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?error=errorCode&error_description=errorDescription&error_uri=errorUri&state=state"); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.principal); } @Test