Browse Source

Fix client_secret_basic authentication failures and return challenge

Closes gh-468
pull/1985/head
Joe Grandja 8 months ago
parent
commit
42c18c856f
  1. 60
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java
  2. 10
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java
  3. 26
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java

60
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2020-2024 the original author or authors.
* Copyright 2020-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,6 +34,7 @@ import org.springframework.security.core.Authentication; @@ -34,6 +34,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
@ -53,6 +54,7 @@ import org.springframework.security.web.authentication.AuthenticationConverter; @@ -53,6 +54,7 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
@ -90,6 +92,8 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter @@ -90,6 +92,8 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
private final BasicAuthenticationEntryPoint basicAuthenticationEntryPoint = new BasicAuthenticationEntryPoint();
private AuthenticationConverter authenticationConverter;
private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess;
@ -110,6 +114,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter @@ -110,6 +114,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
Assert.notNull(requestMatcher, "requestMatcher cannot be null");
this.authenticationManager = authenticationManager;
this.requestMatcher = requestMatcher;
this.basicAuthenticationEntryPoint.setRealmName("default");
// @formatter:off
this.authenticationConverter = new DelegatingAuthenticationConverter(
Arrays.asList(
@ -130,8 +135,9 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter @@ -130,8 +135,9 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
return;
}
Authentication authenticationRequest = null;
try {
Authentication authenticationRequest = this.authenticationConverter.convert(request);
authenticationRequest = this.authenticationConverter.convert(request);
if (authenticationRequest instanceof AbstractAuthenticationToken) {
((AbstractAuthenticationToken) authenticationRequest)
.setDetails(this.authenticationDetailsSource.buildDetails(request));
@ -148,7 +154,14 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter @@ -148,7 +154,14 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
if (this.logger.isTraceEnabled()) {
this.logger.trace(LogMessage.format("Client authentication failed: %s", ex.getError()), ex);
}
this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
if (authenticationRequest instanceof OAuth2ClientAuthenticationToken clientAuthentication) {
this.authenticationFailureHandler.onAuthenticationFailure(request, response,
new OAuth2ClientAuthenticationException(ex.getError(), ex, clientAuthentication));
}
else {
this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
}
}
}
@ -200,21 +213,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter @@ -200,21 +213,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
}
private void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) throws IOException {
AuthenticationException authenticationException) throws IOException {
SecurityContextHolder.clearContext();
// TODO
// The authorization server MAY return an HTTP 401 (Unauthorized) status code
// to indicate which HTTP authentication schemes are supported.
// If the client attempted to authenticate via the "Authorization" request header
// field,
// the authorization server MUST respond with an HTTP 401 (Unauthorized) status
// code and
// include the "WWW-Authenticate" response header field
// matching the authentication scheme used by the client.
OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
if (authenticationException instanceof OAuth2ClientAuthenticationException clientAuthenticationException) {
OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationException
.getClientAuthentication();
if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC
.equals(clientAuthentication.getClientAuthenticationMethod())) {
this.basicAuthenticationEntryPoint.commence(request, response, authenticationException);
return;
}
}
OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
@ -249,4 +262,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter @@ -249,4 +262,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
}
}
private static final class OAuth2ClientAuthenticationException extends OAuth2AuthenticationException {
private final OAuth2ClientAuthenticationToken clientAuthentication;
private OAuth2ClientAuthenticationException(OAuth2Error error, Throwable cause,
OAuth2ClientAuthenticationToken clientAuthentication) {
super(error, cause);
Assert.notNull(clientAuthentication, "clientAuthentication cannot be null");
this.clientAuthentication = clientAuthentication;
}
private OAuth2ClientAuthenticationToken getClientAuthentication() {
return this.clientAuthentication;
}
}
}

