diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWSAlgorithmMapJWSKeySelector.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWSAlgorithmMapJWSKeySelector.java deleted file mode 100644 index 2947e90ff5..0000000000 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWSAlgorithmMapJWSKeySelector.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright 2002-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.jwt; - -import java.security.Key; -import java.util.List; -import java.util.Map; -import java.util.Set; - -import com.nimbusds.jose.JWSAlgorithm; -import com.nimbusds.jose.JWSHeader; -import com.nimbusds.jose.KeySourceException; -import com.nimbusds.jose.proc.JWSKeySelector; -import com.nimbusds.jose.proc.SecurityContext; - -/** - * Class for delegating to a Nimbus JWSKeySelector by the given JWSAlgorithm - * - * @author Josh Cummings - */ -class JWSAlgorithmMapJWSKeySelector implements JWSKeySelector { - private Map> jwsKeySelectors; - - JWSAlgorithmMapJWSKeySelector(Map> jwsKeySelectors) { - this.jwsKeySelectors = jwsKeySelectors; - } - - @Override - public List selectJWSKeys(JWSHeader header, C context) throws KeySourceException { - JWSKeySelector keySelector = this.jwsKeySelectors.get(header.getAlgorithm()); - if (keySelector == null) { - throw new IllegalArgumentException("Unsupported algorithm of " + header.getAlgorithm()); - } - return keySelector.selectJWSKeys(header, context); - } - - public Set getExpectedJWSAlgorithms() { - return this.jwsKeySelectors.keySet(); - } -} diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index b8e805fdfd..44daabd13c 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -23,7 +23,6 @@ import java.security.interfaces.RSAPublicKey; import java.text.ParseException; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; @@ -286,16 +285,13 @@ public final class NimbusJwtDecoder implements JwtDecoder { JWSKeySelector jwsKeySelector(JWKSource jwkSource) { if (this.signatureAlgorithms.isEmpty()) { return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource); - } else if (this.signatureAlgorithms.size() == 1) { - JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(this.signatureAlgorithms.iterator().next().getName()); - return new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource); } else { - Map> jwsKeySelectors = new HashMap<>(); + Set jwsAlgorithms = new HashSet<>(); for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) { - JWSAlgorithm jwsAlg = JWSAlgorithm.parse(signatureAlgorithm.getName()); - jwsKeySelectors.put(jwsAlg, new JWSVerificationKeySelector<>(jwsAlg, jwkSource)); + JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName()); + jwsAlgorithms.add(jwsAlgorithm); } - return new JWSAlgorithmMapJWSKeySelector<>(jwsKeySelectors); + return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); } } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index c80bbb4a3a..fa82a3899c 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -17,7 +17,6 @@ package org.springframework.security.oauth2.jwt; import java.security.interfaces.RSAPublicKey; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; @@ -307,16 +306,13 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { JWSKeySelector jwsKeySelector(JWKSource jwkSource) { if (this.signatureAlgorithms.isEmpty()) { return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource); - } else if (this.signatureAlgorithms.size() == 1) { - JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(this.signatureAlgorithms.iterator().next().getName()); - return new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource); } else { - Map> jwsKeySelectors = new HashMap<>(); + Set jwsAlgorithms = new HashSet<>(); for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) { - JWSAlgorithm jwsAlg = JWSAlgorithm.parse(signatureAlgorithm.getName()); - jwsKeySelectors.put(jwsAlg, new JWSVerificationKeySelector<>(jwsAlg, jwkSource)); + JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName()); + jwsAlgorithms.add(jwsAlgorithm); } - return new JWSAlgorithmMapJWSKeySelector<>(jwsKeySelectors); + return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); } } @@ -330,7 +326,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri); source.setWebClient(this.webClient); - Set expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector); + Function expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector); return jwt -> { JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader()); return source.get(selector) @@ -339,22 +335,20 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { }; } - private Set getExpectedJwsAlgorithms(JWSKeySelector jwsKeySelector) { + private Function getExpectedJwsAlgorithms(JWSKeySelector jwsKeySelector) { if (jwsKeySelector instanceof JWSVerificationKeySelector) { - return Collections.singleton(((JWSVerificationKeySelector) jwsKeySelector).getExpectedJWSAlgorithm()); - } - if (jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector) { - return ((JWSAlgorithmMapJWSKeySelector) jwsKeySelector).getExpectedJWSAlgorithms(); + return ((JWSVerificationKeySelector) jwsKeySelector)::isAllowed; } throw new IllegalArgumentException("Unsupported key selector type " + jwsKeySelector.getClass()); } - private JWKSelector createSelector(Set expectedJwsAlgorithms, Header header) { - if (!expectedJwsAlgorithms.contains(header.getAlgorithm())) { + private JWKSelector createSelector(Function expectedJwsAlgorithms, Header header) { + JWSHeader jwsHeader = (JWSHeader) header; + if (!expectedJwsAlgorithms.apply(jwsHeader.getAlgorithm())) { throw new BadJwtException("Unsupported algorithm of " + header.getAlgorithm()); } - return new JWKSelector(JWKMatcher.forJWSHeader((JWSHeader) header)); + return new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader)); } } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java index 9e3608b3be..a6ff351287 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java @@ -415,8 +415,8 @@ public class NimbusJwtDecoderTests { assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); JWSVerificationKeySelector jwsVerificationKeySelector = (JWSVerificationKeySelector) jwsKeySelector; - assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm()) - .isEqualTo(JWSAlgorithm.RS256); + assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS256)) + .isTrue(); } @Test @@ -428,8 +428,8 @@ public class NimbusJwtDecoderTests { assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); JWSVerificationKeySelector jwsVerificationKeySelector = (JWSVerificationKeySelector) jwsKeySelector; - assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm()) - .isEqualTo(JWSAlgorithm.RS512); + assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS512)) + .isTrue(); } @Test @@ -440,11 +440,13 @@ public class NimbusJwtDecoderTests { .jwsAlgorithm(SignatureAlgorithm.RS256) .jwsAlgorithm(SignatureAlgorithm.RS512) .jwsKeySelector(jwkSource); - assertThat(jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector); - JWSAlgorithmMapJWSKeySelector jwsAlgorithmMapKeySelector = - (JWSAlgorithmMapJWSKeySelector) jwsKeySelector; - assertThat(jwsAlgorithmMapKeySelector.getExpectedJWSAlgorithms()) - .containsExactlyInAnyOrder(JWSAlgorithm.RS256, JWSAlgorithm.RS512); + assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); + JWSVerificationKeySelector jwsAlgorithmMapKeySelector = + (JWSVerificationKeySelector) jwsKeySelector; + assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS256)) + .isTrue(); + assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS512)) + .isTrue(); } // gh-7290 diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index 74a9d8f671..3109031fdb 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -395,8 +395,8 @@ public class NimbusReactiveJwtDecoderTests { assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); JWSVerificationKeySelector jwsVerificationKeySelector = (JWSVerificationKeySelector) jwsKeySelector; - assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm()) - .isEqualTo(JWSAlgorithm.RS256); + assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS256)) + .isTrue(); } @Test @@ -408,8 +408,8 @@ public class NimbusReactiveJwtDecoderTests { assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); JWSVerificationKeySelector jwsVerificationKeySelector = (JWSVerificationKeySelector) jwsKeySelector; - assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm()) - .isEqualTo(JWSAlgorithm.RS512); + assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS512)) + .isTrue(); } @Test @@ -420,11 +420,13 @@ public class NimbusReactiveJwtDecoderTests { .jwsAlgorithm(SignatureAlgorithm.RS256) .jwsAlgorithm(SignatureAlgorithm.RS512) .jwsKeySelector(jwkSource); - assertThat(jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector); - JWSAlgorithmMapJWSKeySelector jwsAlgorithmMapKeySelector = - (JWSAlgorithmMapJWSKeySelector) jwsKeySelector; - assertThat(jwsAlgorithmMapKeySelector.getExpectedJWSAlgorithms()) - .containsExactlyInAnyOrder(JWSAlgorithm.RS256, JWSAlgorithm.RS512); + assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); + JWSVerificationKeySelector jwsAlgorithmMapKeySelector = + (JWSVerificationKeySelector) jwsKeySelector; + assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS256)) + .isTrue(); + assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS512)) + .isTrue(); } private SignedJWT signedJwt(SecretKey secretKey, MacAlgorithm jwsAlgorithm, JWTClaimsSet claimsSet) throws Exception {