diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java index bba502dfd9..fb0468fa9b 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java @@ -87,17 +87,12 @@ public final class NimbusJwtEncoder implements JwtEncoder { private final JWKSource jwkSource; - private Converter, JWK> jwkSelector= (jwks)->{ - if (jwks.size() > 1) { - throw new JwtEncodingException(String.format( - "Failed to select a key since there are multiple for the signing algorithm [%s]; " + - "please specify a selector in NimbusJwsEncoder#setJwkSelector",jwks.get(0).getAlgorithm())); - } - if (jwks.isEmpty()) { - throw new JwtEncodingException( - String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key")); - } - return jwks.get(0); + private Converter, JWK> jwkSelector = (jwks) -> { + throw new JwtEncodingException( + String.format( + "Failed to select a key since there are multiple for the signing algorithm [%s]; " + + "please specify a selector in NimbusJwsEncoder#setJwkSelector", + jwks.get(0).getAlgorithm())); }; /** @@ -108,17 +103,20 @@ public final class NimbusJwtEncoder implements JwtEncoder { Assert.notNull(jwkSource, "jwkSource cannot be null"); this.jwkSource = jwkSource; } + /** - * Use this strategy to reduce the list of matching JWKs down to a since one. - *

For example, you can call {@code setJwkSelector(List::getFirst)} in order - * to have this encoder select the first match. + * Use this strategy to reduce the list of matching JWKs when there is more than one. + *

+ * For example, you can call {@code setJwkSelector(List::getFirst)} in order to have + * this encoder select the first match. * - *

By default, the class with throw an exception if there is more than one result. + *

+ * By default, the class with throw an exception. * @since 6.5 */ public void setJwkSelector(Converter, JWK> jwkSelector) { - if(null!=jwkSelector) - this.jwkSelector = jwkSelector; + Assert.notNull(jwkSelector, "jwkSelector cannot be null"); + this.jwkSelector = jwkSelector; } @Override @@ -149,6 +147,13 @@ public final class NimbusJwtEncoder implements JwtEncoder { throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key -> " + ex.getMessage()), ex); } + if (jwks.isEmpty()) { + throw new JwtEncodingException( + String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key")); + } + if (jwks.size() == 1) { + return jwks.get(0); + } return this.jwkSelector.convert(jwks); } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java index 412adbfd4d..d0426a4533 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,10 @@ public final class TestJwks { private TestJwks() { } + public static RSAKey.Builder rsa() { + return jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY); + } + public static RSAKey.Builder jwk(RSAPublicKey publicKey, RSAPrivateKey privateKey) { // @formatter:off return new RSAKey.Builder(publicKey) diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java index e9825f0a35..ab17156eac 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.KeySourceException; import com.nimbusds.jose.jwk.ECKey; import com.nimbusds.jose.jwk.JWK; @@ -39,6 +40,7 @@ import org.junit.jupiter.api.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.TestKeys; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; @@ -51,6 +53,8 @@ import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; /** * Tests for {@link NimbusJwtEncoder}. @@ -109,7 +113,7 @@ public class NimbusJwtEncoderTests { @Test public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception { - RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + RSAKey rsaJwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build(); this.jwkList.add(rsaJwk); this.jwkList.add(rsaJwk); @@ -118,7 +122,7 @@ public class NimbusJwtEncoderTests { assertThatExceptionOfType(JwtEncodingException.class) .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet))) - .withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'"); + .withMessageContaining("Failed to select a key since there are multiple for the signing algorithm [RS256]"); } @Test @@ -291,6 +295,55 @@ public class NimbusJwtEncoderTests { assertThat(jwk1.getKeyID()).isNotEqualTo(jwk2.getKeyID()); } + @Test + public void encodeWhenMultipleKeysThenJwkSelectorUsed() throws Exception { + JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build(); + JWKSource jwkSource = mock(JWKSource.class); + given(jwkSource.get(any(), any())).willReturn(List.of(jwk, jwk)); + Converter, JWK> selector = mock(Converter.class); + given(selector.convert(any())).willReturn(TestJwks.DEFAULT_RSA_JWK); + + NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource); + jwtEncoder.setJwkSelector(selector); + + JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build(); + jwtEncoder.encode(JwtEncoderParameters.from(claims)); + + verify(selector).convert(any()); + } + + @Test + public void encodeWhenSingleKeyThenJwkSelectorIsNotUsed() throws Exception { + JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build(); + JWKSource jwkSource = mock(JWKSource.class); + given(jwkSource.get(any(), any())).willReturn(List.of(jwk)); + Converter, JWK> selector = mock(Converter.class); + + NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource); + jwtEncoder.setJwkSelector(selector); + + JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build(); + jwtEncoder.encode(JwtEncoderParameters.from(claims)); + + verifyNoInteractions(selector); + } + + @Test + public void encodeWhenNoKeysThenJwkSelectorIsNotUsed() throws Exception { + JWKSource jwkSource = mock(JWKSource.class); + given(jwkSource.get(any(), any())).willReturn(List.of()); + Converter, JWK> selector = mock(Converter.class); + + NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource); + jwtEncoder.setJwkSelector(selector); + + JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build(); + assertThatExceptionOfType(JwtEncodingException.class) + .isThrownBy(() -> jwtEncoder.encode(JwtEncoderParameters.from(claims))); + + verifyNoInteractions(selector); + } + private static final class JwkListResultCaptor implements Answer> { private List result;