diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index 4b97d8def0..ca7d1ec451 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -40,6 +40,8 @@ import com.nimbusds.jwt.proc.DefaultJWTProcessor; import com.nimbusds.jwt.proc.JWTProcessor; import reactor.core.publisher.Mono; +import org.springframework.security.oauth2.core.OAuth2TokenValidator; +import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import org.springframework.util.Assert; @@ -67,6 +69,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { private final JWKSelectorFactory jwkSelectorFactory; + private OAuth2TokenValidator jwtValidator = JwtValidators.createDefault(); + public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) { JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256); @@ -77,6 +81,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { new JWSVerificationKeySelector<>(algorithm, jwkSource); DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); jwtProcessor.setJWSKeySelector(jwsKeySelector); + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {}); this.jwtProcessor = jwtProcessor; this.reactiveJwkSource = new ReactiveJWKSourceAdapter(jwkSource); @@ -98,6 +103,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); jwtProcessor.setJWSKeySelector(jwsKeySelector); + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {}); this.jwtProcessor = jwtProcessor; this.reactiveJwkSource = new ReactiveRemoteJWKSource(jwkSetUrl); @@ -106,6 +112,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { } + /** + * Use the provided {@link OAuth2TokenValidator} to validate incoming {@link Jwt}s. + * + * @param jwtValidator the {@link OAuth2TokenValidator} to use + */ + public void setJwtValidator(OAuth2TokenValidator jwtValidator) { + Assert.notNull(jwtValidator, "jwtValidator cannot be null"); + this.jwtValidator = jwtValidator; + } + @Override public Mono decode(String token) throws JwtException { JWT jwt = parse(token); @@ -131,7 +147,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { .onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e)) .map(jwkList -> createClaimsSet(parsedToken, jwkList)) .map(set -> createJwt(parsedToken, set)) - .onErrorMap(e -> !(e instanceof IllegalStateException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e)); + .map(this::validateJwt) + .onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e)); } catch (RuntimeException ex) { throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex); } @@ -164,6 +181,17 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims()); } + private Jwt validateJwt(Jwt jwt) { + OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt); + + if ( result.hasErrors() ) { + String message = result.getErrors().iterator().next().getDescription(); + throw new JwtValidationException(message, result.getErrors()); + } + + return jwt; + } + private static RSAKey rsaKey(RSAPublicKey publicKey) { return new RSAKey.Builder(publicKey) .build(); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index 2b3ca53adf..daed03cadc 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -16,12 +16,6 @@ package org.springframework.security.oauth2.jwt; -import okhttp3.mockwebserver.MockResponse; -import okhttp3.mockwebserver.MockWebServer; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - import java.net.UnknownHostException; import java.security.KeyFactory; import java.security.interfaces.RSAPublicKey; @@ -29,8 +23,21 @@ import java.security.spec.X509EncodedKeySpec; import java.util.Base64; import java.util.Date; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2TokenValidator; +import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -114,7 +121,7 @@ public class NimbusReactiveJwtDecoderTests { @Test public void decodeWhenExpiredThenFail() { assertThatCode(() -> this.decoder.decode(this.expired).block()) - .isInstanceOf(JwtException.class); + .isInstanceOf(JwtValidationException.class); } @Test @@ -155,4 +162,24 @@ public class NimbusReactiveJwtDecoderTests { .isInstanceOf(JwtException.class) .hasMessage("Unsupported algorithm of none"); } + + @Test + public void decodeWhenUsingCustomValidatorThenValidatorIsInvoked() { + OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); + this.decoder.setJwtValidator(jwtValidator); + + OAuth2Error error = new OAuth2Error("mock-error", "mock-description", "mock-uri"); + OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(error); + when(jwtValidator.validate(any(Jwt.class))).thenReturn(result); + + assertThatCode(() -> this.decoder.decode(messageReadToken).block()) + .isInstanceOf(JwtException.class) + .hasMessageContaining("mock-description"); + } + + @Test + public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() { + assertThatCode(() -> this.decoder.setJwtValidator(null)) + .isInstanceOf(IllegalArgumentException.class); + } }