Browse Source

OAuth2AuthorizationEndpointFilter is applied after AuthorizationFilter

Closes gh-18251
pull/18256/head
Joe Grandja 3 weeks ago
parent
commit
c53e66a217
  1. 32
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationEndpointConfigurer.java
  2. 21
      config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java
  3. 37
      oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java
  4. 10
      oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationToken.java
  5. 130
      oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java
  6. 12
      oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java
  7. 37
      oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

32
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationEndpointConfigurer.java

@ -16,10 +16,12 @@ @@ -16,10 +16,12 @@
package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import jakarta.servlet.Filter;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.http.HttpMethod;
@ -36,10 +38,12 @@ import org.springframework.security.oauth2.server.authorization.authentication.O @@ -36,10 +38,12 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationValidator;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter;
import org.springframework.security.web.access.intercept.AuthorizationFilter;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
@ -50,6 +54,7 @@ import org.springframework.security.web.servlet.util.matcher.PathPatternRequestM @@ -50,6 +54,7 @@ import org.springframework.security.web.servlet.util.matcher.PathPatternRequestM
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
/**
@ -83,6 +88,8 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C @@ -83,6 +88,8 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authorizationCodeRequestAuthenticationValidator;
private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authorizationCodeRequestAuthenticationValidatorComposite;
private SessionAuthenticationStrategy sessionAuthenticationStrategy;
/**
@ -248,8 +255,16 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C @@ -248,8 +255,16 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
authenticationProviders.addAll(0, this.authenticationProviders);
}
this.authenticationProvidersConsumer.accept(authenticationProviders);
authenticationProviders.forEach(
(authenticationProvider) -> httpSecurity.authenticationProvider(postProcess(authenticationProvider)));
authenticationProviders.forEach((authenticationProvider) -> {
httpSecurity.authenticationProvider(postProcess(authenticationProvider));
if (authenticationProvider instanceof OAuth2AuthorizationCodeRequestAuthenticationProvider) {
Method method = ReflectionUtils.findMethod(OAuth2AuthorizationCodeRequestAuthenticationProvider.class,
"getAuthenticationValidatorComposite");
ReflectionUtils.makeAccessible(method);
this.authorizationCodeRequestAuthenticationValidatorComposite = (Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext>) ReflectionUtils
.invokeMethod(method, authenticationProvider);
}
});
}
@Override
@ -282,7 +297,18 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C @@ -282,7 +297,18 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
if (this.sessionAuthenticationStrategy != null) {
authorizationEndpointFilter.setSessionAuthenticationStrategy(this.sessionAuthenticationStrategy);
}
httpSecurity.addFilterBefore(postProcess(authorizationEndpointFilter),
httpSecurity.addFilterAfter(postProcess(authorizationEndpointFilter), AuthorizationFilter.class);
// Create and add
// OAuth2AuthorizationEndpointFilter.OAuth2AuthorizationCodeRequestValidatingFilter
Method method = ReflectionUtils.findMethod(OAuth2AuthorizationEndpointFilter.class,
"createAuthorizationCodeRequestValidatingFilter", RegisteredClientRepository.class, Consumer.class);
ReflectionUtils.makeAccessible(method);
RegisteredClientRepository registeredClientRepository = OAuth2ConfigurerUtils
.getRegisteredClientRepository(httpSecurity);
Filter authorizationCodeRequestValidatingFilter = (Filter) ReflectionUtils.invokeMethod(method,
authorizationEndpointFilter, registeredClientRepository,
this.authorizationCodeRequestAuthenticationValidatorComposite);
httpSecurity.addFilterBefore(postProcess(authorizationCodeRequestValidatingFilter),
AbstractPreAuthenticatedProcessingFilter.class);
}

21
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java

@ -307,8 +307,8 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -307,8 +307,8 @@ public class OAuth2AuthorizationCodeGrantTests {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
this.mvc
.perform(
get(DEFAULT_AUTHORIZATION_ENDPOINT_URI).params(getAuthorizationRequestParameters(registeredClient)))
.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.queryParams(getAuthorizationRequestParameters(registeredClient)))
.andExpect(status().isBadRequest())
.andReturn();
}
@ -851,21 +851,31 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -851,21 +851,31 @@ public class OAuth2AuthorizationCodeGrantTests {
this.spring.register(AuthorizationServerConfigurationCustomAuthorizationEndpoint.class).autowire();
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
this.registeredClientRepository.save(registeredClient);
TestingAuthenticationToken principal = new TestingAuthenticationToken("principalName", "password");
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE);
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal,
registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes(),
additionalParameters);
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode("code", Instant.now(),
Instant.now().plus(5, ChronoUnit.MINUTES));
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal, authorizationCode,
registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED,
registeredClient.getScopes());
given(authorizationRequestConverter.convert(any())).willReturn(authorizationCodeRequestAuthenticationResult);
given(authorizationRequestConverter.convert(any())).willReturn(authorizationCodeRequestAuthentication);
given(authorizationRequestAuthenticationProvider
.supports(eq(OAuth2AuthorizationCodeRequestAuthenticationToken.class))).willReturn(true);
given(authorizationRequestAuthenticationProvider.authenticate(any()))
.willReturn(authorizationCodeRequestAuthenticationResult);
this.mvc
.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI).params(getAuthorizationRequestParameters(registeredClient))
.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.queryParams(getAuthorizationRequestParameters(registeredClient))
.with(user("user")))
.andExpect(status().isOk());
@ -880,8 +890,7 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -880,8 +890,7 @@ public class OAuth2AuthorizationCodeGrantTests {
|| converter instanceof OAuth2AuthorizationCodeRequestAuthenticationConverter
|| converter instanceof OAuth2AuthorizationConsentAuthenticationConverter);
verify(authorizationRequestAuthenticationProvider)
.authenticate(eq(authorizationCodeRequestAuthenticationResult));
verify(authorizationRequestAuthenticationProvider).authenticate(eq(authorizationCodeRequestAuthentication));
@SuppressWarnings("unchecked")
ArgumentCaptor<List<AuthenticationProvider>> authenticationProvidersCaptor = ArgumentCaptor

37
oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

@ -190,6 +190,8 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen @@ -190,6 +190,8 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
OAuth2AuthorizationCodeRequestAuthenticationContext.Builder authenticationContextBuilder = OAuth2AuthorizationCodeRequestAuthenticationContext
.with(authorizationCodeRequestAuthentication)
.registeredClient(registeredClient);
if (!authorizationCodeRequestAuthentication.isValidated()) {
OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext = authenticationContextBuilder
.build();
@ -205,36 +207,38 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen @@ -205,36 +207,38 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
.accept(authenticationContext);
// prompt (OPTIONAL for OpenID Connect 1.0 Authentication Request)
Set<String> promptValues = Collections.emptySet();
if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID)) {
String prompt = (String) authorizationCodeRequestAuthentication.getAdditionalParameters().get("prompt");
if (StringUtils.hasText(prompt)) {
OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_PROMPT_VALIDATOR
.accept(authenticationContext);
promptValues = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(prompt, " ")));
}
}
authorizationCodeRequestAuthentication.setValidated(true);
if (this.logger.isTraceEnabled()) {
this.logger.trace("Validated authorization code request parameters");
}
}
// ---------------
// The request is valid - ensure the resource owner is authenticated
// ---------------
Authentication principal = (Authentication) authorizationCodeRequestAuthentication.getPrincipal();
Set<String> promptValues = Collections.emptySet();
if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID)) {
String prompt = (String) authorizationCodeRequestAuthentication.getAdditionalParameters().get("prompt");
if (StringUtils.hasText(prompt)) {
promptValues = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(prompt, " ")));
}
}
if (!isPrincipalAuthenticated(principal)) {
if (promptValues.contains(OidcPrompt.NONE)) {
// Return an error instead of displaying the login page (via the
// configured AuthenticationEntryPoint)
throwError("login_required", "prompt", authorizationCodeRequestAuthentication, registeredClient);
}
if (this.logger.isTraceEnabled()) {
this.logger.trace("Did not authenticate authorization code request since principal not authenticated");
else {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "principal", authorizationCodeRequestAuthentication,
registeredClient);
}
// Return the authorization request as-is where isAuthenticated() is false
return authorizationCodeRequestAuthentication;
}
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
@ -400,6 +404,13 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen @@ -400,6 +404,13 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
this.authorizationConsentRequired = authorizationConsentRequired;
}
Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> getAuthenticationValidatorComposite() {
return OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_AUTHORIZATION_GRANT_TYPE_VALIDATOR
.andThen(this.authenticationValidator)
.andThen(OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_CODE_CHALLENGE_VALIDATOR)
.andThen(OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_PROMPT_VALIDATOR);
}
private static boolean isAuthorizationConsentRequired(
OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext) {
if (!authenticationContext.getRegisteredClient().getClientSettings().isRequireAuthorizationConsent()) {

10
oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationToken.java

@ -42,6 +42,8 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken @@ -42,6 +42,8 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken
private final OAuth2AuthorizationCode authorizationCode;
private boolean validated;
/**
* Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationToken} using the
* provided parameters.
@ -89,4 +91,12 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken @@ -89,4 +91,12 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken
return this.authorizationCode;
}
final boolean isValidated() {
return this.validated;
}
final void setValidated(boolean validated) {
this.validated = validated;
}
}

130
oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@ -17,11 +17,14 @@ @@ -17,11 +17,14 @@
package org.springframework.security.oauth2.server.authorization.web;
import java.io.IOException;
import java.lang.reflect.Field;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.Set;
import java.util.function.Consumer;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
@ -38,14 +41,18 @@ import org.springframework.security.core.AuthenticationException; @@ -38,14 +41,18 @@ import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationContext;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationException;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter;
import org.springframework.security.web.DefaultRedirectStrategy;
@ -64,6 +71,7 @@ import org.springframework.security.web.util.matcher.NegatedRequestMatcher; @@ -64,6 +71,7 @@ import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.UriComponentsBuilder;
@ -180,20 +188,17 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte @@ -180,20 +188,17 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
}
try {
Authentication authentication = this.authenticationConverter.convert(request);
// Get the pre-validated authorization code request (if available),
// which was set by OAuth2AuthorizationCodeRequestValidatingFilter
Authentication authentication = (Authentication) request
.getAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName());
if (authentication == null) {
authentication = this.authenticationConverter.convert(request);
if (authentication instanceof AbstractAuthenticationToken authenticationToken) {
authenticationToken.setDetails(this.authenticationDetailsSource.buildDetails(request));
}
Authentication authenticationResult = this.authenticationManager.authenticate(authentication);
if (!authenticationResult.isAuthenticated()) {
// If the Principal (Resource Owner) is not authenticated then pass
// through the chain
// with the expectation that the authentication process will commence via
// AuthenticationEntryPoint
filterChain.doFilter(request, response);
return;
}
Authentication authenticationResult = this.authenticationManager.authenticate(authentication);
if (authenticationResult instanceof OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthenticationToken) {
if (this.logger.isTraceEnabled()) {
@ -401,4 +406,109 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte @@ -401,4 +406,109 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
this.redirectStrategy.sendRedirect(request, response, redirectUri);
}
Filter createAuthorizationCodeRequestValidatingFilter(RegisteredClientRepository registeredClientRepository,
Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator) {
return new OAuth2AuthorizationCodeRequestValidatingFilter(registeredClientRepository, authenticationValidator);
}
/**
* A {@code Filter} that is applied before {@code OAuth2AuthorizationEndpointFilter}
* and handles the pre-validation of an OAuth 2.0 Authorization Code Request.
*/
private final class OAuth2AuthorizationCodeRequestValidatingFilter extends OncePerRequestFilter {
private final RegisteredClientRepository registeredClientRepository;
private final Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator;
private final Field setValidatedField;
private OAuth2AuthorizationCodeRequestValidatingFilter(RegisteredClientRepository registeredClientRepository,
Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator) {
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
Assert.notNull(authenticationValidator, "authenticationValidator cannot be null");
this.registeredClientRepository = registeredClientRepository;
this.authenticationValidator = authenticationValidator;
this.setValidatedField = ReflectionUtils.findField(OAuth2AuthorizationCodeRequestAuthenticationToken.class,
"validated");
ReflectionUtils.makeAccessible(this.setValidatedField);
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
if (!OAuth2AuthorizationEndpointFilter.this.authorizationEndpointMatcher.matches(request)) {
filterChain.doFilter(request, response);
return;
}
try {
Authentication authentication = OAuth2AuthorizationEndpointFilter.this.authenticationConverter
.convert(request);
if (!(authentication instanceof OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication)) {
filterChain.doFilter(request, response);
return;
}
String requestUri = (String) authorizationCodeRequestAuthentication.getAdditionalParameters()
.get(OAuth2ParameterNames.REQUEST_URI);
if (StringUtils.hasText(requestUri)) {
filterChain.doFilter(request, response);
return;
}
authorizationCodeRequestAuthentication.setDetails(
OAuth2AuthorizationEndpointFilter.this.authenticationDetailsSource.buildDetails(request));
RegisteredClient registeredClient = this.registeredClientRepository
.findByClientId(authorizationCodeRequestAuthentication.getClientId());
if (registeredClient == null) {
String redirectUri = null; // Prevent redirect
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
authorizationCodeRequestAuthentication.getAuthorizationUri(),
authorizationCodeRequestAuthentication.getClientId(),
(Authentication) authorizationCodeRequestAuthentication.getPrincipal(), redirectUri,
authorizationCodeRequestAuthentication.getState(),
authorizationCodeRequestAuthentication.getScopes(),
authorizationCodeRequestAuthentication.getAdditionalParameters());
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST,
"OAuth 2.0 Parameter: " + OAuth2ParameterNames.CLIENT_ID,
"https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1");
throw new OAuth2AuthorizationCodeRequestAuthenticationException(error,
authorizationCodeRequestAuthenticationResult);
}
OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext = OAuth2AuthorizationCodeRequestAuthenticationContext
.with(authorizationCodeRequestAuthentication)
.registeredClient(registeredClient)
.build();
this.authenticationValidator.accept(authenticationContext);
ReflectionUtils.setField(this.setValidatedField, authorizationCodeRequestAuthentication, true);
// Set the validated authorization code request as a request
// attribute
// to be used upstream by OAuth2AuthorizationEndpointFilter
request.setAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName(),
authorizationCodeRequestAuthentication);
filterChain.doFilter(request, response);
}
catch (OAuth2AuthenticationException ex) {
if (this.logger.isTraceEnabled()) {
this.logger.trace(LogMessage.format("Authorization request failed: %s", ex.getError()), ex);
}
OAuth2AuthorizationEndpointFilter.this.authenticationFailureHandler.onAuthenticationFailure(request,
response, ex);
}
finally {
request.removeAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName());
}
}
}
}

