From 90e5f45e1f3a2c0b1e2d63faaabf3df6b86ee34f Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Thu, 30 Jul 2020 16:22:49 -0600 Subject: [PATCH] Polish to Avoid NPE Issue gh-5648 Co-authored-by: MattyA --- .../security/oauth2/jwt/NimbusJwtDecoder.java | 15 ++++++- .../oauth2/jwt/NimbusReactiveJwtDecoder.java | 42 +++++++++++-------- .../oauth2/jwt/NimbusJwtDecoderTests.java | 16 +++++++ .../jwt/NimbusReactiveJwtDecoderTests.java | 16 +++++++ 4 files changed, 70 insertions(+), 19 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 03993b32e6..317d23a30d 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -22,6 +22,7 @@ import java.net.URL; import java.security.interfaces.RSAPublicKey; import java.text.ParseException; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; @@ -57,11 +58,13 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; @@ -170,9 +173,17 @@ public final class NimbusJwtDecoder implements JwtDecoder { private Jwt validateJwt(Jwt jwt){ OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt); if (result.hasErrors()) { - String description = result.getErrors().iterator().next().getDescription(); + Collection errors = result.getErrors(); + String validationErrorString = "Unable to validate Jwt"; + for (OAuth2Error oAuth2Error : errors) { + if (!StringUtils.isEmpty(oAuth2Error.getDescription())) { + validationErrorString = String.format( + DECODING_ERROR_MESSAGE_TEMPLATE, oAuth2Error.getDescription()); + break; + } + } throw new JwtValidationException( - String.format(DECODING_ERROR_MESSAGE_TEMPLATE, description), + validationErrorString, result.getErrors()); } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index 934a8297e9..9bbedfd306 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -15,16 +15,6 @@ */ package org.springframework.security.oauth2.jwt; -import java.security.interfaces.RSAPublicKey; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.Set; -import java.util.function.Consumer; -import java.util.function.Function; -import javax.crypto.SecretKey; - import com.nimbusds.jose.Header; import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; @@ -47,17 +37,29 @@ import com.nimbusds.jwt.SignedJWT; import com.nimbusds.jwt.proc.ConfigurableJWTProcessor; import com.nimbusds.jwt.proc.DefaultJWTProcessor; import com.nimbusds.jwt.proc.JWTProcessor; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import javax.crypto.SecretKey; +import java.security.interfaces.RSAPublicKey; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Function; /** * An implementation of a {@link ReactiveJwtDecoder} that "decodes" a @@ -178,10 +180,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { private Jwt validateJwt(Jwt jwt) { OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt); - - if ( result.hasErrors() ) { - String message = result.getErrors().iterator().next().getDescription(); - throw new JwtValidationException(message, result.getErrors()); + if (result.hasErrors()) { + Collection errors = result.getErrors(); + String validationErrorString = "Unable to validate Jwt"; + for (OAuth2Error oAuth2Error : errors) { + if (!StringUtils.isEmpty(oAuth2Error.getDescription())) { + validationErrorString = oAuth2Error.getDescription(); + break; + } + } + throw new JwtValidationException(validationErrorString, errors); } return jwt; diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java index a6f5b5e136..9e1e8e95b3 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java @@ -202,6 +202,22 @@ public class NimbusJwtDecoderTests { .hasFieldOrPropertyWithValue("errors", Arrays.asList(firstFailure, secondFailure)); } + @Test + public void decodeWhenReadingErrorPickTheFirstErrorMessage() { + OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); + this.jwtDecoder.setJwtValidator(jwtValidator); + + OAuth2Error errorEmpty = new OAuth2Error("mock-error", "", "mock-uri"); + OAuth2Error error = new OAuth2Error("mock-error", "mock-description", "mock-uri"); + OAuth2Error error2 = new OAuth2Error("mock-error-second", "mock-description-second", "mock-uri-second"); + OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(errorEmpty, error, error2); + when(jwtValidator.validate(any(Jwt.class))).thenReturn(result); + + Assertions.assertThatCode(() -> this.jwtDecoder.decode(SIGNED_JWT)) + .isInstanceOf(JwtValidationException.class) + .hasMessageContaining("mock-description"); + } + @Test public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() { Converter, Map> claimSetConverter = mock(Converter.class); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index d701b7b6f3..8c884dc8ab 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -224,6 +224,22 @@ public class NimbusReactiveJwtDecoderTests { .hasMessageContaining("mock-description"); } + @Test + public void decodeWhenReadingErrorPickTheFirstErrorMessage() { + OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); + this.decoder.setJwtValidator(jwtValidator); + + OAuth2Error errorEmpty = new OAuth2Error("mock-error", "", "mock-uri"); + OAuth2Error error = new OAuth2Error("mock-error", "mock-description", "mock-uri"); + OAuth2Error error2 = new OAuth2Error("mock-error-second", "mock-description-second", "mock-uri-second"); + OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(errorEmpty, error, error2); + when(jwtValidator.validate(any(Jwt.class))).thenReturn(result); + + assertThatCode(() -> this.decoder.decode(this.messageReadToken).block()) + .isInstanceOf(JwtValidationException.class) + .hasMessageContaining("mock-description"); + } + @Test public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() { Converter, Map> claimSetConverter = mock(Converter.class);