From b455268fa1b801f244f2a94dac3d416d5e481397 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Wed, 3 Nov 2021 11:54:59 -0400 Subject: [PATCH] Polish OAuth2ClientAuthenticationProviderTests --- .../OAuth2ClientAuthenticationProvider.java | 33 +++++++----- .../web/OAuth2ClientAuthenticationFilter.java | 4 +- ...uth2ClientAuthenticationProviderTests.java | 54 ++++++++++++------- 3 files changed, 59 insertions(+), 32 deletions(-) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java index d7306a49..b99a3332 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java @@ -28,6 +28,7 @@ import org.springframework.security.crypto.factory.PasswordEncoderFactories; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2TokenType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -54,6 +55,7 @@ import org.springframework.util.StringUtils; * @see PasswordEncoder */ public final class OAuth2ClientAuthenticationProvider implements AuthenticationProvider { + private static final String CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-01#section-3.2.1"; private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); private final RegisteredClientRepository registeredClientRepository; private final OAuth2AuthorizationService authorizationService; @@ -95,28 +97,28 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP String clientId = clientAuthentication.getPrincipal().toString(); RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); if (registeredClient == null) { - throwInvalidClient(); + throwInvalidClient(OAuth2ParameterNames.CLIENT_ID); } if (!registeredClient.getClientAuthenticationMethods().contains( clientAuthentication.getClientAuthenticationMethod())) { - throwInvalidClient(); + throwInvalidClient("authentication_method"); } - boolean authenticatedCredentials = false; + boolean credentialsAuthenticated = false; if (clientAuthentication.getCredentials() != null) { String clientSecret = clientAuthentication.getCredentials().toString(); if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) { - throwInvalidClient(); + throwInvalidClient(OAuth2ParameterNames.CLIENT_SECRET); } - authenticatedCredentials = true; + credentialsAuthenticated = true; } boolean pkceAuthenticated = authenticatePkceIfAvailable(clientAuthentication, registeredClient); - authenticatedCredentials = authenticatedCredentials || pkceAuthenticated; - if (!authenticatedCredentials) { - throwInvalidClient(); + credentialsAuthenticated = credentialsAuthenticated || pkceAuthenticated; + if (!credentialsAuthenticated) { + throwInvalidClient("credentials"); } return new OAuth2ClientAuthenticationToken(registeredClient, @@ -140,7 +142,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP (String) parameters.get(OAuth2ParameterNames.CODE), AUTHORIZATION_CODE_TOKEN_TYPE); if (authorization == null) { - throwInvalidClient(); + throwInvalidClient(OAuth2ParameterNames.CODE); } OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( @@ -150,7 +152,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP .get(PkceParameterNames.CODE_CHALLENGE); if (!StringUtils.hasText(codeChallenge)) { if (registeredClient.getClientSettings().isRequireProofKey()) { - throwInvalidClient(); + throwInvalidClient(PkceParameterNames.CODE_CHALLENGE); } else { return false; } @@ -160,7 +162,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP .get(PkceParameterNames.CODE_CHALLENGE_METHOD); String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER); if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) { - throwInvalidClient(); + throwInvalidClient(PkceParameterNames.CODE_VERIFIER); } return true; @@ -191,7 +193,12 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP throw new OAuth2AuthenticationException(OAuth2ErrorCodes.SERVER_ERROR); } - private static void throwInvalidClient() { - throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT); + private static void throwInvalidClient(String parameterName) { + OAuth2Error error = new OAuth2Error( + OAuth2ErrorCodes.INVALID_CLIENT, + "Client authentication failed: " + parameterName, + CLIENT_AUTHENTICATION_ERROR_URI); + throw new OAuth2AuthenticationException(error); } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java index b20ae7e7..3a375b52 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java @@ -178,7 +178,9 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter } else { httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); } - this.errorHttpResponseConverter.write(error, null, httpResponse); + // We don't want to reveal too much information to the caller so just return the error code + OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode()); + this.errorHttpResponseConverter.write(errorResponse, null, httpResponse); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java index c04191c4..4fc75026 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java @@ -128,8 +128,10 @@ public class OAuth2ClientAuthenticationProviderTests { assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_ID); + }); } @Test @@ -143,8 +145,10 @@ public class OAuth2ClientAuthenticationProviderTests { assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_SECRET); + }); verify(this.passwordEncoder).matches(any(), any()); } @@ -159,8 +163,10 @@ public class OAuth2ClientAuthenticationProviderTests { assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains("credentials"); + }); } @Test @@ -222,8 +228,10 @@ public class OAuth2ClientAuthenticationProviderTests { assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains(OAuth2ParameterNames.CODE); + }); } @Test @@ -246,8 +254,10 @@ public class OAuth2ClientAuthenticationProviderTests { assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER); + }); } @Test @@ -270,8 +280,10 @@ public class OAuth2ClientAuthenticationProviderTests { assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER); + }); } @Test @@ -294,8 +306,10 @@ public class OAuth2ClientAuthenticationProviderTests { assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER); + }); } @Test @@ -318,8 +332,10 @@ public class OAuth2ClientAuthenticationProviderTests { assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER); + }); } @Test @@ -437,8 +453,10 @@ public class OAuth2ClientAuthenticationProviderTests { assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains("authentication_method"); + }); } private static Map createAuthorizationCodeTokenParameters() {