From c53e66a217fc41668b2f08e7cc9c87ac5ffa4d70 Mon Sep 17 00:00:00 2001 From: Joe Grandja <10884212+jgrandja@users.noreply.github.com> Date: Fri, 28 Nov 2025 06:06:35 -0500 Subject: [PATCH] OAuth2AuthorizationEndpointFilter is applied after AuthorizationFilter Closes gh-18251 --- ...OAuth2AuthorizationEndpointConfigurer.java | 32 ++++- .../OAuth2AuthorizationCodeGrantTests.java | 21 ++- ...tionCodeRequestAuthenticationProvider.java | 67 +++++---- ...izationCodeRequestAuthenticationToken.java | 10 ++ .../OAuth2AuthorizationEndpointFilter.java | 134 ++++++++++++++++-- ...odeRequestAuthenticationProviderTests.java | 12 +- ...Auth2AuthorizationEndpointFilterTests.java | 37 ++--- 7 files changed, 231 insertions(+), 82 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationEndpointConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationEndpointConfigurer.java index 678347f761..495358e86e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationEndpointConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationEndpointConfigurer.java @@ -16,10 +16,12 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; +import jakarta.servlet.Filter; import jakarta.servlet.http.HttpServletRequest; import org.springframework.http.HttpMethod; @@ -36,10 +38,12 @@ import org.springframework.security.oauth2.server.authorization.authentication.O import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationValidator; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter; +import org.springframework.security.web.access.intercept.AuthorizationFilter; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; @@ -50,6 +54,7 @@ import org.springframework.security.web.servlet.util.matcher.PathPatternRequestM import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; /** @@ -83,6 +88,8 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C private Consumer authorizationCodeRequestAuthenticationValidator; + private Consumer authorizationCodeRequestAuthenticationValidatorComposite; + private SessionAuthenticationStrategy sessionAuthenticationStrategy; /** @@ -248,8 +255,16 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C authenticationProviders.addAll(0, this.authenticationProviders); } this.authenticationProvidersConsumer.accept(authenticationProviders); - authenticationProviders.forEach( - (authenticationProvider) -> httpSecurity.authenticationProvider(postProcess(authenticationProvider))); + authenticationProviders.forEach((authenticationProvider) -> { + httpSecurity.authenticationProvider(postProcess(authenticationProvider)); + if (authenticationProvider instanceof OAuth2AuthorizationCodeRequestAuthenticationProvider) { + Method method = ReflectionUtils.findMethod(OAuth2AuthorizationCodeRequestAuthenticationProvider.class, + "getAuthenticationValidatorComposite"); + ReflectionUtils.makeAccessible(method); + this.authorizationCodeRequestAuthenticationValidatorComposite = (Consumer) ReflectionUtils + .invokeMethod(method, authenticationProvider); + } + }); } @Override @@ -282,7 +297,18 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C if (this.sessionAuthenticationStrategy != null) { authorizationEndpointFilter.setSessionAuthenticationStrategy(this.sessionAuthenticationStrategy); } - httpSecurity.addFilterBefore(postProcess(authorizationEndpointFilter), + httpSecurity.addFilterAfter(postProcess(authorizationEndpointFilter), AuthorizationFilter.class); + // Create and add + // OAuth2AuthorizationEndpointFilter.OAuth2AuthorizationCodeRequestValidatingFilter + Method method = ReflectionUtils.findMethod(OAuth2AuthorizationEndpointFilter.class, + "createAuthorizationCodeRequestValidatingFilter", RegisteredClientRepository.class, Consumer.class); + ReflectionUtils.makeAccessible(method); + RegisteredClientRepository registeredClientRepository = OAuth2ConfigurerUtils + .getRegisteredClientRepository(httpSecurity); + Filter authorizationCodeRequestValidatingFilter = (Filter) ReflectionUtils.invokeMethod(method, + authorizationEndpointFilter, registeredClientRepository, + this.authorizationCodeRequestAuthenticationValidatorComposite); + httpSecurity.addFilterBefore(postProcess(authorizationCodeRequestValidatingFilter), AbstractPreAuthenticatedProcessingFilter.class); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index 7b485487d8..af8b3123a4 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -307,8 +307,8 @@ public class OAuth2AuthorizationCodeGrantTests { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); this.mvc - .perform( - get(DEFAULT_AUTHORIZATION_ENDPOINT_URI).params(getAuthorizationRequestParameters(registeredClient))) + .perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) + .queryParams(getAuthorizationRequestParameters(registeredClient))) .andExpect(status().isBadRequest()) .andReturn(); } @@ -851,21 +851,31 @@ public class OAuth2AuthorizationCodeGrantTests { this.spring.register(AuthorizationServerConfigurationCustomAuthorizationEndpoint.class).autowire(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + this.registeredClientRepository.save(registeredClient); + TestingAuthenticationToken principal = new TestingAuthenticationToken("principalName", "password"); + Map additionalParameters = new HashMap<>(); + additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE); + additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); + OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken( + "https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal, + registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes(), + additionalParameters); OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode("code", Instant.now(), Instant.now().plus(5, ChronoUnit.MINUTES)); OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken( "https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal, authorizationCode, registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes()); - given(authorizationRequestConverter.convert(any())).willReturn(authorizationCodeRequestAuthenticationResult); + given(authorizationRequestConverter.convert(any())).willReturn(authorizationCodeRequestAuthentication); given(authorizationRequestAuthenticationProvider .supports(eq(OAuth2AuthorizationCodeRequestAuthenticationToken.class))).willReturn(true); given(authorizationRequestAuthenticationProvider.authenticate(any())) .willReturn(authorizationCodeRequestAuthenticationResult); this.mvc - .perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI).params(getAuthorizationRequestParameters(registeredClient)) + .perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) + .queryParams(getAuthorizationRequestParameters(registeredClient)) .with(user("user"))) .andExpect(status().isOk()); @@ -880,8 +890,7 @@ public class OAuth2AuthorizationCodeGrantTests { || converter instanceof OAuth2AuthorizationCodeRequestAuthenticationConverter || converter instanceof OAuth2AuthorizationConsentAuthenticationConverter); - verify(authorizationRequestAuthenticationProvider) - .authenticate(eq(authorizationCodeRequestAuthenticationResult)); + verify(authorizationRequestAuthenticationProvider).authenticate(eq(authorizationCodeRequestAuthentication)); @SuppressWarnings("unchecked") ArgumentCaptor> authenticationProvidersCaptor = ArgumentCaptor diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java index 05368c6abb..6c3fd90eff 100644 --- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java @@ -190,33 +190,31 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen OAuth2AuthorizationCodeRequestAuthenticationContext.Builder authenticationContextBuilder = OAuth2AuthorizationCodeRequestAuthenticationContext .with(authorizationCodeRequestAuthentication) .registeredClient(registeredClient); - OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext = authenticationContextBuilder - .build(); - // grant_type - OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_AUTHORIZATION_GRANT_TYPE_VALIDATOR - .accept(authenticationContext); + if (!authorizationCodeRequestAuthentication.isValidated()) { + OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext = authenticationContextBuilder + .build(); - // redirect_uri and scope - this.authenticationValidator.accept(authenticationContext); + // grant_type + OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_AUTHORIZATION_GRANT_TYPE_VALIDATOR + .accept(authenticationContext); - // code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE) - OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_CODE_CHALLENGE_VALIDATOR - .accept(authenticationContext); + // redirect_uri and scope + this.authenticationValidator.accept(authenticationContext); - // prompt (OPTIONAL for OpenID Connect 1.0 Authentication Request) - Set promptValues = Collections.emptySet(); - if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID)) { - String prompt = (String) authorizationCodeRequestAuthentication.getAdditionalParameters().get("prompt"); - if (StringUtils.hasText(prompt)) { - OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_PROMPT_VALIDATOR - .accept(authenticationContext); - promptValues = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(prompt, " "))); - } - } + // code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE) + OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_CODE_CHALLENGE_VALIDATOR + .accept(authenticationContext); - if (this.logger.isTraceEnabled()) { - this.logger.trace("Validated authorization code request parameters"); + // prompt (OPTIONAL for OpenID Connect 1.0 Authentication Request) + OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_PROMPT_VALIDATOR + .accept(authenticationContext); + + authorizationCodeRequestAuthentication.setValidated(true); + + if (this.logger.isTraceEnabled()) { + this.logger.trace("Validated authorization code request parameters"); + } } // --------------- @@ -224,17 +222,23 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen // --------------- Authentication principal = (Authentication) authorizationCodeRequestAuthentication.getPrincipal(); + + Set promptValues = Collections.emptySet(); + if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID)) { + String prompt = (String) authorizationCodeRequestAuthentication.getAdditionalParameters().get("prompt"); + if (StringUtils.hasText(prompt)) { + promptValues = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(prompt, " "))); + } + } + if (!isPrincipalAuthenticated(principal)) { if (promptValues.contains(OidcPrompt.NONE)) { - // Return an error instead of displaying the login page (via the - // configured AuthenticationEntryPoint) throwError("login_required", "prompt", authorizationCodeRequestAuthentication, registeredClient); } - if (this.logger.isTraceEnabled()) { - this.logger.trace("Did not authenticate authorization code request since principal not authenticated"); + else { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, "principal", authorizationCodeRequestAuthentication, + registeredClient); } - // Return the authorization request as-is where isAuthenticated() is false - return authorizationCodeRequestAuthentication; } OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() @@ -400,6 +404,13 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen this.authorizationConsentRequired = authorizationConsentRequired; } + Consumer getAuthenticationValidatorComposite() { + return OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_AUTHORIZATION_GRANT_TYPE_VALIDATOR + .andThen(this.authenticationValidator) + .andThen(OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_CODE_CHALLENGE_VALIDATOR) + .andThen(OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_PROMPT_VALIDATOR); + } + private static boolean isAuthorizationConsentRequired( OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext) { if (!authenticationContext.getRegisteredClient().getClientSettings().isRequireAuthorizationConsent()) { diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationToken.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationToken.java index 95e0489a56..bf8d0445d7 100644 --- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationToken.java +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationToken.java @@ -42,6 +42,8 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken private final OAuth2AuthorizationCode authorizationCode; + private boolean validated; + /** * Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationToken} using the * provided parameters. @@ -89,4 +91,12 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken return this.authorizationCode; } + final boolean isValidated() { + return this.validated; + } + + final void setValidated(boolean validated) { + this.validated = validated; + } + } diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index dddc84336c..99de407539 100644 --- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -17,11 +17,14 @@ package org.springframework.security.oauth2.server.authorization.web; import java.io.IOException; +import java.lang.reflect.Field; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collections; import java.util.Set; +import java.util.function.Consumer; +import jakarta.servlet.Filter; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; @@ -38,14 +41,18 @@ import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationContext; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationException; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter; import org.springframework.security.web.DefaultRedirectStrategy; @@ -64,6 +71,7 @@ import org.springframework.security.web.util.matcher.NegatedRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.util.UriComponentsBuilder; @@ -180,21 +188,18 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte } try { - Authentication authentication = this.authenticationConverter.convert(request); - if (authentication instanceof AbstractAuthenticationToken authenticationToken) { - authenticationToken.setDetails(this.authenticationDetailsSource.buildDetails(request)); + // Get the pre-validated authorization code request (if available), + // which was set by OAuth2AuthorizationCodeRequestValidatingFilter + Authentication authentication = (Authentication) request + .getAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName()); + if (authentication == null) { + authentication = this.authenticationConverter.convert(request); + if (authentication instanceof AbstractAuthenticationToken authenticationToken) { + authenticationToken.setDetails(this.authenticationDetailsSource.buildDetails(request)); + } } Authentication authenticationResult = this.authenticationManager.authenticate(authentication); - if (!authenticationResult.isAuthenticated()) { - // If the Principal (Resource Owner) is not authenticated then pass - // through the chain - // with the expectation that the authentication process will commence via - // AuthenticationEntryPoint - filterChain.doFilter(request, response); - return; - } - if (authenticationResult instanceof OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthenticationToken) { if (this.logger.isTraceEnabled()) { this.logger.trace("Authorization consent is required"); @@ -401,4 +406,109 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte this.redirectStrategy.sendRedirect(request, response, redirectUri); } + Filter createAuthorizationCodeRequestValidatingFilter(RegisteredClientRepository registeredClientRepository, + Consumer authenticationValidator) { + return new OAuth2AuthorizationCodeRequestValidatingFilter(registeredClientRepository, authenticationValidator); + } + + /** + * A {@code Filter} that is applied before {@code OAuth2AuthorizationEndpointFilter} + * and handles the pre-validation of an OAuth 2.0 Authorization Code Request. + */ + private final class OAuth2AuthorizationCodeRequestValidatingFilter extends OncePerRequestFilter { + + private final RegisteredClientRepository registeredClientRepository; + + private final Consumer authenticationValidator; + + private final Field setValidatedField; + + private OAuth2AuthorizationCodeRequestValidatingFilter(RegisteredClientRepository registeredClientRepository, + Consumer authenticationValidator) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + Assert.notNull(authenticationValidator, "authenticationValidator cannot be null"); + this.registeredClientRepository = registeredClientRepository; + this.authenticationValidator = authenticationValidator; + this.setValidatedField = ReflectionUtils.findField(OAuth2AuthorizationCodeRequestAuthenticationToken.class, + "validated"); + ReflectionUtils.makeAccessible(this.setValidatedField); + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + if (!OAuth2AuthorizationEndpointFilter.this.authorizationEndpointMatcher.matches(request)) { + filterChain.doFilter(request, response); + return; + } + + try { + Authentication authentication = OAuth2AuthorizationEndpointFilter.this.authenticationConverter + .convert(request); + if (!(authentication instanceof OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication)) { + filterChain.doFilter(request, response); + return; + } + + String requestUri = (String) authorizationCodeRequestAuthentication.getAdditionalParameters() + .get(OAuth2ParameterNames.REQUEST_URI); + if (StringUtils.hasText(requestUri)) { + filterChain.doFilter(request, response); + return; + } + + authorizationCodeRequestAuthentication.setDetails( + OAuth2AuthorizationEndpointFilter.this.authenticationDetailsSource.buildDetails(request)); + + RegisteredClient registeredClient = this.registeredClientRepository + .findByClientId(authorizationCodeRequestAuthentication.getClientId()); + if (registeredClient == null) { + String redirectUri = null; // Prevent redirect + OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken( + authorizationCodeRequestAuthentication.getAuthorizationUri(), + authorizationCodeRequestAuthentication.getClientId(), + (Authentication) authorizationCodeRequestAuthentication.getPrincipal(), redirectUri, + authorizationCodeRequestAuthentication.getState(), + authorizationCodeRequestAuthentication.getScopes(), + authorizationCodeRequestAuthentication.getAdditionalParameters()); + + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, + "OAuth 2.0 Parameter: " + OAuth2ParameterNames.CLIENT_ID, + "https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1"); + throw new OAuth2AuthorizationCodeRequestAuthenticationException(error, + authorizationCodeRequestAuthenticationResult); + } + + OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext = OAuth2AuthorizationCodeRequestAuthenticationContext + .with(authorizationCodeRequestAuthentication) + .registeredClient(registeredClient) + .build(); + + this.authenticationValidator.accept(authenticationContext); + + ReflectionUtils.setField(this.setValidatedField, authorizationCodeRequestAuthentication, true); + + // Set the validated authorization code request as a request + // attribute + // to be used upstream by OAuth2AuthorizationEndpointFilter + request.setAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName(), + authorizationCodeRequestAuthentication); + + filterChain.doFilter(request, response); + } + catch (OAuth2AuthenticationException ex) { + if (this.logger.isTraceEnabled()) { + this.logger.trace(LogMessage.format("Authorization request failed: %s", ex.getError()), ex); + } + OAuth2AuthorizationEndpointFilter.this.authenticationFailureHandler.onAuthenticationFailure(request, + response, ex); + } + finally { + request.removeAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName()); + } + } + + } + } diff --git a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java index c1ac268067..15c536ad6c 100644 --- a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java +++ b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java @@ -428,7 +428,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests { } @Test - public void authenticateWhenPrincipalNotAuthenticatedThenReturnAuthorizationCodeRequest() { + public void authenticateWhenPrincipalNotAuthenticatedThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); given(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .willReturn(registeredClient); @@ -438,12 +438,10 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests { OAuth2AuthorizationCodeRequestAuthenticationToken authentication = new OAuth2AuthorizationCodeRequestAuthenticationToken( AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, redirectUri, STATE, registeredClient.getScopes(), createPkceParameters()); - - OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult = (OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider - .authenticate(authentication); - - assertThat(authenticationResult).isSameAs(authentication); - assertThat(authenticationResult.isAuthenticated()).isFalse(); + assertThatExceptionOfType(OAuth2AuthorizationCodeRequestAuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .satisfies((ex) -> assertAuthenticationException(ex, OAuth2ErrorCodes.INVALID_REQUEST, "principal", + authentication.getRedirectUri())); } @Test diff --git a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 063ecaa759..e37cf233ae 100644 --- a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -372,7 +372,11 @@ public class OAuth2AuthorizationEndpointFilterTests { given(authenticationConverter.convert(any())).willReturn(authorizationCodeRequestAuthentication); this.filter.setAuthenticationConverter(authenticationConverter); - given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthentication); + OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken( + AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, this.authorizationCode, + registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes()); + authorizationCodeRequestAuthenticationResult.setAuthenticated(true); + given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthenticationResult); MockHttpServletRequest request = createAuthorizationRequest(registeredClient); MockHttpServletResponse response = new MockHttpServletResponse(); @@ -382,7 +386,7 @@ public class OAuth2AuthorizationEndpointFilterTests { verify(authenticationConverter).convert(any()); verify(this.authenticationManager).authenticate(any()); - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + verifyNoInteractions(filterChain); } @Test @@ -461,9 +465,6 @@ public class OAuth2AuthorizationEndpointFilterTests { @Test public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken( - AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, - registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes(), null); MockHttpServletRequest request = createAuthorizationRequest(registeredClient); AuthenticationDetailsSource authenticationDetailsSource = mock( @@ -472,36 +473,20 @@ public class OAuth2AuthorizationEndpointFilterTests { given(authenticationDetailsSource.buildDetails(request)).willReturn(webAuthenticationDetails); this.filter.setAuthenticationDetailsSource(authenticationDetailsSource); - given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthentication); - - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verify(authenticationDetailsSource).buildDetails(any()); - verify(this.authenticationManager).authenticate(any()); - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - } - - @Test - public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication() throws Exception { - this.principal.setAuthenticated(false); - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken( - AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, - registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes(), null); - authorizationCodeRequestAuthenticationResult.setAuthenticated(false); + AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, this.authorizationCode, + registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes()); + authorizationCodeRequestAuthenticationResult.setAuthenticated(true); given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthenticationResult); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); + verify(authenticationDetailsSource).buildDetails(any()); verify(this.authenticationManager).authenticate(any()); - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + verifyNoInteractions(filterChain); } @Test