Browse Source

Remove Deprecated Usages of RemoteJWKSet

Closes gh-16251

Signed-off-by: Daeho Kwon <trewq231@naver.com>
pull/16476/head
Daeho Kwon 11 months ago committed by Josh Cummings
parent
commit
7b7abb28bb
  1. 180
      oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java
  2. 3
      oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java

180
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2023 the original author or authors. * Copyright 2002-2025 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,12 @@
package org.springframework.security.oauth2.jwt; package org.springframework.security.oauth2.jwt;
import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.source.JWKSetParseException;
import com.nimbusds.jose.jwk.source.JWKSetRetrievalException;
import java.io.IOException; import java.io.IOException;
import java.net.MalformedURLException; import java.net.MalformedURLException;
import java.net.URL; import java.net.URL;
@ -26,8 +32,10 @@ import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
@ -35,17 +43,12 @@ import javax.crypto.SecretKey;
import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.JWSKeySelector; import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.proc.SingleKeyJWSKeySelector; import com.nimbusds.jose.proc.SingleKeyJWSKeySelector;
import com.nimbusds.jose.util.Resource;
import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.JWTParser; import com.nimbusds.jwt.JWTParser;
@ -57,6 +60,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.cache.Cache; import org.springframework.cache.Cache;
import org.springframework.cache.support.NoOpCache;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
@ -80,6 +84,7 @@ import org.springframework.web.client.RestTemplate;
* @author Josh Cummings * @author Josh Cummings
* @author Joe Grandja * @author Joe Grandja
* @author Mykyta Bezverkhyi * @author Mykyta Bezverkhyi
* @author Daeho Kwon
* @since 5.2 * @since 5.2
*/ */
public final class NimbusJwtDecoder implements JwtDecoder { public final class NimbusJwtDecoder implements JwtDecoder {
@ -165,7 +170,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
.build(); .build();
// @formatter:on // @formatter:on
} }
catch (RemoteKeySourceException ex) { catch (KeySourceException ex) {
this.logger.trace("Failed to retrieve JWK set", ex); this.logger.trace("Failed to retrieve JWK set", ex);
if (ex.getCause() instanceof ParseException) { if (ex.getCause() instanceof ParseException) {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex); throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex);
@ -273,7 +278,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
private RestOperations restOperations = new RestTemplate(); private RestOperations restOperations = new RestTemplate();
private Cache cache; private Cache cache = new NoOpCache("default");
private Consumer<ConfigurableJWTProcessor<SecurityContext>> jwtProcessorCustomizer; private Consumer<ConfigurableJWTProcessor<SecurityContext>> jwtProcessorCustomizer;
@ -376,18 +381,13 @@ public final class NimbusJwtDecoder implements JwtDecoder {
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
} }
JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) { JWKSource<SecurityContext> jwkSource() {
if (this.cache == null) { String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever); return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri);
}
JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache);
} }
JWTProcessor<SecurityContext> processor() { JWTProcessor<SecurityContext> processor() {
ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations); JWKSource<SecurityContext> jwkSource = jwkSource();
String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
JWKSource<SecurityContext> jwkSource = jwkSource(jwkSetRetriever, jwkSetUri);
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>(); ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource)); jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource));
// Spring Security validates the claim set independent from Nimbus // Spring Security validates the claim set independent from Nimbus
@ -414,84 +414,130 @@ public final class NimbusJwtDecoder implements JwtDecoder {
} }
} }
private static final class SpringJWKSetCache implements JWKSetCache { private static final class SpringJWKSource<C extends SecurityContext> implements JWKSource<C> {
private final String jwkSetUri; private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
private final ReentrantLock reentrantLock = new ReentrantLock();
private final RestOperations restOperations;
private final Cache cache; private final Cache cache;
private JWKSet jwkSet; private final URL url;
SpringJWKSetCache(String jwkSetUri, Cache cache) { private final String jwkSetUri;
this.jwkSetUri = jwkSetUri;
private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) {
Assert.notNull(restOperations, "restOperations cannot be null");
this.restOperations = restOperations;
this.cache = cache; this.cache = cache;
this.updateJwkSetFromCache(); this.url = url;
this.jwkSetUri = jwkSetUri;
} }
private void updateJwkSetFromCache() {
@Override
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class); String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class);
JWKSet jwkSet = null;
if (cachedJwkSet != null) { if (cachedJwkSet != null) {
try { jwkSet = parse(cachedJwkSet);
this.jwkSet = JWKSet.parse(cachedJwkSet); }
} if (jwkSet == null) {
catch (ParseException ignored) { if(reentrantLock.tryLock()) {
// Ignore invalid cache value try {
String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class);
if (cachedJwkSetAfterLock != null) {
jwkSet = parse(cachedJwkSetAfterLock);
}
if(jwkSet == null) {
try {
jwkSet = fetchJWKSet();
} catch (IOException e) {
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
}
}
} finally {
reentrantLock.unlock();
}
} }
} }
} List<JWK> matches = jwkSelector.select(jwkSet);
if(!matches.isEmpty()) {
// Note: Only called from inside a synchronized block in RemoteJWKSet. return matches;
@Override }
public void put(JWKSet jwkSet) { String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
this.jwkSet = jwkSet; if (soughtKeyID == null) {
this.cache.put(this.jwkSetUri, jwkSet.toString(false)); return Collections.emptyList();
} }
if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
@Override return Collections.emptyList();
public JWKSet get() { }
return (!requiresRefresh()) ? this.jwkSet : null;
}
@Override
public boolean requiresRefresh() {
return this.cache.get(this.jwkSetUri) == null;
}
}
private static class RestOperationsResourceRetriever implements ResourceRetriever {
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
private final RestOperations restOperations;
RestOperationsResourceRetriever(RestOperations restOperations) { if(reentrantLock.tryLock()) {
Assert.notNull(restOperations, "restOperations cannot be null"); try {
this.restOperations = restOperations; String jwkSetUri = this.cache.get(this.jwkSetUri, String.class);
JWKSet cacheJwkSet = parse(jwkSetUri);
if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) {
try {
jwkSet = fetchJWKSet();
} catch (IOException e) {
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
}
} else if (jwkSetUri != null) {
jwkSet = parse(jwkSetUri);
}
} finally {
reentrantLock.unlock();
}
}
if(jwkSet == null) {
return Collections.emptyList();
}
return jwkSelector.select(jwkSet);
} }
@Override private JWKSet fetchJWKSet() throws IOException, KeySourceException {
public Resource retrieveResource(URL url) throws IOException {
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON)); headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
ResponseEntity<String> response = getResponse(url, headers); ResponseEntity<String> response = getResponse(headers);
if (response.getStatusCode().value() != 200) { if (response.getStatusCode().value() != 200) {
throw new IOException(response.toString()); throw new IOException(response.toString());
} }
return new Resource(response.getBody(), "UTF-8"); try {
String jwkSet = response.getBody();
this.cache.put(this.jwkSetUri, jwkSet);
return JWKSet.parse(jwkSet);
} catch (ParseException e) {
throw new JWKSetParseException("Unable to parse JWK set", e);
}
} }
private ResponseEntity<String> getResponse(URL url, HttpHeaders headers) throws IOException { private ResponseEntity<String> getResponse(HttpHeaders headers) throws IOException {
try { try {
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI()); RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI());
return this.restOperations.exchange(request, String.class); return this.restOperations.exchange(request, String.class);
} } catch (Exception ex) {
catch (Exception ex) {
throw new IOException(ex); throw new IOException(ex);
} }
} }
private JWKSet parse(String cachedJwkSet) {
JWKSet jwkSet = null;
try {
jwkSet = JWKSet.parse(cachedJwkSet);
} catch (ParseException ignored) {
// Ignore invalid cache value
}
return jwkSet;
}
private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) {
Set<String> keyIDs = jwkMatcher.getKeyIDs();
return (keyIDs == null || keyIDs.isEmpty()) ?
null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null);
}
} }
} }

3
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2025 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -308,6 +308,7 @@ public class JwtDecodersTests {
private void prepareConfigurationResponse(String body) { private void prepareConfigurationResponse(String body) {
this.server.enqueue(response(body)); this.server.enqueue(response(body));
this.server.enqueue(response(JWK_SET)); this.server.enqueue(response(JWK_SET));
this.server.enqueue(response(JWK_SET)); // default NoOpCache
} }
private void prepareConfigurationResponseOidc() { private void prepareConfigurationResponseOidc() {

Loading…
Cancel
Save