Browse Source

Fix DPoP jkt claim validation during refresh_token grant for public clients

Closes gh-2008
pull/2011/head
Joe Grandja 7 months ago
parent
commit
86b5607a03
  1. 21
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java
  2. 3
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2RefreshTokenGrantTests.java

21
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java

@ -15,16 +15,12 @@
*/ */
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import java.security.MessageDigest;
import java.security.Principal; import java.security.Principal;
import java.security.PublicKey;
import java.util.Base64;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import com.nimbusds.jose.jwk.AsymmetricJWK;
import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWK;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
@ -292,18 +288,15 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
} }
private static void verifyDPoPProofPublicKey(Jwt dPoPProof, ClaimAccessor accessTokenClaims) { private static void verifyDPoPProofPublicKey(Jwt dPoPProof, ClaimAccessor accessTokenClaims) {
PublicKey publicKey = null; JWK jwk = null;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Map<String, Object> jwkJson = (Map<String, Object>) dPoPProof.getHeaders().get("jwk"); Map<String, Object> jwkJson = (Map<String, Object>) dPoPProof.getHeaders().get("jwk");
try { try {
JWK jwk = JWK.parse(jwkJson); jwk = JWK.parse(jwkJson);
if (jwk instanceof AsymmetricJWK) {
publicKey = ((AsymmetricJWK) jwk).toPublicKey();
}
} }
catch (Exception ignored) { catch (Exception ignored) {
} }
if (publicKey == null) { if (jwk == null) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF,
"jwk header is missing or invalid.", null); "jwk header is missing or invalid.", null);
throw new OAuth2AuthenticationException(error); throw new OAuth2AuthenticationException(error);
@ -311,7 +304,7 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
String jwkThumbprint; String jwkThumbprint;
try { try {
jwkThumbprint = computeSHA256(publicKey); jwkThumbprint = jwk.computeThumbprint().toString();
} }
catch (Exception ex) { catch (Exception ex) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF,
@ -335,10 +328,4 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
} }
} }
private static String computeSHA256(PublicKey publicKey) throws Exception {
MessageDigest md = MessageDigest.getInstance("SHA-256");
byte[] digest = md.digest(publicKey.getEncoded());
return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
}
} }

3
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2RefreshTokenGrantTests.java

@ -295,9 +295,8 @@ public class OAuth2RefreshTokenGrantTests {
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.DPOP, OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.DPOP,
"dpop-bound-access-token", Instant.now(), Instant.now().plusSeconds(300)); "dpop-bound-access-token", Instant.now(), Instant.now().plusSeconds(300));
Map<String, Object> accessTokenClaims = new HashMap<>(); Map<String, Object> accessTokenClaims = new HashMap<>();
PublicKey publicKey = TestJwks.DEFAULT_EC_JWK.toPublicKey();
Map<String, Object> cnfClaim = new HashMap<>(); Map<String, Object> cnfClaim = new HashMap<>();
cnfClaim.put("jkt", computeSHA256(publicKey)); cnfClaim.put("jkt", TestJwks.DEFAULT_EC_JWK.toPublicJWK().computeThumbprint().toString());
accessTokenClaims.put("cnf", cnfClaim); accessTokenClaims.put("cnf", cnfClaim);
OAuth2Authorization authorization = TestOAuth2Authorizations OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClient, accessToken, accessTokenClaims) .authorization(registeredClient, accessToken, accessTokenClaims)

Loading…
Cancel
Save