diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java index 21d2c7d012..7feb511682 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java @@ -73,7 +73,8 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< private Predicate retrieveUserInfo = OidcUserRequestUtils::shouldRetrieveUserInfo; - private BiFunction> oidcUserMapper = this::getUser; + private Converter> oidcUserConverter = (source) -> Mono + .just(OidcUserRequestUtils.getUser(source)); /** * Returns the default {@link Converter}'s used for type conversion of claim values @@ -102,34 +103,26 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< public Mono loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException { Assert.notNull(userRequest, "userRequest cannot be null"); // @formatter:off - return getUserInfo(userRequest) - .flatMap((userInfo) -> this.oidcUserMapper.apply(userRequest, userInfo)) - .switchIfEmpty(Mono.defer(() -> this.oidcUserMapper.apply(userRequest, null))); + return Mono.just(userRequest) + .filter(this.retrieveUserInfo::test) + .flatMap(this.oauth2UserService::loadUser) + .flatMap((oauth2User) -> toOidcUser(userRequest, oauth2User)) + .switchIfEmpty(Mono.defer(() -> this.oidcUserConverter.convert(new OidcUserSource(userRequest)))); // @formatter:on } - private Mono getUser(OidcUserRequest userRequest, OidcUserInfo userInfo) { - return Mono.just(OidcUserRequestUtils.getUser(userRequest, userInfo)); - } - - private Mono getUserInfo(OidcUserRequest userRequest) { - if (!this.retrieveUserInfo.test(userRequest)) { - return Mono.empty(); - } - // @formatter:off - return this.oauth2UserService - .loadUser(userRequest) - .map(OAuth2User::getAttributes) - .map((claims) -> convertClaims(claims, userRequest.getClientRegistration())) - .map(OidcUserInfo::new) - .doOnNext((userInfo) -> { - String subject = userInfo.getSubject(); - if (subject == null || !subject.equals(userRequest.getIdToken().getSubject())) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - }); - // @formatter:on + private Mono toOidcUser(OidcUserRequest userRequest, OAuth2User oauth2User) { + return Mono.defer(() -> { + Map claims = convertClaims(oauth2User.getAttributes(), userRequest.getClientRegistration()); + OidcUserInfo userInfo = new OidcUserInfo(claims); + String subject = userInfo.getSubject(); + if (subject == null || !subject.equals(userRequest.getIdToken().getSubject())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + OidcUserSource source = new OidcUserSource(userRequest, userInfo, oauth2User); + return this.oidcUserConverter.convert(source); + }); } private Map convertClaims(Map claims, ClientRegistration clientRegistration) { @@ -229,10 +222,21 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< * @param oidcUserMapper the function used to map the {@link OidcUser} from the * {@link OidcUserRequest} and {@link OidcUserInfo} * @since 6.3 + * @deprecated Use {@link #setOidcUserConverter(Converter)} instead */ + @Deprecated(since = "7.0", forRemoval = true) public final void setOidcUserMapper(BiFunction> oidcUserMapper) { Assert.notNull(oidcUserMapper, "oidcUserMapper cannot be null"); - this.oidcUserMapper = oidcUserMapper; + this.oidcUserConverter = (source) -> oidcUserMapper.apply(source.getUserRequest(), source.getUserInfo()); + } + + /** + * Allows converting from the {@link OidcUserSource} to and {@link OidcUser}. + * @param oidcUserConverter the {@link Converter} to use. Cannot be null. + */ + public void setOidcUserConverter(Converter> oidcUserConverter) { + Assert.notNull(oidcUserConverter, "oidcUserConverter cannot be null"); + this.oidcUserConverter = oidcUserConverter; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java index a9f3629aae..1172aee2d5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java @@ -76,7 +76,9 @@ final class OidcUserRequestUtils { return false; } - static OidcUser getUser(OidcUserRequest userRequest, OidcUserInfo userInfo) { + static OidcUser getUser(OidcUserSource userMetadata) { + OidcUserRequest userRequest = userMetadata.getUserRequest(); + OidcUserInfo userInfo = userMetadata.getUserInfo(); Set authorities = new LinkedHashSet<>(); ClientRegistration.ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails(); String userNameAttributeName = providerDetails.getUserInfoEndpoint().getUserNameAttributeName(); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java index a7b0151ae0..162e94e0aa 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java @@ -82,7 +82,7 @@ public class OidcUserService implements OAuth2UserService retrieveUserInfo = this::shouldRetrieveUserInfo; - private BiFunction oidcUserMapper = OidcUserRequestUtils::getUser; + private Converter oidcUserConverter = OidcUserRequestUtils::getUser; /** * Returns the default {@link Converter}'s used for type conversion of claim values @@ -111,8 +111,9 @@ public class OidcUserService implements OAuth2UserService claims = getClaims(userRequest, oauth2User); userInfo = new OidcUserInfo(claims); // https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse @@ -133,7 +134,8 @@ public class OidcUserService implements OAuth2UserService getClaims(OidcUserRequest userRequest, OAuth2User oauth2User) { @@ -293,10 +295,21 @@ public class OidcUserService implements OAuth2UserService oidcUserMapper) { Assert.notNull(oidcUserMapper, "oidcUserMapper cannot be null"); - this.oidcUserMapper = oidcUserMapper; + this.oidcUserConverter = (source) -> oidcUserMapper.apply(source.getUserRequest(), source.getUserInfo()); + } + + /** + * Allows converting from the {@link OidcUserSource} to and {@link OidcUser}. + * @param oidcUserConverter the {@link Converter} to use. Cannot be null. + */ + public void setOidcUserConverter(Converter oidcUserConverter) { + Assert.notNull(oidcUserConverter, "oidcUserConverter cannot be null"); + this.oidcUserConverter = oidcUserConverter; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserSource.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserSource.java new file mode 100644 index 0000000000..98f7f56213 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserSource.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.userinfo; + +import org.jspecify.annotations.Nullable; + +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.util.Assert; + +/** + * The source for the converter to + * {@link org.springframework.security.oauth2.core.oidc.user.OidcUser}. + * + * @author Rob Winch + * @since 7.0 + */ +public class OidcUserSource { + + private final OidcUserRequest userRequest; + + private final @Nullable OidcUserInfo userInfo; + + private final @Nullable OAuth2User oauth2User; + + public OidcUserSource(OidcUserRequest userRequest) { + this(userRequest, null, null); + } + + public OidcUserSource(OidcUserRequest userRequest, @Nullable OidcUserInfo userInfo, + @Nullable OAuth2User oauth2User) { + Assert.notNull(userRequest, "userRequest cannot be null"); + this.userRequest = userRequest; + this.userInfo = userInfo; + this.oauth2User = oauth2User; + } + + public OidcUserRequest getUserRequest() { + return this.userRequest; + } + + public @Nullable OidcUserInfo getUserInfo() { + return this.userInfo; + } + + public @Nullable OAuth2User getOauth2User() { + return this.oauth2User; + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java index 172a7b3bac..d9d7a4f7a3 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java @@ -316,6 +316,36 @@ public class OidcReactiveOAuth2UserServiceTests { assertThat(userAuthority.getUserNameAttributeName()).isEqualTo("id"); } + @Test + public void loadUserWhenCustomOidcUserConverterSetThenUsed() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .userInfoUri("https://example.com/user") + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName(StandardClaimNames.SUB) + .build(); + this.accessToken = TestOAuth2AccessTokens.scopes(clientRegistration.getScopes().toArray(new String[0])); + Converter> oidcUserConverter = mock(Converter.class); + String nameAttributeKey = IdTokenClaimNames.SUB; + OidcUser actualUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("a", "b"), this.idToken, + nameAttributeKey); + OAuth2User oauth2User = new DefaultOAuth2User(actualUser.getAuthorities(), actualUser.getClaims(), + nameAttributeKey); + ReactiveOAuth2UserService oauth2 = mock(ReactiveOAuth2UserService.class); + given(oauth2.loadUser(any())).willReturn(Mono.just(oauth2User)); + given(oidcUserConverter.convert(any())).willReturn(Mono.just(actualUser)); + this.userService.setOauth2UserService(oauth2); + this.userService.setOidcUserConverter(oidcUserConverter); + OidcUserRequest userRequest = new OidcUserRequest(clientRegistration, this.accessToken, this.idToken); + OidcUser user = this.userService.loadUser(userRequest).block(); + assertThat(user).isEqualTo(actualUser); + ArgumentCaptor metadataCptr = ArgumentCaptor.forClass(OidcUserSource.class); + verify(oidcUserConverter).convert(metadataCptr.capture()); + OidcUserSource metadata = metadataCptr.getValue(); + assertThat(metadata.getUserRequest()).isEqualTo(userRequest); + assertThat(metadata.getOauth2User()).isEqualTo(oauth2User); + assertThat(metadata.getUserInfo()).isNotNull(); + } + @Test public void loadUserWhenNestedUserInfoSuccessThenReturnUser() throws IOException { // @formatter:off diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index baa574fa1e..318877d8be 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -44,6 +44,8 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; @@ -58,6 +60,7 @@ import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; @@ -155,6 +158,15 @@ public class OidcUserServiceTests { // @formatter:on } + @Test + public void setOidcUserConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.userService.setOidcUserConverter(null)) + .withMessage("oidcUserConverter cannot be null"); + // @formatter:on + } + @Test public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null)); @@ -299,6 +311,33 @@ public class OidcUserServiceTests { assertThat(userInfo.getClaimAsString("preferred_username")).isEqualTo("user1"); } + @Test + public void loadUserWhenCustomOidcUserConverterSetThenUsed() { + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri("https://example.com/user") + .build(); + this.accessToken = TestOAuth2AccessTokens.noScopes(); + Converter oidcUserConverter = mock(Converter.class); + String nameAttributeKey = IdTokenClaimNames.SUB; + OidcUser actualUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("a", "b"), this.idToken, + nameAttributeKey); + OAuth2User oauth2User = new DefaultOAuth2User(actualUser.getAuthorities(), actualUser.getClaims(), + nameAttributeKey); + OAuth2UserService oauth2 = mock(OAuth2UserService.class); + given(oauth2.loadUser(any())).willReturn(oauth2User); + given(oidcUserConverter.convert(any())).willReturn(actualUser); + this.userService.setOauth2UserService(oauth2); + this.userService.setOidcUserConverter(oidcUserConverter); + OidcUserRequest userRequest = new OidcUserRequest(clientRegistration, this.accessToken, this.idToken); + OidcUser user = this.userService.loadUser(userRequest); + assertThat(user).isEqualTo(actualUser); + ArgumentCaptor metadataCptr = ArgumentCaptor.forClass(OidcUserSource.class); + verify(oidcUserConverter).convert(metadataCptr.capture()); + OidcUserSource metadata = metadataCptr.getValue(); + assertThat(metadata.getUserRequest()).isEqualTo(userRequest); + assertThat(metadata.getOauth2User()).isEqualTo(oauth2User); + assertThat(metadata.getUserInfo()).isNotNull(); + } + @Test public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { // @formatter:off