12
oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java

@ -428,7 +428,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests { @@ -428,7 +428,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
}
@Test
public void authenticateWhenPrincipalNotAuthenticatedThenReturnAuthorizationCodeRequest() {
public void authenticateWhenPrincipalNotAuthenticatedThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
given(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.willReturn(registeredClient);
@ -438,12 +438,10 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests { @@ -438,12 +438,10 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
OAuth2AuthorizationCodeRequestAuthenticationToken authentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, redirectUri, STATE,
registeredClient.getScopes(), createPkceParameters());
OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult = (OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider
.authenticate(authentication);
assertThat(authenticationResult).isSameAs(authentication);
assertThat(authenticationResult.isAuthenticated()).isFalse();
assertThatExceptionOfType(OAuth2AuthorizationCodeRequestAuthenticationException.class)
.isThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.satisfies((ex) -> assertAuthenticationException(ex, OAuth2ErrorCodes.INVALID_REQUEST, "principal",
authentication.getRedirectUri()));
}
@Test

37
oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@ -372,7 +372,11 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -372,7 +372,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
given(authenticationConverter.convert(any())).willReturn(authorizationCodeRequestAuthentication);
this.filter.setAuthenticationConverter(authenticationConverter);
given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthentication);
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, this.authorizationCode,
registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes());
authorizationCodeRequestAuthenticationResult.setAuthenticated(true);
given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthenticationResult);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
MockHttpServletResponse response = new MockHttpServletResponse();
@ -382,7 +386,7 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -382,7 +386,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
verify(authenticationConverter).convert(any());
verify(this.authenticationManager).authenticate(any());
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
verifyNoInteractions(filterChain);
}
@Test
@ -461,9 +465,6 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -461,9 +465,6 @@ public class OAuth2AuthorizationEndpointFilterTests {
@Test
public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
AUTHORIZATION_URI, registeredClient.getClientId(), this.principal,
registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes(), null);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
AuthenticationDetailsSource<HttpServletRequest, WebAuthenticationDetails> authenticationDetailsSource = mock(
@ -472,36 +473,20 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -472,36 +473,20 @@ public class OAuth2AuthorizationEndpointFilterTests {
given(authenticationDetailsSource.buildDetails(request)).willReturn(webAuthenticationDetails);
this.filter.setAuthenticationDetailsSource(authenticationDetailsSource);
given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthentication);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verify(authenticationDetailsSource).buildDetails(any());
verify(this.authenticationManager).authenticate(any());
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication() throws Exception {
this.principal.setAuthenticated(false);
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
AUTHORIZATION_URI, registeredClient.getClientId(), this.principal,
registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes(), null);
authorizationCodeRequestAuthenticationResult.setAuthenticated(false);
AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, this.authorizationCode,
registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes());
authorizationCodeRequestAuthenticationResult.setAuthenticated(true);
given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthenticationResult);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verify(authenticationDetailsSource).buildDetails(any());
verify(this.authenticationManager).authenticate(any());
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
verifyNoInteractions(filterChain);
}
@Test

Loading…
Cancel
Save