Browse Source

Polish OAuth2ClientAuthenticationProviderTests

pull/483/head
Joe Grandja 4 years ago
parent
commit
b455268fa1
  1. 33
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java
  2. 4
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java
  3. 54
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java

33
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java

@ -28,6 +28,7 @@ import org.springframework.security.crypto.factory.PasswordEncoderFactories; @@ -28,6 +28,7 @@ import org.springframework.security.crypto.factory.PasswordEncoderFactories;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2TokenType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
@ -54,6 +55,7 @@ import org.springframework.util.StringUtils; @@ -54,6 +55,7 @@ import org.springframework.util.StringUtils;
* @see PasswordEncoder
*/
public final class OAuth2ClientAuthenticationProvider implements AuthenticationProvider {
private static final String CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-01#section-3.2.1";
private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
private final RegisteredClientRepository registeredClientRepository;
private final OAuth2AuthorizationService authorizationService;
@ -95,28 +97,28 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP @@ -95,28 +97,28 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
String clientId = clientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throwInvalidClient();
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
}
if (!registeredClient.getClientAuthenticationMethods().contains(
clientAuthentication.getClientAuthenticationMethod())) {
throwInvalidClient();
throwInvalidClient("authentication_method");
}
boolean authenticatedCredentials = false;
boolean credentialsAuthenticated = false;
if (clientAuthentication.getCredentials() != null) {
String clientSecret = clientAuthentication.getCredentials().toString();
if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) {
throwInvalidClient();
throwInvalidClient(OAuth2ParameterNames.CLIENT_SECRET);
}
authenticatedCredentials = true;
credentialsAuthenticated = true;
}
boolean pkceAuthenticated = authenticatePkceIfAvailable(clientAuthentication, registeredClient);
authenticatedCredentials = authenticatedCredentials || pkceAuthenticated;
if (!authenticatedCredentials) {
throwInvalidClient();
credentialsAuthenticated = credentialsAuthenticated || pkceAuthenticated;
if (!credentialsAuthenticated) {
throwInvalidClient("credentials");
}
return new OAuth2ClientAuthenticationToken(registeredClient,
@ -140,7 +142,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP @@ -140,7 +142,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
(String) parameters.get(OAuth2ParameterNames.CODE),
AUTHORIZATION_CODE_TOKEN_TYPE);
if (authorization == null) {
throwInvalidClient();
throwInvalidClient(OAuth2ParameterNames.CODE);
}
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
@ -150,7 +152,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP @@ -150,7 +152,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
.get(PkceParameterNames.CODE_CHALLENGE);
if (!StringUtils.hasText(codeChallenge)) {
if (registeredClient.getClientSettings().isRequireProofKey()) {
throwInvalidClient();
throwInvalidClient(PkceParameterNames.CODE_CHALLENGE);
} else {
return false;
}
@ -160,7 +162,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP @@ -160,7 +162,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER);
if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
throwInvalidClient();
throwInvalidClient(PkceParameterNames.CODE_VERIFIER);
}
return true;
@ -191,7 +193,12 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP @@ -191,7 +193,12 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.SERVER_ERROR);
}
private static void throwInvalidClient() {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
private static void throwInvalidClient(String parameterName) {
OAuth2Error error = new OAuth2Error(
OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName,
CLIENT_AUTHENTICATION_ERROR_URI);
throw new OAuth2AuthenticationException(error);
}
}

4
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java

@ -178,7 +178,9 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter @@ -178,7 +178,9 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
} else {
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
}
this.errorHttpResponseConverter.write(error, null, httpResponse);
// We don't want to reveal too much information to the caller so just return the error code
OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode());
this.errorHttpResponseConverter.write(errorResponse, null, httpResponse);
}
}

54
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java

@ -128,8 +128,10 @@ public class OAuth2ClientAuthenticationProviderTests { @@ -128,8 +128,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_ID);
});
}
@Test
@ -143,8 +145,10 @@ public class OAuth2ClientAuthenticationProviderTests { @@ -143,8 +145,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_SECRET);
});
verify(this.passwordEncoder).matches(any(), any());
}
@ -159,8 +163,10 @@ public class OAuth2ClientAuthenticationProviderTests { @@ -159,8 +163,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains("credentials");
});
}
@Test
@ -222,8 +228,10 @@ public class OAuth2ClientAuthenticationProviderTests { @@ -222,8 +228,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(OAuth2ParameterNames.CODE);
});
}
@Test
@ -246,8 +254,10 @@ public class OAuth2ClientAuthenticationProviderTests { @@ -246,8 +254,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER);
});
}
@Test
@ -270,8 +280,10 @@ public class OAuth2ClientAuthenticationProviderTests { @@ -270,8 +280,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER);
});
}
@Test
@ -294,8 +306,10 @@ public class OAuth2ClientAuthenticationProviderTests { @@ -294,8 +306,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER);
});
}
@Test
@ -318,8 +332,10 @@ public class OAuth2ClientAuthenticationProviderTests { @@ -318,8 +332,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER);
});
}
@Test
@ -437,8 +453,10 @@ public class OAuth2ClientAuthenticationProviderTests { @@ -437,8 +453,10 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
assertThat(error.getDescription()).contains("authentication_method");
});
}
private static Map<String, Object> createAuthorizationCodeTokenParameters() {

Loading…
Cancel
Save