Browse Source

Add generic request validator for refresh token

Signed-off-by: Andrey Litvitski <andrey1010102008@gmail.com>
pull/18098/head
Andrey Litvitski 2 months ago
parent
commit
dcc85d8df4
  1. 111
      oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationContext.java
  2. 71
      oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java
  3. 114
      oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationValidator.java

111
oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationContext.java

@ -0,0 +1,111 @@ @@ -0,0 +1,111 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.util.Assert;
/**
* An {@link OAuth2AuthenticationContext} that holds an
* {@link OAuth2RefreshTokenAuthenticationToken} and additional information and is used
* when validating the OAuth 2.0 Refresh Token Grant Request.
* <p>
* This context provides access to the current {@link OAuth2Authorization},
* {@link OAuth2ClientAuthenticationToken}, and optionally a DPoP {@link Jwt} proof.
* </p>
*
* @author Andrey Litvitski
* @since 7.0.0
* @see OAuth2AuthenticationContext
* @see OAuth2RefreshTokenAuthenticationProvider#setAuthenticationValidator(Consumer)
*/
public final class OAuth2RefreshTokenAuthenticationContext implements OAuth2AuthenticationContext {
private final Map<Object, Object> context;
private OAuth2RefreshTokenAuthenticationContext(Map<Object, Object> context) {
this.context = Collections.unmodifiableMap(new HashMap<>(context));
}
@SuppressWarnings("unchecked")
@Nullable
@Override
public <V> V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}
@Override
public boolean hasKey(Object key) {
Assert.notNull(key, "key cannot be null");
return this.context.containsKey(key);
}
public OAuth2Authorization getAuthorization() {
return get(OAuth2Authorization.class);
}
public OAuth2ClientAuthenticationToken getClientPrincipal() {
return get(OAuth2ClientAuthenticationToken.class);
}
@Nullable public Jwt getDPoPProof() {
return get(Jwt.class);
}
public static Builder with(OAuth2RefreshTokenAuthenticationToken authentication) {
return new Builder(authentication);
}
public static final class Builder extends AbstractBuilder<OAuth2RefreshTokenAuthenticationContext, Builder> {
private Builder(OAuth2RefreshTokenAuthenticationToken authentication) {
super(authentication);
}
public Builder authorization(OAuth2Authorization authorization) {
return put(OAuth2Authorization.class, authorization);
}
public Builder clientPrincipal(OAuth2ClientAuthenticationToken clientPrincipal) {
return put(OAuth2ClientAuthenticationToken.class, clientPrincipal);
}
public Builder dPoPProof(@Nullable Jwt dPoPProof) {
if (dPoPProof != null) {
put(Jwt.class, dPoPProof);
}
return this;
}
@Override
public OAuth2RefreshTokenAuthenticationContext build() {
Assert.notNull(get(OAuth2Authorization.class), "authorization cannot be null");
Assert.notNull(get(OAuth2ClientAuthenticationToken.class), "clientPrincipal cannot be null");
return new OAuth2RefreshTokenAuthenticationContext(getContext());
}
}
}

71
oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java

@ -21,8 +21,8 @@ import java.util.Collections; @@ -21,8 +21,8 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import com.nimbusds.jose.jwk.JWK;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -31,8 +31,6 @@ import org.springframework.security.authentication.AuthenticationProvider; @@ -31,8 +31,6 @@ import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
@ -52,7 +50,6 @@ import org.springframework.security.oauth2.server.authorization.token.DefaultOAu @@ -52,7 +50,6 @@ import org.springframework.security.oauth2.server.authorization.token.DefaultOAu
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
/**
* An {@link AuthenticationProvider} implementation for the OAuth 2.0 Refresh Token Grant.
@ -60,6 +57,7 @@ import org.springframework.util.CollectionUtils; @@ -60,6 +57,7 @@ import org.springframework.util.CollectionUtils;
* @author Alexey Nesterov
* @author Joe Grandja
* @author Anoop Garlapati
* @author Andrey Litvitski
* @since 7.0
* @see OAuth2RefreshTokenAuthenticationToken
* @see OAuth2AccessTokenAuthenticationToken
@ -84,6 +82,8 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic @@ -84,6 +82,8 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
private Consumer<OAuth2RefreshTokenAuthenticationContext> authenticationValidator = new OAuth2RefreshTokenAuthenticationValidator();
/**
* Constructs an {@code OAuth2RefreshTokenAuthenticationProvider} using the provided
* parameters.
@ -164,13 +164,14 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic @@ -164,13 +164,14 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
// Verify the DPoP Proof (if available)
Jwt dPoPProof = DPoPProofVerifier.verifyIfAvailable(refreshTokenAuthentication);
if (dPoPProof != null
&& clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) {
// For public clients, verify the DPoP Proof public key is same as (current)
// access token public key binding
Map<String, Object> accessTokenClaims = authorization.getAccessToken().getClaims();
verifyDPoPProofPublicKey(dPoPProof, () -> accessTokenClaims);
}
OAuth2RefreshTokenAuthenticationContext context = OAuth2RefreshTokenAuthenticationContext
.with(refreshTokenAuthentication)
.authorization(authorization)
.clientPrincipal(clientPrincipal)
.dPoPProof(dPoPProof)
.build();
this.authenticationValidator.accept(context);
if (this.logger.isTraceEnabled()) {
this.logger.trace("Validated token request parameters");
@ -292,45 +293,15 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic @@ -292,45 +293,15 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication);
}
private static void verifyDPoPProofPublicKey(Jwt dPoPProof, ClaimAccessor accessTokenClaims) {
JWK jwk = null;
@SuppressWarnings("unchecked")
Map<String, Object> jwkJson = (Map<String, Object>) dPoPProof.getHeaders().get("jwk");
try {
jwk = JWK.parse(jwkJson);
}
catch (Exception ignored) {
}
if (jwk == null) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF,
"jwk header is missing or invalid.", null);
throw new OAuth2AuthenticationException(error);
}
String jwkThumbprint;
try {
jwkThumbprint = jwk.computeThumbprint().toString();
}
catch (Exception ex) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF,
"Failed to compute SHA-256 Thumbprint for jwk.", null);
throw new OAuth2AuthenticationException(error);
}
String jwkThumbprintClaim = null;
Map<String, Object> confirmationMethodClaim = accessTokenClaims.getClaimAsMap("cnf");
if (!CollectionUtils.isEmpty(confirmationMethodClaim) && confirmationMethodClaim.containsKey("jkt")) {
jwkThumbprintClaim = (String) confirmationMethodClaim.get("jkt");
}
if (jwkThumbprintClaim == null) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, "jkt claim is missing.", null);
throw new OAuth2AuthenticationException(error);
}
if (!jwkThumbprint.equals(jwkThumbprintClaim)) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, "jwk header is invalid.", null);
throw new OAuth2AuthenticationException(error);
}
/**
* Sets the {@code Consumer} responsible for validating the OAuth 2.0 Refresh Token
* Grant Request using the provided {@link OAuth2RefreshTokenAuthenticationContext}.
* <p>
* The default validator performs DPoP proof verification if present.
*/
public void setAuthenticationValidator(Consumer<OAuth2RefreshTokenAuthenticationContext> authenticationValidator) {
Assert.notNull(authenticationValidator, "authenticationValidator cannot be null");
this.authenticationValidator = authenticationValidator;
}
}