10
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2020-2024 the original author or authors.
* Copyright 2020-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -538,7 +538,7 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -538,7 +538,7 @@ public class OAuth2AuthorizationCodeGrantTests {
}
@Test
public void requestWhenConfidentialClientWithPkceAndMissingCodeVerifierThenBadRequest() throws Exception {
public void requestWhenConfidentialClientWithPkceAndMissingCodeVerifierThenUnauthorized() throws Exception {
this.spring.register(AuthorizationServerConfiguration.class).autowire();
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@ -569,7 +569,7 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -569,7 +569,7 @@ public class OAuth2AuthorizationCodeGrantTests {
.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
.andExpect(status().isBadRequest());
.andExpect(status().isUnauthorized());
}
// gh-1011
@ -601,7 +601,7 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -601,7 +601,7 @@ public class OAuth2AuthorizationCodeGrantTests {
}
@Test
public void requestWhenConfidentialClientWithPkceAndMissingCodeChallengeButCodeVerifierProvidedThenBadRequest()
public void requestWhenConfidentialClientWithPkceAndMissingCodeChallengeButCodeVerifierProvidedThenUnauthorized()
throws Exception {
this.spring.register(AuthorizationServerConfiguration.class).autowire();
@ -631,7 +631,7 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -631,7 +631,7 @@ public class OAuth2AuthorizationCodeGrantTests {
.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
.param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER)
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
.andExpect(status().isBadRequest());
.andExpect(status().isUnauthorized());
}
@Test

26
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ import org.junit.jupiter.api.BeforeEach; @@ -26,6 +26,7 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
@ -175,26 +176,25 @@ public class OAuth2ClientAuthenticationFilterTests { @@ -175,26 +176,25 @@ public class OAuth2ClientAuthenticationFilterTests {
// gh-889
@Test
public void doFilterWhenRequestMatchesAndClientIdContainsNonPrintableASCIIThenInvalidRequestError()
throws Exception {
public void doFilterWhenRequestMatchesAndClientIdContainsNonPrintableASCIIThenReturnChallenge() throws Exception {
// Hex 00 -> null
String clientId = new String(Hex.decode("00"), StandardCharsets.UTF_8);
assertWhenInvalidClientIdThenInvalidRequestError(clientId);
assertWhenInvalidClientIdThenReturnChallenge(clientId);
// Hex 0a61 -> line feed + a
clientId = new String(Hex.decode("0a61"), StandardCharsets.UTF_8);
assertWhenInvalidClientIdThenInvalidRequestError(clientId);
assertWhenInvalidClientIdThenReturnChallenge(clientId);
// Hex 1b -> escape
clientId = new String(Hex.decode("1b"), StandardCharsets.UTF_8);
assertWhenInvalidClientIdThenInvalidRequestError(clientId);
assertWhenInvalidClientIdThenReturnChallenge(clientId);
// Hex 1b61 -> escape + a
clientId = new String(Hex.decode("1b61"), StandardCharsets.UTF_8);
assertWhenInvalidClientIdThenInvalidRequestError(clientId);
assertWhenInvalidClientIdThenReturnChallenge(clientId);
}
private void assertWhenInvalidClientIdThenInvalidRequestError(String clientId) throws Exception {
private void assertWhenInvalidClientIdThenReturnChallenge(String clientId) throws Exception {
given(this.authenticationConverter.convert(any(HttpServletRequest.class)))
.willReturn(new OAuth2ClientAuthenticationToken(clientId, ClientAuthenticationMethod.CLIENT_SECRET_BASIC,
"secret", null));
@ -210,13 +210,12 @@ public class OAuth2ClientAuthenticationFilterTests { @@ -210,13 +210,12 @@ public class OAuth2ClientAuthenticationFilterTests {
verifyNoInteractions(this.authenticationManager);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
OAuth2Error error = readError(response);
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
assertThat(response.getHeader(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Basic realm=\"default\"");
}
@Test
public void doFilterWhenRequestMatchesAndBadCredentialsThenInvalidClientError() throws Exception {
public void doFilterWhenRequestMatchesAndBadCredentialsThenReturnChallenge() throws Exception {
given(this.authenticationConverter.convert(any(HttpServletRequest.class)))
.willReturn(new OAuth2ClientAuthenticationToken("clientId", ClientAuthenticationMethod.CLIENT_SECRET_BASIC,
"invalid-secret", null));
@ -235,8 +234,7 @@ public class OAuth2ClientAuthenticationFilterTests { @@ -235,8 +234,7 @@ public class OAuth2ClientAuthenticationFilterTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
OAuth2Error error = readError(response);
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(response.getHeader(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Basic realm=\"default\"");
}
@Test

Loading…
Cancel
Save