@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2023 the original author or authors .
* Copyright 2002 - 2025 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -16,6 +16,12 @@
@@ -16,6 +16,12 @@
package org.springframework.security.oauth2.jwt ;
import com.nimbusds.jose.KeySourceException ;
import com.nimbusds.jose.jwk.JWK ;
import com.nimbusds.jose.jwk.JWKMatcher ;
import com.nimbusds.jose.jwk.JWKSelector ;
import com.nimbusds.jose.jwk.source.JWKSetParseException ;
import com.nimbusds.jose.jwk.source.JWKSetRetrievalException ;
import java.io.IOException ;
import java.net.MalformedURLException ;
import java.net.URL ;
@ -26,8 +32,10 @@ import java.util.Collection;
@@ -26,8 +32,10 @@ import java.util.Collection;
import java.util.Collections ;
import java.util.HashSet ;
import java.util.LinkedHashMap ;
import java.util.List ;
import java.util.Map ;
import java.util.Set ;
import java.util.concurrent.locks.ReentrantLock ;
import java.util.function.Consumer ;
import java.util.function.Function ;
@ -35,17 +43,12 @@ import javax.crypto.SecretKey;
@@ -35,17 +43,12 @@ import javax.crypto.SecretKey;
import com.nimbusds.jose.JOSEException ;
import com.nimbusds.jose.JWSAlgorithm ;
import com.nimbusds.jose.RemoteKeySourceException ;
import com.nimbusds.jose.jwk.JWKSet ;
import com.nimbusds.jose.jwk.source.JWKSetCache ;
import com.nimbusds.jose.jwk.source.JWKSource ;
import com.nimbusds.jose.jwk.source.RemoteJWKSet ;
import com.nimbusds.jose.proc.JWSKeySelector ;
import com.nimbusds.jose.proc.JWSVerificationKeySelector ;
import com.nimbusds.jose.proc.SecurityContext ;
import com.nimbusds.jose.proc.SingleKeyJWSKeySelector ;
import com.nimbusds.jose.util.Resource ;
import com.nimbusds.jose.util.ResourceRetriever ;
import com.nimbusds.jwt.JWT ;
import com.nimbusds.jwt.JWTClaimsSet ;
import com.nimbusds.jwt.JWTParser ;
@ -57,6 +60,7 @@ import org.apache.commons.logging.Log;
@@ -57,6 +60,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory ;
import org.springframework.cache.Cache ;
import org.springframework.cache.support.NoOpCache ;
import org.springframework.core.convert.converter.Converter ;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.HttpMethod ;
@ -80,6 +84,7 @@ import org.springframework.web.client.RestTemplate;
@@ -80,6 +84,7 @@ import org.springframework.web.client.RestTemplate;
* @author Josh Cummings
* @author Joe Grandja
* @author Mykyta Bezverkhyi
* @author Daeho Kwon
* @since 5 . 2
* /
public final class NimbusJwtDecoder implements JwtDecoder {
@ -165,7 +170,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
@@ -165,7 +170,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
. build ( ) ;
// @formatter:on
}
catch ( Remote KeySourceException ex ) {
catch ( KeySourceException ex ) {
this . logger . trace ( "Failed to retrieve JWK set" , ex ) ;
if ( ex . getCause ( ) instanceof ParseException ) {
throw new JwtException ( String . format ( DECODING_ERROR_MESSAGE_TEMPLATE , "Malformed Jwk set" ) , ex ) ;
@ -273,7 +278,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
@@ -273,7 +278,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
private RestOperations restOperations = new RestTemplate ( ) ;
private Cache cache ;
private Cache cache = new NoOpCache ( "default" ) ;
private Consumer < ConfigurableJWTProcessor < SecurityContext > > jwtProcessorCustomizer ;
@ -376,18 +381,13 @@ public final class NimbusJwtDecoder implements JwtDecoder {
@@ -376,18 +381,13 @@ public final class NimbusJwtDecoder implements JwtDecoder {
return new JWSVerificationKeySelector < > ( jwsAlgorithms , jwkSource ) ;
}
JWKSource < SecurityContext > jwkSource ( ResourceRetriever jwkSetRetriever , String jwkSetUri ) {
if ( this . cache = = null ) {
return new RemoteJWKSet < > ( toURL ( jwkSetUri ) , jwkSetRetriever ) ;
}
JWKSetCache jwkSetCache = new SpringJWKSetCache ( jwkSetUri , this . cache ) ;
return new RemoteJWKSet < > ( toURL ( jwkSetUri ) , jwkSetRetriever , jwkSetCache ) ;
JWKSource < SecurityContext > jwkSource ( ) {
String jwkSetUri = this . jwkSetUri . apply ( this . restOperations ) ;
return new SpringJWKSource < > ( this . restOperations , this . cache , toURL ( jwkSetUri ) , jwkSetUri ) ;
}
JWTProcessor < SecurityContext > processor ( ) {
ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever ( this . restOperations ) ;
String jwkSetUri = this . jwkSetUri . apply ( this . restOperations ) ;
JWKSource < SecurityContext > jwkSource = jwkSource ( jwkSetRetriever , jwkSetUri ) ;
JWKSource < SecurityContext > jwkSource = jwkSource ( ) ;
ConfigurableJWTProcessor < SecurityContext > jwtProcessor = new DefaultJWTProcessor < > ( ) ;
jwtProcessor . setJWSKeySelector ( jwsKeySelector ( jwkSource ) ) ;
// Spring Security validates the claim set independent from Nimbus
@ -414,84 +414,130 @@ public final class NimbusJwtDecoder implements JwtDecoder {
@@ -414,84 +414,130 @@ public final class NimbusJwtDecoder implements JwtDecoder {
}
}
private static final class SpringJWKSetCache implements JWKSetCache {
private static final class SpringJWKSource < C extends SecurityContext > implements JWKSource < C > {
private final String jwkSetUri ;
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ( "application" , "jwk-set+json" ) ;
private final ReentrantLock reentrantLock = new ReentrantLock ( ) ;
private final RestOperations restOperations ;
private final Cache cache ;
private JWKSet jwkSet ;
private final URL url ;
SpringJWKSetCache ( String jwkSetUri , Cache cache ) {
this . jwkSetUri = jwkSetUri ;
private final String jwkSetUri ;
private SpringJWKSource ( RestOperations restOperations , Cache cache , URL url , String jwkSetUri ) {
Assert . notNull ( restOperations , "restOperations cannot be null" ) ;
this . restOperations = restOperations ;
this . cache = cache ;
this . updateJwkSetFromCache ( ) ;
this . url = url ;
this . jwkSetUri = jwkSetUri ;
}
private void updateJwkSetFromCache ( ) {
@Override
public List < JWK > get ( JWKSelector jwkSelector , SecurityContext context ) throws KeySourceException {
String cachedJwkSet = this . cache . get ( this . jwkSetUri , String . class ) ;
JWKSet jwkSet = null ;
if ( cachedJwkSet ! = null ) {
try {
this . jwkSet = JWKSet . parse ( cachedJwkSet ) ;
}
catch ( ParseException ignored ) {
// Ignore invalid cache value
jwkSet = parse ( cachedJwkSet ) ;
}
if ( jwkSet = = null ) {
if ( reentrantLock . tryLock ( ) ) {
try {
String cachedJwkSetAfterLock = this . cache . get ( this . jwkSetUri , String . class ) ;
if ( cachedJwkSetAfterLock ! = null ) {
jwkSet = parse ( cachedJwkSetAfterLock ) ;
}
if ( jwkSet = = null ) {
try {
jwkSet = fetchJWKSet ( ) ;
} catch ( IOException e ) {
throw new JWKSetRetrievalException ( "Couldn't retrieve JWK set from URL: " + e . getMessage ( ) , e ) ;
}
}
} finally {
reentrantLock . unlock ( ) ;
}
}
}
}
// Note: Only called from inside a synchronized block in RemoteJWKSet.
@Override
public void put ( JWKSet jwkSet ) {
this . jwkSet = jwkSet ;
this . cache . put ( this . jwkSetUri , jwkSet . toString ( false ) ) ;
}
@Override
public JWKSet get ( ) {
return ( ! requiresRefresh ( ) ) ? this . jwkSet : null ;
}
@Override
public boolean requiresRefresh ( ) {
return this . cache . get ( this . jwkSetUri ) = = null ;
}
}
private static class RestOperationsResourceRetriever implements ResourceRetriever {
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ( "application" , "jwk-set+json" ) ;
private final RestOperations restOperations ;
List < JWK > matches = jwkSelector . select ( jwkSet ) ;
if ( ! matches . isEmpty ( ) ) {
return matches ;
}
String soughtKeyID = getFirstSpecifiedKeyID ( jwkSelector . getMatcher ( ) ) ;
if ( soughtKeyID = = null ) {
return Collections . emptyList ( ) ;
}
if ( jwkSet . getKeyByKeyId ( soughtKeyID ) ! = null ) {
return Collections . emptyList ( ) ;
}
RestOperationsResourceRetriever ( RestOperations restOperations ) {
Assert . notNull ( restOperations , "restOperations cannot be null" ) ;
this . restOperations = restOperations ;
if ( reentrantLock . tryLock ( ) ) {
try {
String jwkSetUri = this . cache . get ( this . jwkSetUri , String . class ) ;
JWKSet cacheJwkSet = parse ( jwkSetUri ) ;
if ( jwkSetUri ! = null & & cacheJwkSet . toString ( ) . equals ( jwkSet . toString ( ) ) ) {
try {
jwkSet = fetchJWKSet ( ) ;
} catch ( IOException e ) {
throw new JWKSetRetrievalException ( "Couldn't retrieve JWK set from URL: " + e . getMessage ( ) , e ) ;
}
} else if ( jwkSetUri ! = null ) {
jwkSet = parse ( jwkSetUri ) ;
}
} finally {
reentrantLock . unlock ( ) ;
}
}
if ( jwkSet = = null ) {
return Collections . emptyList ( ) ;
}
return jwkSelector . select ( jwkSet ) ;
}
@Override
public Resource retrieveResource ( URL url ) throws IOException {
private JWKSet fetchJWKSet ( ) throws IOException , KeySourceException {
HttpHeaders headers = new HttpHeaders ( ) ;
headers . setAccept ( Arrays . asList ( MediaType . APPLICATION_JSON , APPLICATION_JWK_SET_JSON ) ) ;
ResponseEntity < String > response = getResponse ( url , headers ) ;
ResponseEntity < String > response = getResponse ( headers ) ;
if ( response . getStatusCode ( ) . value ( ) ! = 200 ) {
throw new IOException ( response . toString ( ) ) ;
}
return new Resource ( response . getBody ( ) , "UTF-8" ) ;
try {
String jwkSet = response . getBody ( ) ;
this . cache . put ( this . jwkSetUri , jwkSet ) ;
return JWKSet . parse ( jwkSet ) ;
} catch ( ParseException e ) {
throw new JWKSetParseException ( "Unable to parse JWK set" , e ) ;
}
}
private ResponseEntity < String > getResponse ( URL url , HttpHeaders headers ) throws IOException {
private ResponseEntity < String > getResponse ( HttpHeaders headers ) throws IOException {
try {
RequestEntity < Void > request = new RequestEntity < > ( headers , HttpMethod . GET , url . toURI ( ) ) ;
RequestEntity < Void > request = new RequestEntity < > ( headers , HttpMethod . GET , this . url . toURI ( ) ) ;
return this . restOperations . exchange ( request , String . class ) ;
}
catch ( Exception ex ) {
} catch ( Exception ex ) {
throw new IOException ( ex ) ;
}
}
private JWKSet parse ( String cachedJwkSet ) {
JWKSet jwkSet = null ;
try {
jwkSet = JWKSet . parse ( cachedJwkSet ) ;
} catch ( ParseException ignored ) {
// Ignore invalid cache value
}
return jwkSet ;
}
private String getFirstSpecifiedKeyID ( JWKMatcher jwkMatcher ) {
Set < String > keyIDs = jwkMatcher . getKeyIDs ( ) ;
return ( keyIDs = = null | | keyIDs . isEmpty ( ) ) ?
null : keyIDs . stream ( ) . filter ( id - > id ! = null ) . findFirst ( ) . orElse ( null ) ;
}
}
}