Browse Source

Polish OAuth2ClientAuthenticationProviderTests

pull/483/head
Joe Grandja 4 years ago
parent
commit
b455268fa1
  1. 33
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java
  2. 4
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java
  3. 54
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java

33
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java

@ -28,6 +28,7 @@ import org.springframework.security.crypto.factory.PasswordEncoderFactories;
import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2TokenType; import org.springframework.security.oauth2.core.OAuth2TokenType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
@ -54,6 +55,7 @@ import org.springframework.util.StringUtils;
* @see PasswordEncoder * @see PasswordEncoder
*/ */
public final class OAuth2ClientAuthenticationProvider implements AuthenticationProvider { public final class OAuth2ClientAuthenticationProvider implements AuthenticationProvider {
private static final String CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-01#section-3.2.1";
private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
private final RegisteredClientRepository registeredClientRepository; private final RegisteredClientRepository registeredClientRepository;
private final OAuth2AuthorizationService authorizationService; private final OAuth2AuthorizationService authorizationService;
@ -95,28 +97,28 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
String clientId = clientAuthentication.getPrincipal().toString(); String clientId = clientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) { if (registeredClient == null) {
throwInvalidClient(); throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
} }
if (!registeredClient.getClientAuthenticationMethods().contains( if (!registeredClient.getClientAuthenticationMethods().contains(
clientAuthentication.getClientAuthenticationMethod())) { clientAuthentication.getClientAuthenticationMethod())) {
throwInvalidClient(); throwInvalidClient("authentication_method");
} }
boolean authenticatedCredentials = false; boolean credentialsAuthenticated = false;
if (clientAuthentication.getCredentials() != null) { if (clientAuthentication.getCredentials() != null) {
String clientSecret = clientAuthentication.getCredentials().toString(); String clientSecret = clientAuthentication.getCredentials().toString();
if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) { if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) {
throwInvalidClient(); throwInvalidClient(OAuth2ParameterNames.CLIENT_SECRET);
} }
authenticatedCredentials = true; credentialsAuthenticated = true;
} }
boolean pkceAuthenticated = authenticatePkceIfAvailable(clientAuthentication, registeredClient); boolean pkceAuthenticated = authenticatePkceIfAvailable(clientAuthentication, registeredClient);
authenticatedCredentials = authenticatedCredentials || pkceAuthenticated; credentialsAuthenticated = credentialsAuthenticated || pkceAuthenticated;
if (!authenticatedCredentials) { if (!credentialsAuthenticated) {
throwInvalidClient(); throwInvalidClient("credentials");
} }
return new OAuth2ClientAuthenticationToken(registeredClient, return new OAuth2ClientAuthenticationToken(registeredClient,
@ -140,7 +142,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
(String) parameters.get(OAuth2ParameterNames.CODE), (String) parameters.get(OAuth2ParameterNames.CODE),
AUTHORIZATION_CODE_TOKEN_TYPE); AUTHORIZATION_CODE_TOKEN_TYPE);
if (authorization == null) { if (authorization == null) {
throwInvalidClient(); throwInvalidClient(OAuth2ParameterNames.CODE);
} }
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
@ -150,7 +152,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
.get(PkceParameterNames.CODE_CHALLENGE); .get(PkceParameterNames.CODE_CHALLENGE);
if (!StringUtils.hasText(codeChallenge)) { if (!StringUtils.hasText(codeChallenge)) {
if (registeredClient.getClientSettings().isRequireProofKey()) { if (registeredClient.getClientSettings().isRequireProofKey()) {
throwInvalidClient(); throwInvalidClient(PkceParameterNames.CODE_CHALLENGE);
} else { } else {
return false; return false;
} }
@ -160,7 +162,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
.get(PkceParameterNames.CODE_CHALLENGE_METHOD); .get(PkceParameterNames.CODE_CHALLENGE_METHOD);
String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER); String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER);
if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) { if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
throwInvalidClient(); throwInvalidClient(PkceParameterNames.CODE_VERIFIER);
} }
return true; return true;
@ -191,7 +193,12 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.SERVER_ERROR); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.SERVER_ERROR);
} }
private static void throwInvalidClient() { private static void throwInvalidClient(String parameterName) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT); OAuth2Error error = new OAuth2Error(
OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName,
CLIENT_AUTHENTICATION_ERROR_URI);
throw new OAuth2AuthenticationException(error);
} }
} }

4
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java

@ -178,7 +178,9 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
} else { } else {
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
} }
this.errorHttpResponseConverter.write(error, null, httpResponse); // We don't want to reveal too much information to the caller so just return the error code
OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode());
this.errorHttpResponseConverter.write(errorResponse, null, httpResponse);
} }
} }

54
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java

@ -128,8 +128,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode") .satisfies(error -> {
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_ID);
});
} }
@Test @Test
@ -143,8 +145,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode") .satisfies(error -> {
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_SECRET);
});
verify(this.passwordEncoder).matches(any(), any()); verify(this.passwordEncoder).matches(any(), any());
} }
@ -159,8 +163,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode") .satisfies(error -> {
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains("credentials");
});
} }
@Test @Test
@ -222,8 +228,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode") .satisfies(error -> {
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(OAuth2ParameterNames.CODE);
});
} }
@Test @Test
@ -246,8 +254,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode") .satisfies(error -> {
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER);
});
} }
@Test @Test
@ -270,8 +280,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode") .satisfies(error -> {
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER);
});
} }
@Test @Test
@ -294,8 +306,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode") .satisfies(error -> {
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER);
});
} }
@Test @Test
@ -318,8 +332,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode") .satisfies(error -> {
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER);
});
} }
@Test @Test
@ -437,8 +453,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode") .satisfies(error -> {
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains("authentication_method");
});
} }
private static Map<String, Object> createAuthorizationCodeTokenParameters() { private static Map<String, Object> createAuthorizationCodeTokenParameters() {

Loading…
Cancel
Save