|
|
|
@ -1,5 +1,5 @@ |
|
|
|
/* |
|
|
|
/* |
|
|
|
* Copyright 2002-2023 the original author or authors. |
|
|
|
* Copyright 2002-2025 the original author or authors. |
|
|
|
* |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@ -16,6 +16,12 @@ |
|
|
|
|
|
|
|
|
|
|
|
package org.springframework.security.oauth2.jwt; |
|
|
|
package org.springframework.security.oauth2.jwt; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import com.nimbusds.jose.KeySourceException; |
|
|
|
|
|
|
|
import com.nimbusds.jose.jwk.JWK; |
|
|
|
|
|
|
|
import com.nimbusds.jose.jwk.JWKMatcher; |
|
|
|
|
|
|
|
import com.nimbusds.jose.jwk.JWKSelector; |
|
|
|
|
|
|
|
import com.nimbusds.jose.jwk.source.JWKSetParseException; |
|
|
|
|
|
|
|
import com.nimbusds.jose.jwk.source.JWKSetRetrievalException; |
|
|
|
import java.io.IOException; |
|
|
|
import java.io.IOException; |
|
|
|
import java.net.MalformedURLException; |
|
|
|
import java.net.MalformedURLException; |
|
|
|
import java.net.URL; |
|
|
|
import java.net.URL; |
|
|
|
@ -26,8 +32,10 @@ import java.util.Collection; |
|
|
|
import java.util.Collections; |
|
|
|
import java.util.Collections; |
|
|
|
import java.util.HashSet; |
|
|
|
import java.util.HashSet; |
|
|
|
import java.util.LinkedHashMap; |
|
|
|
import java.util.LinkedHashMap; |
|
|
|
|
|
|
|
import java.util.List; |
|
|
|
import java.util.Map; |
|
|
|
import java.util.Map; |
|
|
|
import java.util.Set; |
|
|
|
import java.util.Set; |
|
|
|
|
|
|
|
import java.util.concurrent.locks.ReentrantLock; |
|
|
|
import java.util.function.Consumer; |
|
|
|
import java.util.function.Consumer; |
|
|
|
import java.util.function.Function; |
|
|
|
import java.util.function.Function; |
|
|
|
|
|
|
|
|
|
|
|
@ -35,17 +43,12 @@ import javax.crypto.SecretKey; |
|
|
|
|
|
|
|
|
|
|
|
import com.nimbusds.jose.JOSEException; |
|
|
|
import com.nimbusds.jose.JOSEException; |
|
|
|
import com.nimbusds.jose.JWSAlgorithm; |
|
|
|
import com.nimbusds.jose.JWSAlgorithm; |
|
|
|
import com.nimbusds.jose.RemoteKeySourceException; |
|
|
|
|
|
|
|
import com.nimbusds.jose.jwk.JWKSet; |
|
|
|
import com.nimbusds.jose.jwk.JWKSet; |
|
|
|
import com.nimbusds.jose.jwk.source.JWKSetCache; |
|
|
|
|
|
|
|
import com.nimbusds.jose.jwk.source.JWKSource; |
|
|
|
import com.nimbusds.jose.jwk.source.JWKSource; |
|
|
|
import com.nimbusds.jose.jwk.source.RemoteJWKSet; |
|
|
|
|
|
|
|
import com.nimbusds.jose.proc.JWSKeySelector; |
|
|
|
import com.nimbusds.jose.proc.JWSKeySelector; |
|
|
|
import com.nimbusds.jose.proc.JWSVerificationKeySelector; |
|
|
|
import com.nimbusds.jose.proc.JWSVerificationKeySelector; |
|
|
|
import com.nimbusds.jose.proc.SecurityContext; |
|
|
|
import com.nimbusds.jose.proc.SecurityContext; |
|
|
|
import com.nimbusds.jose.proc.SingleKeyJWSKeySelector; |
|
|
|
import com.nimbusds.jose.proc.SingleKeyJWSKeySelector; |
|
|
|
import com.nimbusds.jose.util.Resource; |
|
|
|
|
|
|
|
import com.nimbusds.jose.util.ResourceRetriever; |
|
|
|
|
|
|
|
import com.nimbusds.jwt.JWT; |
|
|
|
import com.nimbusds.jwt.JWT; |
|
|
|
import com.nimbusds.jwt.JWTClaimsSet; |
|
|
|
import com.nimbusds.jwt.JWTClaimsSet; |
|
|
|
import com.nimbusds.jwt.JWTParser; |
|
|
|
import com.nimbusds.jwt.JWTParser; |
|
|
|
@ -57,6 +60,7 @@ import org.apache.commons.logging.Log; |
|
|
|
import org.apache.commons.logging.LogFactory; |
|
|
|
import org.apache.commons.logging.LogFactory; |
|
|
|
|
|
|
|
|
|
|
|
import org.springframework.cache.Cache; |
|
|
|
import org.springframework.cache.Cache; |
|
|
|
|
|
|
|
import org.springframework.cache.support.NoOpCache; |
|
|
|
import org.springframework.core.convert.converter.Converter; |
|
|
|
import org.springframework.core.convert.converter.Converter; |
|
|
|
import org.springframework.http.HttpHeaders; |
|
|
|
import org.springframework.http.HttpHeaders; |
|
|
|
import org.springframework.http.HttpMethod; |
|
|
|
import org.springframework.http.HttpMethod; |
|
|
|
@ -80,6 +84,7 @@ import org.springframework.web.client.RestTemplate; |
|
|
|
* @author Josh Cummings |
|
|
|
* @author Josh Cummings |
|
|
|
* @author Joe Grandja |
|
|
|
* @author Joe Grandja |
|
|
|
* @author Mykyta Bezverkhyi |
|
|
|
* @author Mykyta Bezverkhyi |
|
|
|
|
|
|
|
* @author Daeho Kwon |
|
|
|
* @since 5.2 |
|
|
|
* @since 5.2 |
|
|
|
*/ |
|
|
|
*/ |
|
|
|
public final class NimbusJwtDecoder implements JwtDecoder { |
|
|
|
public final class NimbusJwtDecoder implements JwtDecoder { |
|
|
|
@ -165,7 +170,7 @@ public final class NimbusJwtDecoder implements JwtDecoder { |
|
|
|
.build(); |
|
|
|
.build(); |
|
|
|
// @formatter:on
|
|
|
|
// @formatter:on
|
|
|
|
} |
|
|
|
} |
|
|
|
catch (RemoteKeySourceException ex) { |
|
|
|
catch (KeySourceException ex) { |
|
|
|
this.logger.trace("Failed to retrieve JWK set", ex); |
|
|
|
this.logger.trace("Failed to retrieve JWK set", ex); |
|
|
|
if (ex.getCause() instanceof ParseException) { |
|
|
|
if (ex.getCause() instanceof ParseException) { |
|
|
|
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex); |
|
|
|
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex); |
|
|
|
@ -273,7 +278,7 @@ public final class NimbusJwtDecoder implements JwtDecoder { |
|
|
|
|
|
|
|
|
|
|
|
private RestOperations restOperations = new RestTemplate(); |
|
|
|
private RestOperations restOperations = new RestTemplate(); |
|
|
|
|
|
|
|
|
|
|
|
private Cache cache; |
|
|
|
private Cache cache = new NoOpCache("default"); |
|
|
|
|
|
|
|
|
|
|
|
private Consumer<ConfigurableJWTProcessor<SecurityContext>> jwtProcessorCustomizer; |
|
|
|
private Consumer<ConfigurableJWTProcessor<SecurityContext>> jwtProcessorCustomizer; |
|
|
|
|
|
|
|
|
|
|
|
@ -376,18 +381,13 @@ public final class NimbusJwtDecoder implements JwtDecoder { |
|
|
|
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); |
|
|
|
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) { |
|
|
|
JWKSource<SecurityContext> jwkSource() { |
|
|
|
if (this.cache == null) { |
|
|
|
String jwkSetUri = this.jwkSetUri.apply(this.restOperations); |
|
|
|
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever); |
|
|
|
return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri); |
|
|
|
} |
|
|
|
|
|
|
|
JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache); |
|
|
|
|
|
|
|
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
JWTProcessor<SecurityContext> processor() { |
|
|
|
JWTProcessor<SecurityContext> processor() { |
|
|
|
ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations); |
|
|
|
JWKSource<SecurityContext> jwkSource = jwkSource(); |
|
|
|
String jwkSetUri = this.jwkSetUri.apply(this.restOperations); |
|
|
|
|
|
|
|
JWKSource<SecurityContext> jwkSource = jwkSource(jwkSetRetriever, jwkSetUri); |
|
|
|
|
|
|
|
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>(); |
|
|
|
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>(); |
|
|
|
jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource)); |
|
|
|
jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource)); |
|
|
|
// Spring Security validates the claim set independent from Nimbus
|
|
|
|
// Spring Security validates the claim set independent from Nimbus
|
|
|
|
@ -414,84 +414,130 @@ public final class NimbusJwtDecoder implements JwtDecoder { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private static final class SpringJWKSetCache implements JWKSetCache { |
|
|
|
private static final class SpringJWKSource<C extends SecurityContext> implements JWKSource<C> { |
|
|
|
|
|
|
|
|
|
|
|
private final String jwkSetUri; |
|
|
|
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json"); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private final ReentrantLock reentrantLock = new ReentrantLock(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private final RestOperations restOperations; |
|
|
|
|
|
|
|
|
|
|
|
private final Cache cache; |
|
|
|
private final Cache cache; |
|
|
|
|
|
|
|
|
|
|
|
private JWKSet jwkSet; |
|
|
|
private final URL url; |
|
|
|
|
|
|
|
|
|
|
|
SpringJWKSetCache(String jwkSetUri, Cache cache) { |
|
|
|
private final String jwkSetUri; |
|
|
|
this.jwkSetUri = jwkSetUri; |
|
|
|
|
|
|
|
|
|
|
|
private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) { |
|
|
|
|
|
|
|
Assert.notNull(restOperations, "restOperations cannot be null"); |
|
|
|
|
|
|
|
this.restOperations = restOperations; |
|
|
|
this.cache = cache; |
|
|
|
this.cache = cache; |
|
|
|
this.updateJwkSetFromCache(); |
|
|
|
this.url = url; |
|
|
|
|
|
|
|
this.jwkSetUri = jwkSetUri; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private void updateJwkSetFromCache() { |
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException { |
|
|
|
String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class); |
|
|
|
String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class); |
|
|
|
|
|
|
|
JWKSet jwkSet = null; |
|
|
|
if (cachedJwkSet != null) { |
|
|
|
if (cachedJwkSet != null) { |
|
|
|
try { |
|
|
|
jwkSet = parse(cachedJwkSet); |
|
|
|
this.jwkSet = JWKSet.parse(cachedJwkSet); |
|
|
|
} |
|
|
|
} |
|
|
|
if (jwkSet == null) { |
|
|
|
catch (ParseException ignored) { |
|
|
|
if(reentrantLock.tryLock()) { |
|
|
|
// Ignore invalid cache value
|
|
|
|
try { |
|
|
|
|
|
|
|
String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class); |
|
|
|
|
|
|
|
if (cachedJwkSetAfterLock != null) { |
|
|
|
|
|
|
|
jwkSet = parse(cachedJwkSetAfterLock); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
if(jwkSet == null) { |
|
|
|
|
|
|
|
try { |
|
|
|
|
|
|
|
jwkSet = fetchJWKSet(); |
|
|
|
|
|
|
|
} catch (IOException e) { |
|
|
|
|
|
|
|
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} finally { |
|
|
|
|
|
|
|
reentrantLock.unlock(); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
List<JWK> matches = jwkSelector.select(jwkSet); |
|
|
|
|
|
|
|
if(!matches.isEmpty()) { |
|
|
|
// Note: Only called from inside a synchronized block in RemoteJWKSet.
|
|
|
|
return matches; |
|
|
|
@Override |
|
|
|
} |
|
|
|
public void put(JWKSet jwkSet) { |
|
|
|
String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher()); |
|
|
|
this.jwkSet = jwkSet; |
|
|
|
if (soughtKeyID == null) { |
|
|
|
this.cache.put(this.jwkSetUri, jwkSet.toString(false)); |
|
|
|
return Collections.emptyList(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (jwkSet.getKeyByKeyId(soughtKeyID) != null) { |
|
|
|
@Override |
|
|
|
return Collections.emptyList(); |
|
|
|
public JWKSet get() { |
|
|
|
} |
|
|
|
return (!requiresRefresh()) ? this.jwkSet : null; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public boolean requiresRefresh() { |
|
|
|
|
|
|
|
return this.cache.get(this.jwkSetUri) == null; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private static class RestOperationsResourceRetriever implements ResourceRetriever { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json"); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private final RestOperations restOperations; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RestOperationsResourceRetriever(RestOperations restOperations) { |
|
|
|
if(reentrantLock.tryLock()) { |
|
|
|
Assert.notNull(restOperations, "restOperations cannot be null"); |
|
|
|
try { |
|
|
|
this.restOperations = restOperations; |
|
|
|
String jwkSetUri = this.cache.get(this.jwkSetUri, String.class); |
|
|
|
|
|
|
|
JWKSet cacheJwkSet = parse(jwkSetUri); |
|
|
|
|
|
|
|
if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) { |
|
|
|
|
|
|
|
try { |
|
|
|
|
|
|
|
jwkSet = fetchJWKSet(); |
|
|
|
|
|
|
|
} catch (IOException e) { |
|
|
|
|
|
|
|
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} else if (jwkSetUri != null) { |
|
|
|
|
|
|
|
jwkSet = parse(jwkSetUri); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} finally { |
|
|
|
|
|
|
|
reentrantLock.unlock(); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
if(jwkSet == null) { |
|
|
|
|
|
|
|
return Collections.emptyList(); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
return jwkSelector.select(jwkSet); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
private JWKSet fetchJWKSet() throws IOException, KeySourceException { |
|
|
|
public Resource retrieveResource(URL url) throws IOException { |
|
|
|
|
|
|
|
HttpHeaders headers = new HttpHeaders(); |
|
|
|
HttpHeaders headers = new HttpHeaders(); |
|
|
|
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON)); |
|
|
|
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON)); |
|
|
|
ResponseEntity<String> response = getResponse(url, headers); |
|
|
|
ResponseEntity<String> response = getResponse(headers); |
|
|
|
if (response.getStatusCode().value() != 200) { |
|
|
|
if (response.getStatusCode().value() != 200) { |
|
|
|
throw new IOException(response.toString()); |
|
|
|
throw new IOException(response.toString()); |
|
|
|
} |
|
|
|
} |
|
|
|
return new Resource(response.getBody(), "UTF-8"); |
|
|
|
try { |
|
|
|
|
|
|
|
String jwkSet = response.getBody(); |
|
|
|
|
|
|
|
this.cache.put(this.jwkSetUri, jwkSet); |
|
|
|
|
|
|
|
return JWKSet.parse(jwkSet); |
|
|
|
|
|
|
|
} catch (ParseException e) { |
|
|
|
|
|
|
|
throw new JWKSetParseException("Unable to parse JWK set", e); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private ResponseEntity<String> getResponse(URL url, HttpHeaders headers) throws IOException { |
|
|
|
private ResponseEntity<String> getResponse(HttpHeaders headers) throws IOException { |
|
|
|
try { |
|
|
|
try { |
|
|
|
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI()); |
|
|
|
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI()); |
|
|
|
return this.restOperations.exchange(request, String.class); |
|
|
|
return this.restOperations.exchange(request, String.class); |
|
|
|
} |
|
|
|
} catch (Exception ex) { |
|
|
|
catch (Exception ex) { |
|
|
|
|
|
|
|
throw new IOException(ex); |
|
|
|
throw new IOException(ex); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private JWKSet parse(String cachedJwkSet) { |
|
|
|
|
|
|
|
JWKSet jwkSet = null; |
|
|
|
|
|
|
|
try { |
|
|
|
|
|
|
|
jwkSet = JWKSet.parse(cachedJwkSet); |
|
|
|
|
|
|
|
} catch (ParseException ignored) { |
|
|
|
|
|
|
|
// Ignore invalid cache value
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
return jwkSet; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) { |
|
|
|
|
|
|
|
Set<String> keyIDs = jwkMatcher.getKeyIDs(); |
|
|
|
|
|
|
|
return (keyIDs == null || keyIDs.isEmpty()) ? |
|
|
|
|
|
|
|
null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|