@ -40,6 +40,8 @@ import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTProcessor ;
import com.nimbusds.jwt.proc.JWTProcessor ;
import reactor.core.publisher.Mono ;
import reactor.core.publisher.Mono ;
import org.springframework.security.oauth2.core.OAuth2TokenValidator ;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult ;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms ;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms ;
import org.springframework.util.Assert ;
import org.springframework.util.Assert ;
@ -67,6 +69,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
private final JWKSelectorFactory jwkSelectorFactory ;
private final JWKSelectorFactory jwkSelectorFactory ;
private OAuth2TokenValidator < Jwt > jwtValidator = JwtValidators . createDefault ( ) ;
public NimbusReactiveJwtDecoder ( RSAPublicKey publicKey ) {
public NimbusReactiveJwtDecoder ( RSAPublicKey publicKey ) {
JWSAlgorithm algorithm = JWSAlgorithm . parse ( JwsAlgorithms . RS256 ) ;
JWSAlgorithm algorithm = JWSAlgorithm . parse ( JwsAlgorithms . RS256 ) ;
@ -77,6 +81,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
new JWSVerificationKeySelector < > ( algorithm , jwkSource ) ;
new JWSVerificationKeySelector < > ( algorithm , jwkSource ) ;
DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor < > ( ) ;
DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor < > ( ) ;
jwtProcessor . setJWSKeySelector ( jwsKeySelector ) ;
jwtProcessor . setJWSKeySelector ( jwsKeySelector ) ;
jwtProcessor . setJWTClaimsSetVerifier ( ( claims , context ) - > { } ) ;
this . jwtProcessor = jwtProcessor ;
this . jwtProcessor = jwtProcessor ;
this . reactiveJwkSource = new ReactiveJWKSourceAdapter ( jwkSource ) ;
this . reactiveJwkSource = new ReactiveJWKSourceAdapter ( jwkSource ) ;
@ -98,6 +103,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
DefaultJWTProcessor < JWKContext > jwtProcessor = new DefaultJWTProcessor < > ( ) ;
DefaultJWTProcessor < JWKContext > jwtProcessor = new DefaultJWTProcessor < > ( ) ;
jwtProcessor . setJWSKeySelector ( jwsKeySelector ) ;
jwtProcessor . setJWSKeySelector ( jwsKeySelector ) ;
jwtProcessor . setJWTClaimsSetVerifier ( ( claims , context ) - > { } ) ;
this . jwtProcessor = jwtProcessor ;
this . jwtProcessor = jwtProcessor ;
this . reactiveJwkSource = new ReactiveRemoteJWKSource ( jwkSetUrl ) ;
this . reactiveJwkSource = new ReactiveRemoteJWKSource ( jwkSetUrl ) ;
@ -106,6 +112,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
}
}
/ * *
* Use the provided { @link OAuth2TokenValidator } to validate incoming { @link Jwt } s .
*
* @param jwtValidator the { @link OAuth2TokenValidator } to use
* /
public void setJwtValidator ( OAuth2TokenValidator < Jwt > jwtValidator ) {
Assert . notNull ( jwtValidator , "jwtValidator cannot be null" ) ;
this . jwtValidator = jwtValidator ;
}
@Override
@Override
public Mono < Jwt > decode ( String token ) throws JwtException {
public Mono < Jwt > decode ( String token ) throws JwtException {
JWT jwt = parse ( token ) ;
JWT jwt = parse ( token ) ;
@ -131,7 +147,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
. onErrorMap ( e - > new IllegalStateException ( "Could not obtain the keys" , e ) )
. onErrorMap ( e - > new IllegalStateException ( "Could not obtain the keys" , e ) )
. map ( jwkList - > createClaimsSet ( parsedToken , jwkList ) )
. map ( jwkList - > createClaimsSet ( parsedToken , jwkList ) )
. map ( set - > createJwt ( parsedToken , set ) )
. map ( set - > createJwt ( parsedToken , set ) )
. onErrorMap ( e - > ! ( e instanceof IllegalStateException ) , e - > new JwtException ( "An error occurred while attempting to decode the Jwt: " , e ) ) ;
. map ( this : : validateJwt )
. onErrorMap ( e - > ! ( e instanceof IllegalStateException ) & & ! ( e instanceof JwtException ) , e - > new JwtException ( "An error occurred while attempting to decode the Jwt: " , e ) ) ;
} catch ( RuntimeException ex ) {
} catch ( RuntimeException ex ) {
throw new JwtException ( "An error occurred while attempting to decode the Jwt: " + ex . getMessage ( ) , ex ) ;
throw new JwtException ( "An error occurred while attempting to decode the Jwt: " + ex . getMessage ( ) , ex ) ;
}
}
@ -164,6 +181,17 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
return new Jwt ( parsedJwt . getParsedString ( ) , issuedAt , expiresAt , headers , jwtClaimsSet . getClaims ( ) ) ;
return new Jwt ( parsedJwt . getParsedString ( ) , issuedAt , expiresAt , headers , jwtClaimsSet . getClaims ( ) ) ;
}
}
private Jwt validateJwt ( Jwt jwt ) {
OAuth2TokenValidatorResult result = this . jwtValidator . validate ( jwt ) ;
if ( result . hasErrors ( ) ) {
String message = result . getErrors ( ) . iterator ( ) . next ( ) . getDescription ( ) ;
throw new JwtValidationException ( message , result . getErrors ( ) ) ;
}
return jwt ;
}
private static RSAKey rsaKey ( RSAPublicKey publicKey ) {
private static RSAKey rsaKey ( RSAPublicKey publicKey ) {
return new RSAKey . Builder ( publicKey )
return new RSAKey . Builder ( publicKey )
. build ( ) ;
. build ( ) ;