114
oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationValidator.java

@ -0,0 +1,114 @@ @@ -0,0 +1,114 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.Map;
import java.util.function.Consumer;
import com.nimbusds.jose.jwk.JWK;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.util.CollectionUtils;
/**
* A {@code Consumer} that validates an {@link OAuth2RefreshTokenAuthenticationContext}
* and acts as the default
* {@link OAuth2RefreshTokenAuthenticationProvider#setAuthenticationValidator(Consumer)
* authentication validator} for the Refresh Token grant.
* <p>
* The default implementation validates a DPoP proof if present and throws
* {@link OAuth2AuthenticationException} on failure.
* </p>
*
* @author Andrey Litvitski
* @since 7.0.0
* @see OAuth2RefreshTokenAuthenticationContext
* @see OAuth2RefreshTokenAuthenticationProvider#setAuthenticationValidator(Consumer)
*/
public final class OAuth2RefreshTokenAuthenticationValidator
implements Consumer<OAuth2RefreshTokenAuthenticationContext> {
public static final Consumer<OAuth2RefreshTokenAuthenticationContext> DEFAULT_VALIDATOR = OAuth2RefreshTokenAuthenticationValidator::validateDefault;
private final Consumer<OAuth2RefreshTokenAuthenticationContext> authenticationValidator = DEFAULT_VALIDATOR;
@Override
public void accept(OAuth2RefreshTokenAuthenticationContext context) {
this.authenticationValidator.accept(context);
}
private static void validateDefault(OAuth2RefreshTokenAuthenticationContext context) {
Jwt dPoPProof;
if (context.getDPoPProof() == null) {
dPoPProof = DPoPProofVerifier.verifyIfAvailable(context.getAuthentication());
}
else {
dPoPProof = context.getDPoPProof();
}
if (dPoPProof == null || !context.getClientPrincipal()
.getClientAuthenticationMethod()
.equals(ClientAuthenticationMethod.NONE)) {
return;
}
JWK jwk = null;
@SuppressWarnings("unchecked")
Map<String, Object> jwkJson = (Map<String, Object>) dPoPProof.getHeaders().get("jwk");
try {
jwk = JWK.parse(jwkJson);
}
catch (Exception ignored) {
}
if (jwk == null) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF,
"jwk header is missing or invalid.", null);
throw new OAuth2AuthenticationException(error);
}
String jwkThumbprint;
try {
jwkThumbprint = jwk.computeThumbprint().toString();
}
catch (Exception ex) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF,
"Failed to compute SHA-256 Thumbprint for jwk.", null);
throw new OAuth2AuthenticationException(error);
}
String jwkThumbprintClaim = null;
Map<String, Object> accessTokenClaimsMap = context.getAuthorization().getAccessToken().getClaims();
ClaimAccessor accessTokenClaims = () -> accessTokenClaimsMap;
Map<String, Object> confirmationMethodClaim = accessTokenClaims.getClaimAsMap("cnf");
if (!CollectionUtils.isEmpty(confirmationMethodClaim) && confirmationMethodClaim.containsKey("jkt")) {
jwkThumbprintClaim = (String) confirmationMethodClaim.get("jkt");
}
if (jwkThumbprintClaim == null) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, "jkt claim is missing.", null);
throw new OAuth2AuthenticationException(error);
}
if (!jwkThumbprint.equals(jwkThumbprintClaim)) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, "jwk header is invalid.", null);
throw new OAuth2AuthenticationException(error);
}
}
}
Loading…
Cancel
Save