diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java index 6a66651a44..566b248d6a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.function.Function; +import java.util.function.Predicate; import reactor.core.publisher.Mono; @@ -33,6 +34,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; +import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -71,6 +73,8 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< private Function, Map>> claimTypeConverterFactory = ( clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER; + private Predicate retrieveUserInfo = OidcUserRequestUtils::shouldRetrieveUserInfo; + /** * Returns the default {@link Converter}'s used for type conversion of claim values * for an {@link OidcUserInfo}. @@ -123,7 +127,7 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< } private Mono getUserInfo(OidcUserRequest userRequest) { - if (!OidcUserRequestUtils.shouldRetrieveUserInfo(userRequest)) { + if (!this.retrieveUserInfo.test(userRequest)) { return Mono.empty(); } // @formatter:off @@ -169,4 +173,24 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< this.claimTypeConverterFactory = claimTypeConverterFactory; } + /** + * Sets the {@code Predicate} used to determine if the UserInfo Endpoint should be + * called to retrieve information about the End-User (Resource Owner). + *

+ * By default, the UserInfo Endpoint is called if all of the following are true: + *

    + *
  • The user info endpoint is defined on the ClientRegistration
  • + *
  • The Client Registration uses the + * {@link AuthorizationGrantType#AUTHORIZATION_CODE} and scopes in the access token + * are defined in the {@link ClientRegistration}
  • + *
+ * @param retrieveUserInfo the function used to determine if the UserInfo Endpoint + * should be called + * @since 6.3 + */ + public final void setRetrieveUserInfo(Predicate retrieveUserInfo) { + Assert.notNull(retrieveUserInfo, "retrieveUserInfo cannot be null"); + this.retrieveUserInfo = retrieveUserInfo; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java index 0ae4727ff7..a0d1cd26aa 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; import java.util.function.Function; +import java.util.function.Predicate; import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; @@ -78,6 +79,8 @@ public class OidcUserService implements OAuth2UserService, Map>> claimTypeConverterFactory = ( clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER; + private Predicate retrieveUserInfo = this::shouldRetrieveUserInfo; + /** * Returns the default {@link Converter}'s used for type conversion of claim values * for an {@link OidcUserInfo}. @@ -105,7 +108,7 @@ public class OidcUserService implements OAuth2UserService claims = getClaims(userRequest, oauth2User); userInfo = new OidcUserInfo(claims); @@ -221,10 +224,35 @@ public class OidcUserService implements OAuth2UserService accessibleScopes) { Assert.notNull(accessibleScopes, "accessibleScopes cannot be null"); this.accessibleScopes = accessibleScopes; } + /** + * Sets the {@code Predicate} used to determine if the UserInfo Endpoint should be + * called to retrieve information about the End-User (Resource Owner). + *

+ * By default, the UserInfo Endpoint is called if all of the following are true: + *

    + *
  • The user info endpoint is defined on the ClientRegistration
  • + *
  • The Client Registration uses the + * {@link AuthorizationGrantType#AUTHORIZATION_CODE}
  • + *
  • The access token contains one or more scopes allowed to access the UserInfo + * Endpoint ({@link OidcScopes#PROFILE profile}, {@link OidcScopes#EMAIL email}, + * {@link OidcScopes#ADDRESS address} or {@link OidcScopes#PHONE phone}) or the access + * token scopes are empty
  • + *
+ * @param retrieveUserInfo the function used to determine if the UserInfo Endpoint + * should be called + * @since 6.3 + */ + public final void setRetrieveUserInfo(Predicate retrieveUserInfo) { + Assert.notNull(retrieveUserInfo, "retrieveUserInfo cannot be null"); + this.retrieveUserInfo = retrieveUserInfo; + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java index 14acfdea16..bf20d712f6 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.function.Function; +import java.util.function.Predicate; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; @@ -107,6 +108,15 @@ public class OidcReactiveOAuth2UserServiceTests { assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setClaimTypeConverterFactory(null)); } + @Test + public void setRetrieveUserInfoWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.userService.setRetrieveUserInfo(null)) + .withMessage("retrieveUserInfo cannot be null"); + // @formatter:on + } + @Test public void loadUserWhenUserInfoUriNullThenUserInfoNotRetrieved() { this.registration.userInfoUri(null); @@ -183,6 +193,48 @@ public class OidcReactiveOAuth2UserServiceTests { verify(customClaimTypeConverterFactory).apply(same(userRequest.getClientRegistration())); } + @Test + public void loadUserWhenTokenScopesIsEmptyThenUserInfoNotRetrieved() { + // @formatter:off + OAuth2AccessToken accessToken = new OAuth2AccessToken( + this.accessToken.getTokenType(), + this.accessToken.getTokenValue(), + this.accessToken.getIssuedAt(), + this.accessToken.getExpiresAt(), + Collections.emptySet()); + // @formatter:on + OidcUserRequest userRequest = new OidcUserRequest(this.registration.build(), accessToken, this.idToken); + OidcUser oidcUser = this.userService.loadUser(userRequest).block(); + assertThat(oidcUser).isNotNull(); + assertThat(oidcUser.getUserInfo()).isNull(); + } + + @Test + public void loadUserWhenCustomRetrieveUserInfoSetThenUsed() { + Map attributes = new HashMap<>(); + attributes.put(StandardClaimNames.SUB, "subject"); + attributes.put("user", "steve"); + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), attributes, + "user"); + given(this.oauth2UserService.loadUser(any())).willReturn(Mono.just(oauth2User)); + Predicate customRetrieveUserInfo = mock(Predicate.class); + this.userService.setRetrieveUserInfo(customRetrieveUserInfo); + given(customRetrieveUserInfo.test(any(OidcUserRequest.class))).willReturn(true); + // @formatter:off + OAuth2AccessToken accessToken = new OAuth2AccessToken( + this.accessToken.getTokenType(), + this.accessToken.getTokenValue(), + this.accessToken.getIssuedAt(), + this.accessToken.getExpiresAt(), + Collections.emptySet()); + // @formatter:on + OidcUserRequest userRequest = new OidcUserRequest(this.registration.build(), accessToken, this.idToken); + OidcUser oidcUser = this.userService.loadUser(userRequest).block(); + assertThat(oidcUser).isNotNull(); + assertThat(oidcUser.getUserInfo()).isNotNull(); + verify(customRetrieveUserInfo).test(userRequest); + } + @Test public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index 6d63e8a5c3..4b08664a80 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -23,6 +23,7 @@ import java.util.Iterator; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.function.Function; +import java.util.function.Predicate; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; @@ -58,6 +59,7 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.same; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -129,6 +131,15 @@ public class OidcUserServiceTests { this.userService.setAccessibleScopes(Collections.emptySet()); } + @Test + public void setRetrieveUserInfoWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.userService.setRetrieveUserInfo(null)) + .withMessage("retrieveUserInfo cannot be null"); + // @formatter:on + } + @Test public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null)); @@ -218,6 +229,30 @@ public class OidcUserServiceTests { assertThat(user.getUserInfo()).isNotNull(); } + @Test + public void loadUserWhenCustomRetrieveUserInfoSetThenUsed() { + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(userInfoResponse)); + String userInfoUri = this.server.url("/user").toString(); + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); + this.accessToken = TestOAuth2AccessTokens.noScopes(); + Predicate customRetrieveUserInfo = mock(Predicate.class); + given(customRetrieveUserInfo.test(any(OidcUserRequest.class))).willReturn(true); + this.userService.setRetrieveUserInfo(customRetrieveUserInfo); + OidcUser user = this.userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + assertThat(user.getUserInfo()).isNotNull(); + } + @Test public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { // @formatter:off