From bf877a986410216bbf3f29118455f05fad2873a3 Mon Sep 17 00:00:00 2001 From: Rob Winch <362503+rwinch@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:14:47 -0500 Subject: [PATCH] Add OAuth2User to OidcUser Conversion Params Previously the Oidc(Reactive)OAuth2UserService APIs allowed a strategy for converting to the OidcUser with the OidcUserRequest and OidcUserInfo. The input should also include the OAuth2User to make it simple to use the OAuth2User as a part of the conversion. This commit introduces OidcUserSource as a POJO containing OidcUserRequest, OidcUserInfo, and OAuth2User. It then updates the OidcUser conversion strategy in OidcUserService and OidcReactiveOAuth2UserService to accept OidcUserSource as the source for the Converter used to create OidUser. Closes gh-17626 --- .../OidcReactiveOAuth2UserService.java | 58 +++++++++-------- .../oidc/userinfo/OidcUserRequestUtils.java | 4 +- .../client/oidc/userinfo/OidcUserService.java | 21 ++++-- .../client/oidc/userinfo/OidcUserSource.java | 64 +++++++++++++++++++ .../OidcReactiveOAuth2UserServiceTests.java | 30 +++++++++ .../oidc/userinfo/OidcUserServiceTests.java | 39 +++++++++++ 6 files changed, 184 insertions(+), 32 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserSource.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java index 21d2c7d012..7feb511682 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java @@ -73,7 +73,8 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< private Predicate retrieveUserInfo = OidcUserRequestUtils::shouldRetrieveUserInfo; - private BiFunction> oidcUserMapper = this::getUser; + private Converter> oidcUserConverter = (source) -> Mono + .just(OidcUserRequestUtils.getUser(source)); /** * Returns the default {@link Converter}'s used for type conversion of claim values @@ -102,34 +103,26 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< public Mono loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException { Assert.notNull(userRequest, "userRequest cannot be null"); // @formatter:off - return getUserInfo(userRequest) - .flatMap((userInfo) -> this.oidcUserMapper.apply(userRequest, userInfo)) - .switchIfEmpty(Mono.defer(() -> this.oidcUserMapper.apply(userRequest, null))); + return Mono.just(userRequest) + .filter(this.retrieveUserInfo::test) + .flatMap(this.oauth2UserService::loadUser) + .flatMap((oauth2User) -> toOidcUser(userRequest, oauth2User)) + .switchIfEmpty(Mono.defer(() -> this.oidcUserConverter.convert(new OidcUserSource(userRequest)))); // @formatter:on } - private Mono getUser(OidcUserRequest userRequest, OidcUserInfo userInfo) { - return Mono.just(OidcUserRequestUtils.getUser(userRequest, userInfo)); - } - - private Mono getUserInfo(OidcUserRequest userRequest) { - if (!this.retrieveUserInfo.test(userRequest)) { - return Mono.empty(); - } - // @formatter:off - return this.oauth2UserService - .loadUser(userRequest) - .map(OAuth2User::getAttributes) - .map((claims) -> convertClaims(claims, userRequest.getClientRegistration())) - .map(OidcUserInfo::new) - .doOnNext((userInfo) -> { - String subject = userInfo.getSubject(); - if (subject == null || !subject.equals(userRequest.getIdToken().getSubject())) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - }); - // @formatter:on + private Mono toOidcUser(OidcUserRequest userRequest, OAuth2User oauth2User) { + return Mono.defer(() -> { + Map claims = convertClaims(oauth2User.getAttributes(), userRequest.getClientRegistration()); + OidcUserInfo userInfo = new OidcUserInfo(claims); + String subject = userInfo.getSubject(); + if (subject == null || !subject.equals(userRequest.getIdToken().getSubject())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + OidcUserSource source = new OidcUserSource(userRequest, userInfo, oauth2User); + return this.oidcUserConverter.convert(source); + }); } private Map convertClaims(Map claims, ClientRegistration clientRegistration) { @@ -229,10 +222,21 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< * @param oidcUserMapper the function used to map the {@link OidcUser} from the * {@link OidcUserRequest} and {@link OidcUserInfo} * @since 6.3 + * @deprecated Use {@link #setOidcUserConverter(Converter)} instead */ + @Deprecated(since = "7.0", forRemoval = true) public final void setOidcUserMapper(BiFunction> oidcUserMapper) { Assert.notNull(oidcUserMapper, "oidcUserMapper cannot be null"); - this.oidcUserMapper = oidcUserMapper; + this.oidcUserConverter = (source) -> oidcUserMapper.apply(source.getUserRequest(), source.getUserInfo()); + } + + /** + * Allows converting from the {@link OidcUserSource} to and {@link OidcUser}. + * @param oidcUserConverter the {@link Converter} to use. Cannot be null. + */ + public void setOidcUserConverter(Converter> oidcUserConverter) { + Assert.notNull(oidcUserConverter, "oidcUserConverter cannot be null"); + this.oidcUserConverter = oidcUserConverter; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java index a9f3629aae..1172aee2d5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java @@ -76,7 +76,9 @@ final class OidcUserRequestUtils { return false; } - static OidcUser getUser(OidcUserRequest userRequest, OidcUserInfo userInfo) { + static OidcUser getUser(OidcUserSource userMetadata) { + OidcUserRequest userRequest = userMetadata.getUserRequest(); + OidcUserInfo userInfo = userMetadata.getUserInfo(); Set authorities = new LinkedHashSet<>(); ClientRegistration.ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails(); String userNameAttributeName = providerDetails.getUserInfoEndpoint().getUserNameAttributeName(); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java index a7b0151ae0..162e94e0aa 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java @@ -82,7 +82,7 @@ public class OidcUserService implements OAuth2UserService retrieveUserInfo = this::shouldRetrieveUserInfo; - private BiFunction oidcUserMapper = OidcUserRequestUtils::getUser; + private Converter oidcUserConverter = OidcUserRequestUtils::getUser; /** * Returns the default {@link Converter}'s used for type conversion of claim values @@ -111,8 +111,9 @@ public class OidcUserService implements OAuth2UserService claims = getClaims(userRequest, oauth2User); userInfo = new OidcUserInfo(claims); // https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse @@ -133,7 +134,8 @@ public class OidcUserService implements OAuth2UserService getClaims(OidcUserRequest userRequest, OAuth2User oauth2User) { @@ -293,10 +295,21 @@ public class OidcUserService implements OAuth2UserService oidcUserMapper) { Assert.notNull(oidcUserMapper, "oidcUserMapper cannot be null"); - this.oidcUserMapper = oidcUserMapper; + this.oidcUserConverter = (source) -> oidcUserMapper.apply(source.getUserRequest(), source.getUserInfo()); + } + + /** + * Allows converting from the {@link OidcUserSource} to and {@link OidcUser}. + * @param oidcUserConverter the {@link Converter} to use. Cannot be null. + */ + public void setOidcUserConverter(Converter oidcUserConverter) { + Assert.notNull(oidcUserConverter, "oidcUserConverter cannot be null"); + this.oidcUserConverter = oidcUserConverter; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserSource.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserSource.java new file mode 100644 index 0000000000..98f7f56213 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserSource.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.userinfo; + +import org.jspecify.annotations.Nullable; + +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.util.Assert; + +/** + * The source for the converter to + * {@link org.springframework.security.oauth2.core.oidc.user.OidcUser}. + * + * @author Rob Winch + * @since 7.0 + */ +public class OidcUserSource { + + private final OidcUserRequest userRequest; + + private final @Nullable OidcUserInfo userInfo; + + private final @Nullable OAuth2User oauth2User; + + public OidcUserSource(OidcUserRequest userRequest) { + this(userRequest, null, null); + } + + public OidcUserSource(OidcUserRequest userRequest, @Nullable OidcUserInfo userInfo, + @Nullable OAuth2User oauth2User) { + Assert.notNull(userRequest, "userRequest cannot be null"); + this.userRequest = userRequest; + this.userInfo = userInfo; + this.oauth2User = oauth2User; + } + + public OidcUserRequest getUserRequest() { + return this.userRequest; + } + + public @Nullable OidcUserInfo getUserInfo() { + return this.userInfo; + } + + public @Nullable OAuth2User getOauth2User() { + return this.oauth2User; + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java index 172a7b3bac..d9d7a4f7a3 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java @@ -316,6 +316,36 @@ public class OidcReactiveOAuth2UserServiceTests { assertThat(userAuthority.getUserNameAttributeName()).isEqualTo("id"); } + @Test + public void loadUserWhenCustomOidcUserConverterSetThenUsed() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .userInfoUri("https://example.com/user") + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName(StandardClaimNames.SUB) + .build(); + this.accessToken = TestOAuth2AccessTokens.scopes(clientRegistration.getScopes().toArray(new String[0])); + Converter> oidcUserConverter = mock(Converter.class); + String nameAttributeKey = IdTokenClaimNames.SUB; + OidcUser actualUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("a", "b"), this.idToken, + nameAttributeKey); + OAuth2User oauth2User = new DefaultOAuth2User(actualUser.getAuthorities(), actualUser.getClaims(), + nameAttributeKey); + ReactiveOAuth2UserService oauth2 = mock(ReactiveOAuth2UserService.class); + given(oauth2.loadUser(any())).willReturn(Mono.just(oauth2User)); + given(oidcUserConverter.convert(any())).willReturn(Mono.just(actualUser)); + this.userService.setOauth2UserService(oauth2); + this.userService.setOidcUserConverter(oidcUserConverter); + OidcUserRequest userRequest = new OidcUserRequest(clientRegistration, this.accessToken, this.idToken); + OidcUser user = this.userService.loadUser(userRequest).block(); + assertThat(user).isEqualTo(actualUser); + ArgumentCaptor metadataCptr = ArgumentCaptor.forClass(OidcUserSource.class); + verify(oidcUserConverter).convert(metadataCptr.capture()); + OidcUserSource metadata = metadataCptr.getValue(); + assertThat(metadata.getUserRequest()).isEqualTo(userRequest); + assertThat(metadata.getOauth2User()).isEqualTo(oauth2User); + assertThat(metadata.getUserInfo()).isNotNull(); + } + @Test public void loadUserWhenNestedUserInfoSuccessThenReturnUser() throws IOException { // @formatter:off diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index baa574fa1e..318877d8be 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -44,6 +44,8 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; @@ -58,6 +60,7 @@ import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; @@ -155,6 +158,15 @@ public class OidcUserServiceTests { // @formatter:on } + @Test + public void setOidcUserConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.userService.setOidcUserConverter(null)) + .withMessage("oidcUserConverter cannot be null"); + // @formatter:on + } + @Test public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null)); @@ -299,6 +311,33 @@ public class OidcUserServiceTests { assertThat(userInfo.getClaimAsString("preferred_username")).isEqualTo("user1"); } + @Test + public void loadUserWhenCustomOidcUserConverterSetThenUsed() { + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri("https://example.com/user") + .build(); + this.accessToken = TestOAuth2AccessTokens.noScopes(); + Converter oidcUserConverter = mock(Converter.class); + String nameAttributeKey = IdTokenClaimNames.SUB; + OidcUser actualUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("a", "b"), this.idToken, + nameAttributeKey); + OAuth2User oauth2User = new DefaultOAuth2User(actualUser.getAuthorities(), actualUser.getClaims(), + nameAttributeKey); + OAuth2UserService oauth2 = mock(OAuth2UserService.class); + given(oauth2.loadUser(any())).willReturn(oauth2User); + given(oidcUserConverter.convert(any())).willReturn(actualUser); + this.userService.setOauth2UserService(oauth2); + this.userService.setOidcUserConverter(oidcUserConverter); + OidcUserRequest userRequest = new OidcUserRequest(clientRegistration, this.accessToken, this.idToken); + OidcUser user = this.userService.loadUser(userRequest); + assertThat(user).isEqualTo(actualUser); + ArgumentCaptor metadataCptr = ArgumentCaptor.forClass(OidcUserSource.class); + verify(oidcUserConverter).convert(metadataCptr.capture()); + OidcUserSource metadata = metadataCptr.getValue(); + assertThat(metadata.getUserRequest()).isEqualTo(userRequest); + assertThat(metadata.getOauth2User()).isEqualTo(oauth2User); + assertThat(metadata.getUserInfo()).isNotNull(); + } + @Test public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { // @formatter:off