diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java index 573eb83e..d7306a49 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java @@ -38,7 +38,6 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; /** @@ -114,8 +113,8 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP authenticatedCredentials = true; } - authenticatedCredentials = authenticatedCredentials || - authenticatePkceIfAvailable(clientAuthentication, registeredClient); + boolean pkceAuthenticated = authenticatePkceIfAvailable(clientAuthentication, registeredClient); + authenticatedCredentials = authenticatedCredentials || pkceAuthenticated; if (!authenticatedCredentials) { throwInvalidClient(); } @@ -133,7 +132,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP RegisteredClient registeredClient) { Map parameters = clientAuthentication.getAdditionalParameters(); - if (CollectionUtils.isEmpty(parameters) || !authorizationCodeGrant(parameters)) { + if (!authorizationCodeGrant(parameters)) { return false; } @@ -149,9 +148,12 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP String codeChallenge = (String) authorizationRequest.getAdditionalParameters() .get(PkceParameterNames.CODE_CHALLENGE); - if (!StringUtils.hasText(codeChallenge) && - registeredClient.getClientSettings().isRequireProofKey()) { - throwInvalidClient(); + if (!StringUtils.hasText(codeChallenge)) { + if (registeredClient.getClientSettings().isRequireProofKey()) { + throwInvalidClient(); + } else { + return false; + } } String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters() @@ -174,7 +176,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP if (!StringUtils.hasText(codeVerifier)) { return false; } else if (!StringUtils.hasText(codeChallengeMethod) || "plain".equals(codeChallengeMethod)) { - return codeVerifier.equals(codeChallenge); + return codeVerifier.equals(codeChallenge); } else if ("S256".equals(codeChallengeMethod)) { try { MessageDigest md = MessageDigest.getInstance("SHA-256"); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretBasicAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretBasicAuthenticationConverter.java index 9dd09b43..4a00bace 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretBasicAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretBasicAuthenticationConverter.java @@ -97,7 +97,7 @@ public final class ClientSecretBasicAuthenticationConverter implements Authentic private static Map extractAdditionalParameters(HttpServletRequest request) { Map additionalParameters = Collections.emptyMap(); - if (OAuth2EndpointUtils.matchesPkceTokenRequest(request)) { + if (OAuth2EndpointUtils.matchesAuthorizationCodeGrantRequest(request)) { // Confidential clients can also leverage PKCE additionalParameters = new HashMap<>(OAuth2EndpointUtils.getParameters(request).toSingleValueMap()); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java index b9ed9b6e..a8418ebd 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java @@ -77,7 +77,7 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica private static Map extractAdditionalParameters(HttpServletRequest request) { Map additionalParameters = Collections.emptyMap(); - if (OAuth2EndpointUtils.matchesPkceTokenRequest(request)) { + if (OAuth2EndpointUtils.matchesAuthorizationCodeGrantRequest(request)) { // Confidential clients can also leverage PKCE additionalParameters = new HashMap<>(OAuth2EndpointUtils.getParameters(request).toSingleValueMap()); additionalParameters.remove(OAuth2ParameterNames.CLIENT_ID); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java index b53886d7..fd2ba5d4 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java @@ -52,10 +52,14 @@ final class OAuth2EndpointUtils { return parameters; } - static boolean matchesPkceTokenRequest(HttpServletRequest request) { + static boolean matchesAuthorizationCodeGrantRequest(HttpServletRequest request) { return AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals( request.getParameter(OAuth2ParameterNames.GRANT_TYPE)) && - request.getParameter(OAuth2ParameterNames.CODE) != null && + request.getParameter(OAuth2ParameterNames.CODE) != null; + } + + static boolean matchesPkceTokenRequest(HttpServletRequest request) { + return matchesAuthorizationCodeGrantRequest(request) && request.getParameter(PkceParameterNames.CODE_VERIFIER) != null; } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index 7c4b3e8c..364b43c8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -370,6 +370,35 @@ public class OAuth2AuthorizationCodeGrantTests { assertThat(authorizationCodeToken.getMetadata().get(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME)).isEqualTo(true); } + @Test + public void requestWhenConfidentialClientWithPkceAndMissingCodeVerifierThenUnauthorized() throws Exception { + this.spring.register(AuthorizationServerConfiguration.class).autowire(); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + this.registeredClientRepository.save(registeredClient); + + MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) + .params(getAuthorizationRequestParameters(registeredClient)) + .param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) + .param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") + .with(user("user"))) + .andExpect(status().is3xxRedirection()) + .andReturn(); + String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state"); + + String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); + OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); + assertThat(authorizationCodeAuthorization).isNotNull(); + assertThat(authorizationCodeAuthorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + + this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI) + .params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization)) + .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) + .header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient))) + .andExpect(status().isUnauthorized()); + } + @Test public void requestWhenCustomJwtEncoderThenUsed() throws Exception { this.spring.register(AuthorizationServerConfigurationWithJwtEncoder.class).autowire(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java index 3b49f786..c04191c4 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java @@ -36,7 +36,6 @@ import org.springframework.security.oauth2.server.authorization.TestOAuth2Author import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; -import org.springframework.security.oauth2.server.authorization.config.ClientSettings; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -182,6 +181,26 @@ public class OAuth2ClientAuthenticationProviderTests { assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); } + @Test + public void authenticateWhenAuthorizationCodeGrantAndValidCredentialsThenAuthenticated() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(TestOAuth2Authorizations.authorization().build()); + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret(), createAuthorizationCodeTokenParameters()); + OAuth2ClientAuthenticationToken authenticationResult = + (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + verify(this.passwordEncoder).matches(any(), any()); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getCredentials().toString()).isEqualTo(registeredClient.getClientSecret()); + assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); + } + @Test public void authenticateWhenPkceAndInvalidCodeThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); @@ -208,20 +227,18 @@ public class OAuth2ClientAuthenticationProviderTests { } @Test - public void authenticateWhenPkceAndRequireProofKeyAndMissingCodeChallengeThenThrowOAuth2AuthenticationException() { - RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient() - .clientSettings(ClientSettings.builder().requireProofKey(true).build()) - .build(); + public void authenticateWhenPkceAndPublicClientAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient) + .authorization(registeredClient, createPkceAuthorizationParametersPlain()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) .thenReturn(authorization); - Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); + Map parameters = createAuthorizationCodeTokenParameters(); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); @@ -234,8 +251,8 @@ public class OAuth2ClientAuthenticationProviderTests { } @Test - public void authenticateWhenPkceAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() { - RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); + public void authenticateWhenPkceAndConfidentialClientAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); @@ -245,11 +262,10 @@ public class OAuth2ClientAuthenticationProviderTests { when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) .thenReturn(authorization); - Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); - parameters.remove(PkceParameterNames.CODE_VERIFIER); + Map parameters = createAuthorizationCodeTokenParameters(); OAuth2ClientAuthenticationToken authentication = - new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret(), parameters); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -425,10 +441,15 @@ public class OAuth2ClientAuthenticationProviderTests { .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } - private static Map createPkceTokenParameters(String codeVerifier) { + private static Map createAuthorizationCodeTokenParameters() { Map parameters = new HashMap<>(); parameters.put(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); parameters.put(OAuth2ParameterNames.CODE, AUTHORIZATION_CODE); + return parameters; + } + + private static Map createPkceTokenParameters(String codeVerifier) { + Map parameters = createAuthorizationCodeTokenParameters(); parameters.put(PkceParameterNames.CODE_VERIFIER, codeVerifier); return parameters; }