From 8e8e6d1b17364f9decfd6e770a01eec19a9e7577 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Tue, 21 Sep 2021 14:31:32 -0500 Subject: [PATCH] Implement User Info Endpoint Closes gh-176 --- .../server/authorization/OidcConfigurer.java | 19 +- .../OidcUserInfoEndpointConfigurer.java | 112 +++++++ .../OidcUserInfoHttpMessageConverter.java | 105 +++--- .../config/ConfigurationSettingNames.java | 5 + .../config/ProviderSettings.java | 22 +- .../oidc/DefaultUserInfoClaimsMapper.java | 11 - .../oidc/UserInfoClaimsMapper.java | 9 - .../OidcUserInfoAuthenticationProvider.java | 199 +++++++++++ .../OidcUserInfoAuthenticationToken.java | 88 +++++ .../oidc/web/OidcUserInfoEndpointFilter.java | 157 +++++---- .../authorization/OidcUserInfoTests.java | 308 ++++++++++++++++++ ...OidcUserInfoHttpMessageConverterTests.java | 224 +++++++++++++ .../config/ProviderSettingsTests.java | 13 +- ...dcUserInfoAuthenticationProviderTests.java | 285 ++++++++++++++++ .../OidcUserInfoAuthenticationTokenTests.java | 60 ++++ .../web/OidcUserInfoEndpointFilterTests.java | 254 +++++++++++++++ 16 files changed, 1720 insertions(+), 151 deletions(-) create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcUserInfoEndpointConfigurer.java delete mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/DefaultUserInfoClaimsMapper.java delete mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/UserInfoClaimsMapper.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationProvider.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationToken.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcUserInfoTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcUserInfoHttpMessageConverterTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationProviderTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationTokenTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilterTests.java diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcConfigurer.java index 71bbd323..dd17955a 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcConfigurer.java @@ -36,9 +36,11 @@ import org.springframework.security.web.util.matcher.RequestMatcher; * @since 0.2.0 * @see OAuth2AuthorizationServerConfigurer#oidc * @see OidcClientRegistrationEndpointConfigurer + * @see OidcUserInfoEndpointConfigurer * @see OidcProviderConfigurationEndpointFilter */ public final class OidcConfigurer extends AbstractOAuth2Configurer { + private final OidcUserInfoEndpointConfigurer userInfoEndpointConfigurer; private OidcClientRegistrationEndpointConfigurer clientRegistrationEndpointConfigurer; private RequestMatcher requestMatcher; @@ -47,6 +49,7 @@ public final class OidcConfigurer extends AbstractOAuth2Configurer { */ OidcConfigurer(ObjectPostProcessor objectPostProcessor) { super(objectPostProcessor); + this.userInfoEndpointConfigurer = new OidcUserInfoEndpointConfigurer(objectPostProcessor); } /** @@ -63,8 +66,20 @@ public final class OidcConfigurer extends AbstractOAuth2Configurer { return this; } + /** + * Configures the OpenID Connect 1.0 UserInfo Endpoint. + * + * @param userInfoEndpointCustomizer the {@link Customizer} providing access to the {@link OidcUserInfoEndpointConfigurer} + * @return the {@link OidcConfigurer} for further configuration + */ + public OidcConfigurer userInfoEndpoint(Customizer userInfoEndpointCustomizer) { + userInfoEndpointCustomizer.customize(this.userInfoEndpointConfigurer); + return this; + } + @Override > void init(B builder) { + this.userInfoEndpointConfigurer.init(builder); if (this.clientRegistrationEndpointConfigurer != null) { this.clientRegistrationEndpointConfigurer.init(builder); } @@ -75,14 +90,16 @@ public final class OidcConfigurer extends AbstractOAuth2Configurer { requestMatchers.add(new AntPathRequestMatcher( "/.well-known/openid-configuration", HttpMethod.GET.name())); } + requestMatchers.add(this.userInfoEndpointConfigurer.getRequestMatcher()); if (this.clientRegistrationEndpointConfigurer != null) { requestMatchers.add(this.clientRegistrationEndpointConfigurer.getRequestMatcher()); } - this.requestMatcher = !requestMatchers.isEmpty() ? new OrRequestMatcher(requestMatchers) : request -> false; + this.requestMatcher = requestMatchers.size() > 1 ? new OrRequestMatcher(requestMatchers) : requestMatchers.get(0); } @Override > void configure(B builder) { + this.userInfoEndpointConfigurer.configure(builder); if (this.clientRegistrationEndpointConfigurer != null) { this.clientRegistrationEndpointConfigurer.configure(builder); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcUserInfoEndpointConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcUserInfoEndpointConfigurer.java new file mode 100644 index 00000000..0f4f5c56 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcUserInfoEndpointConfigurer.java @@ -0,0 +1,112 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; + +import java.util.function.Function; + +import org.springframework.http.HttpMethod; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.config.annotation.ObjectPostProcessor; +import org.springframework.security.config.annotation.web.HttpSecurityBuilder; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.authentication.OAuth2AuthenticationContext; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.oidc.web.OidcUserInfoEndpointFilter; +import org.springframework.security.web.access.intercept.FilterSecurityInterceptor; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.OrRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; + +/** + * Configurer for OpenID Connect 1.0 UserInfo Endpoint. + * + * @author Steve Riesenberg + * @since 0.2.1 + * @see OidcConfigurer#userInfoEndpoint + * @see OidcUserInfoEndpointFilter + */ +public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configurer { + private RequestMatcher requestMatcher; + private Function userInfoMapper; + + /** + * Restrict for internal use only. + */ + OidcUserInfoEndpointConfigurer(ObjectPostProcessor objectPostProcessor) { + super(objectPostProcessor); + } + + /** + * Sets the {@link Function} used to extract claims from an {@link OAuth2AuthenticationContext} + * to an instance of {@link OidcUserInfo}. + * + *

+ * The {@link OAuth2AuthenticationContext} gives the mapper access to the {@link OidcUserInfoAuthenticationToken}. + * In addition, the following context attributes are supported: + *

    + *
  • {@code OAuth2Token.class} - The {@link OAuth2Token} containing the bearer token used to make the request.
  • + *
  • {@code OAuth2Authorization.class} - The {@link OAuth2Authorization} containing the {@link OidcIdToken} and + * {@link OAuth2AccessToken} associated with the bearer token used to make the request.
  • + *
+ * + * @param userInfoMapper the {@link Function} used to extract claims from an {@link OAuth2AuthenticationContext} to an instance of {@link OidcUserInfo} + * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration + */ + public OidcUserInfoEndpointConfigurer userInfoMapper(Function userInfoMapper) { + this.userInfoMapper = userInfoMapper; + return this; + } + + @Override + > void init(B builder) { + ProviderSettings providerSettings = OAuth2ConfigurerUtils.getProviderSettings(builder); + String userInfoEndpointUri = providerSettings.getOidcUserInfoEndpoint(); + this.requestMatcher = new OrRequestMatcher( + new AntPathRequestMatcher(userInfoEndpointUri, HttpMethod.GET.name()), + new AntPathRequestMatcher(userInfoEndpointUri, HttpMethod.POST.name())); + + OidcUserInfoAuthenticationProvider oidcUserInfoAuthenticationProvider = + new OidcUserInfoAuthenticationProvider( + OAuth2ConfigurerUtils.getAuthorizationService(builder)); + if (this.userInfoMapper != null) { + oidcUserInfoAuthenticationProvider.setUserInfoMapper(this.userInfoMapper); + } + builder.authenticationProvider(postProcess(oidcUserInfoAuthenticationProvider)); + } + + @Override + > void configure(B builder) { + AuthenticationManager authenticationManager = builder.getSharedObject(AuthenticationManager.class); + ProviderSettings providerSettings = OAuth2ConfigurerUtils.getProviderSettings(builder); + + OidcUserInfoEndpointFilter oidcUserInfoEndpointFilter = + new OidcUserInfoEndpointFilter( + authenticationManager, + providerSettings.getOidcUserInfoEndpoint()); + builder.addFilterAfter(postProcess(oidcUserInfoEndpointFilter), FilterSecurityInterceptor.class); + } + + @Override + RequestMatcher getRequestMatcher() { + return this.requestMatcher; + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcUserInfoHttpMessageConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcUserInfoHttpMessageConverter.java index f74f5170..361787b9 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcUserInfoHttpMessageConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcUserInfoHttpMessageConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,44 +15,46 @@ */ package org.springframework.security.oauth2.core.oidc.http.converter; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; + import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpInputMessage; import org.springframework.http.HttpOutputMessage; import org.springframework.http.MediaType; -import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.AbstractHttpMessageConverter; import org.springframework.http.converter.GenericHttpMessageConverter; -import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageNotReadableException; -import org.springframework.http.converter.AbstractHttpMessageConverter; +import org.springframework.http.converter.HttpMessageNotWritableException; import org.springframework.security.oauth2.core.converter.ClaimConversionService; import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.oidc.StandardClaimNames; import org.springframework.util.Assert; -import java.util.HashMap; -import java.util.Map; - /** - * A {@link HttpMessageConverter} for an {@link OidcUserInfo OIDC User Info Response}. + * A {@link HttpMessageConverter} for an {@link OidcUserInfo OpenID Connect UserInfo Request and Response}. * * @author Ido Salomon + * @author Steve Riesenberg + * @since 0.2.1 * @see AbstractHttpMessageConverter * @see OidcUserInfo - * @since 0.1.1 */ public class OidcUserInfoHttpMessageConverter extends AbstractHttpMessageConverter { private static final ParameterizedTypeReference> STRING_OBJECT_MAP = - new ParameterizedTypeReference>() { - }; + new ParameterizedTypeReference>() {}; - private final GenericHttpMessageConverter jsonMessageConverter = HttpMessageConverters.getJsonMessageConverter(); + private final GenericHttpMessageConverter jsonMessageConverter = + HttpMessageConverters.getJsonMessageConverter(); - private Converter, OidcUserInfo> oidcUserInfoConverter = new OidcUserInfoConverter(); - private Converter> oidcUserInfoParametersConverter = OidcUserInfo::getClaims; + private Converter, OidcUserInfo> userInfoConverter = new MapOidcUserInfoConverter(); + private Converter> userInfoParametersConverter = OidcUserInfo::getClaims; public OidcUserInfoHttpMessageConverter() { super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json")); @@ -68,12 +70,12 @@ public class OidcUserInfoHttpMessageConverter extends AbstractHttpMessageConvert protected OidcUserInfo readInternal(Class clazz, HttpInputMessage inputMessage) throws HttpMessageNotReadableException { try { - Map oidcUserInfoParameters = + Map userInfoParameters = (Map) this.jsonMessageConverter.read(STRING_OBJECT_MAP.getType(), null, inputMessage); - return this.oidcUserInfoConverter.convert(oidcUserInfoParameters); + return this.userInfoConverter.convert(userInfoParameters); } catch (Exception ex) { throw new HttpMessageNotReadableException( - "An error occurred reading the OIDC User Info: " + ex.getMessage(), ex, inputMessage); + "An error occurred reading the UserInfo: " + ex.getMessage(), ex, inputMessage); } } @@ -81,77 +83,82 @@ public class OidcUserInfoHttpMessageConverter extends AbstractHttpMessageConvert protected void writeInternal(OidcUserInfo oidcUserInfo, HttpOutputMessage outputMessage) throws HttpMessageNotWritableException { try { - Map oidcUserInfoResponseParameters = - this.oidcUserInfoParametersConverter.convert(oidcUserInfo); + Map userInfoResponseParameters = + this.userInfoParametersConverter.convert(oidcUserInfo); this.jsonMessageConverter.write( - oidcUserInfoResponseParameters, + userInfoResponseParameters, STRING_OBJECT_MAP.getType(), MediaType.APPLICATION_JSON, outputMessage ); } catch (Exception ex) { throw new HttpMessageNotWritableException( - "An error occurred writing the OIDC User Info response: " + ex.getMessage(), ex); + "An error occurred writing the UserInfo response: " + ex.getMessage(), ex); } } /** - * Sets the {@link Converter} used for converting the OIDC User Info parameters + * Sets the {@link Converter} used for converting the UserInfo parameters * to an {@link OidcUserInfo}. * - * @param oidcUserInfoConverter the {@link Converter} used for converting to an - * {@link OidcUserInfo} + * @param userInfoConverter the {@link Converter} used for converting to an + * {@link OidcUserInfo} */ - public final void setOidcUserInfoConverter(Converter, OidcUserInfo> oidcUserInfoConverter) { - Assert.notNull(oidcUserInfoConverter, "oidcUserInfoConverter cannot be null"); - this.oidcUserInfoConverter = oidcUserInfoConverter; + public final void setUserInfoConverter(Converter, OidcUserInfo> userInfoConverter) { + Assert.notNull(userInfoConverter, "userInfoConverter cannot be null"); + this.userInfoConverter = userInfoConverter; } /** * Sets the {@link Converter} used for converting the {@link OidcUserInfo} to a - * {@code Map} representation of the OIDC User Info. + * {@code Map} representation of the UserInfo. * - * @param oidcUserInfoParametersConverter the {@link Converter} used for converting to a - * {@code Map} representation of the OIDC User Info + * @param userInfoParametersConverter the {@link Converter} used for converting to a + * {@code Map} representation of the UserInfo */ - public final void setOidcUserInfoParametersConverter( - Converter> oidcUserInfoParametersConverter) { - Assert.notNull(oidcUserInfoParametersConverter, "oidcUserInfoParametersConverter cannot be null"); - this.oidcUserInfoParametersConverter = oidcUserInfoParametersConverter; + public final void setUserInfoParametersConverter( + Converter> userInfoParametersConverter) { + Assert.notNull(userInfoParametersConverter, "userInfoParametersConverter cannot be null"); + this.userInfoParametersConverter = userInfoParametersConverter; } - private static final class OidcUserInfoConverter implements Converter, OidcUserInfo> { + private static final class MapOidcUserInfoConverter implements Converter, OidcUserInfo> { + private static final ClaimConversionService CLAIM_CONVERSION_SERVICE = ClaimConversionService.getSharedInstance(); private static final TypeDescriptor OBJECT_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(Object.class); private static final TypeDescriptor BOOLEAN_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(Boolean.class); private static final TypeDescriptor STRING_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(String.class); + private static final TypeDescriptor INSTANT_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(Instant.class); + private static final TypeDescriptor STRING_OBJECT_MAP_DESCRIPTOR = TypeDescriptor.map(Map.class, STRING_TYPE_DESCRIPTOR, OBJECT_TYPE_DESCRIPTOR); private final ClaimTypeConverter claimTypeConverter; - private OidcUserInfoConverter() { - Converter stringConverter = getConverter(STRING_TYPE_DESCRIPTOR); + private MapOidcUserInfoConverter() { Converter booleanConverter = getConverter(BOOLEAN_TYPE_DESCRIPTOR); + Converter stringConverter = getConverter(STRING_TYPE_DESCRIPTOR); + Converter instantConverter = getConverter(INSTANT_TYPE_DESCRIPTOR); + Converter mapConverter = getConverter(STRING_OBJECT_MAP_DESCRIPTOR); Map> claimConverters = new HashMap<>(); claimConverters.put(StandardClaimNames.SUB, stringConverter); - claimConverters.put(StandardClaimNames.PROFILE, stringConverter); - claimConverters.put(StandardClaimNames.ADDRESS, stringConverter); - claimConverters.put(StandardClaimNames.BIRTHDATE, stringConverter); - claimConverters.put(StandardClaimNames.EMAIL, stringConverter); - claimConverters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); claimConverters.put(StandardClaimNames.NAME, stringConverter); claimConverters.put(StandardClaimNames.GIVEN_NAME, stringConverter); - claimConverters.put(StandardClaimNames.MIDDLE_NAME, stringConverter); claimConverters.put(StandardClaimNames.FAMILY_NAME, stringConverter); + claimConverters.put(StandardClaimNames.MIDDLE_NAME, stringConverter); claimConverters.put(StandardClaimNames.NICKNAME, stringConverter); claimConverters.put(StandardClaimNames.PREFERRED_USERNAME, stringConverter); - claimConverters.put(StandardClaimNames.LOCALE, stringConverter); - claimConverters.put(StandardClaimNames.GENDER, stringConverter); - claimConverters.put(StandardClaimNames.PHONE_NUMBER, stringConverter); - claimConverters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, stringConverter); + claimConverters.put(StandardClaimNames.PROFILE, stringConverter); claimConverters.put(StandardClaimNames.PICTURE, stringConverter); - claimConverters.put(StandardClaimNames.ZONEINFO, stringConverter); claimConverters.put(StandardClaimNames.WEBSITE, stringConverter); - claimConverters.put(StandardClaimNames.UPDATED_AT, stringConverter); + claimConverters.put(StandardClaimNames.EMAIL, stringConverter); + claimConverters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); + claimConverters.put(StandardClaimNames.GENDER, stringConverter); + claimConverters.put(StandardClaimNames.BIRTHDATE, stringConverter); + claimConverters.put(StandardClaimNames.ZONEINFO, stringConverter); + claimConverters.put(StandardClaimNames.LOCALE, stringConverter); + claimConverters.put(StandardClaimNames.PHONE_NUMBER, stringConverter); + claimConverters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); + claimConverters.put(StandardClaimNames.ADDRESS, mapConverter); + claimConverters.put(StandardClaimNames.UPDATED_AT, instantConverter); this.claimTypeConverter = new ClaimTypeConverter(claimConverters); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ConfigurationSettingNames.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ConfigurationSettingNames.java index baa819be..4b0d082a 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ConfigurationSettingNames.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ConfigurationSettingNames.java @@ -94,6 +94,11 @@ public final class ConfigurationSettingNames { */ public static final String OIDC_CLIENT_REGISTRATION_ENDPOINT = PROVIDER_SETTINGS_NAMESPACE.concat("oidc-client-registration-endpoint"); + /** + * Set the Provider's OpenID Connect 1.0 UserInfo endpoint. + */ + public static final String OIDC_USER_INFO_ENDPOINT = PROVIDER_SETTINGS_NAMESPACE.concat("oidc-user-info-endpoint"); + private Provider() { } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettings.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettings.java index d2ff2e9f..1d75dc8f 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettings.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettings.java @@ -97,6 +97,15 @@ public final class ProviderSettings extends AbstractSettings { return getSetting(ConfigurationSettingNames.Provider.OIDC_CLIENT_REGISTRATION_ENDPOINT); } + /** + * Returns the Provider's OpenID Connect 1.0 UserInfo endpoint. The default is {@code /userinfo}. + * + * @return the OpenID Connect 1.0 User Info endpoint + */ + public String getOidcUserInfoEndpoint() { + return getSetting(ConfigurationSettingNames.Provider.OIDC_USER_INFO_ENDPOINT); + } + /** * Constructs a new {@link Builder} with the default settings. * @@ -109,7 +118,8 @@ public final class ProviderSettings extends AbstractSettings { .jwkSetEndpoint("/oauth2/jwks") .tokenRevocationEndpoint("/oauth2/revoke") .tokenIntrospectionEndpoint("/oauth2/introspect") - .oidcClientRegistrationEndpoint("/connect/register"); + .oidcClientRegistrationEndpoint("/connect/register") + .oidcUserInfoEndpoint("/userinfo"); } /** @@ -202,6 +212,16 @@ public final class ProviderSettings extends AbstractSettings { return setting(ConfigurationSettingNames.Provider.OIDC_CLIENT_REGISTRATION_ENDPOINT, oidcClientRegistrationEndpoint); } + /** + * Sets the Provider's OpenID Connect 1.0 UserInfo endpoint. + * + * @param oidcUserInfoEndpoint the OpenID Connect 1.0 User Info endpoint + * @return the {@link Builder} for further configuration + */ + public Builder oidcUserInfoEndpoint(String oidcUserInfoEndpoint) { + return setting(ConfigurationSettingNames.Provider.OIDC_USER_INFO_ENDPOINT, oidcUserInfoEndpoint); + } + /** * Builds the {@link ProviderSettings}. * diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/DefaultUserInfoClaimsMapper.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/DefaultUserInfoClaimsMapper.java deleted file mode 100644 index 4d3d4a2c..00000000 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/DefaultUserInfoClaimsMapper.java +++ /dev/null @@ -1,11 +0,0 @@ -package org.springframework.security.oauth2.server.authorization.oidc; - -import org.springframework.security.oauth2.core.oidc.OidcUserInfo; - -public class DefaultUserInfoClaimsMapper implements UserInfoClaimsMapper { - - public OidcUserInfo map(Object principal) { - return null; // TODO - } - -} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/UserInfoClaimsMapper.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/UserInfoClaimsMapper.java deleted file mode 100644 index 48590a2d..00000000 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/UserInfoClaimsMapper.java +++ /dev/null @@ -1,9 +0,0 @@ -package org.springframework.security.oauth2.server.authorization.oidc; - -import org.springframework.security.oauth2.core.oidc.OidcUserInfo; - -public interface UserInfoClaimsMapper { - - OidcUserInfo map(Object principal); - -} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationProvider.java new file mode 100644 index 00000000..e5b56202 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationProvider.java @@ -0,0 +1,199 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.OAuth2TokenType; +import org.springframework.security.oauth2.core.authentication.OAuth2AuthenticationContext; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken; +import org.springframework.util.Assert; + +/** + * An {@link AuthenticationProvider} implementation for OpenID Connect 1.0 UserInfo Endpoint. + * + * @author Steve Riesenberg + * @since 0.2.1 + * @see OAuth2AuthorizationService + * @see 5.3. UserInfo Endpoint + */ +public final class OidcUserInfoAuthenticationProvider implements AuthenticationProvider { + + private final OAuth2AuthorizationService authorizationService; + + private Function userInfoMapper = new DefaultOidcUserInfoMapper(); + + /** + * Constructs an {@code OidcUserInfoAuthenticationProvider} using the provided parameters. + * + * @param authorizationService the authorization service + */ + public OidcUserInfoAuthenticationProvider(OAuth2AuthorizationService authorizationService) { + Assert.notNull(authorizationService, "authorizationService cannot be null"); + this.authorizationService = authorizationService; + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + OidcUserInfoAuthenticationToken userInfoAuthentication = + (OidcUserInfoAuthenticationToken) authentication; + + AbstractOAuth2TokenAuthenticationToken accessTokenAuthentication = null; + if (AbstractOAuth2TokenAuthenticationToken.class.isAssignableFrom(userInfoAuthentication.getPrincipal().getClass())) { + accessTokenAuthentication = (AbstractOAuth2TokenAuthenticationToken) userInfoAuthentication.getPrincipal(); + } + if (accessTokenAuthentication == null || !accessTokenAuthentication.isAuthenticated()) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + + String accessTokenValue = accessTokenAuthentication.getToken().getTokenValue(); + + OAuth2Authorization authorization = this.authorizationService.findByToken( + accessTokenValue, OAuth2TokenType.ACCESS_TOKEN); + if (authorization == null) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + + OAuth2Authorization.Token authorizedAccessToken = authorization.getAccessToken(); + if (!authorizedAccessToken.isActive()) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + + if (!authorizedAccessToken.getToken().getScopes().contains(OidcScopes.OPENID)) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + } + + OAuth2Authorization.Token idToken = authorization.getToken(OidcIdToken.class); + if (idToken == null) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + + Map context = new HashMap<>(); + context.put(OAuth2Token.class, accessTokenAuthentication.getToken()); + context.put(OAuth2Authorization.class, authorization); + OAuth2AuthenticationContext authenticationContext = new OAuth2AuthenticationContext( + userInfoAuthentication, context); + + OidcUserInfo userInfo = this.userInfoMapper.apply(authenticationContext); + return new OidcUserInfoAuthenticationToken(accessTokenAuthentication, userInfo); + } + + @Override + public boolean supports(Class authentication) { + return OidcUserInfoAuthenticationToken.class.isAssignableFrom(authentication); + } + + /** + * Sets the {@link Function} used when mapping from an {@link OAuth2AuthenticationContext} + * to an instance of {@link OidcUserInfo} for the UserInfo response. + * + *

+ * The {@link OAuth2AuthenticationContext} gives the mapper access to the {@link OidcUserInfoAuthenticationToken}. + * In addition, the following context attributes are supported: + *

    + *
  • {@code OAuth2Token.class} - The {@link OAuth2Token} containing the bearer token used to make the request.
  • + *
  • {@code OAuth2Authorization.class} - The {@link OAuth2Authorization} containing the {@link OidcIdToken} and + * {@link OAuth2AccessToken} associated with the bearer token used to make the request.
  • + *
+ * + * @param userInfoMapper the {@link Function} used when mapping from an {@link OAuth2AuthenticationContext} + */ + public void setUserInfoMapper(Function userInfoMapper) { + Assert.notNull(userInfoMapper, "userInfoMapper cannot be null"); + this.userInfoMapper = userInfoMapper; + } + + private static final class DefaultOidcUserInfoMapper implements Function { + + private static final List EMAIL_CLAIMS = Arrays.asList( + StandardClaimNames.EMAIL, + StandardClaimNames.EMAIL_VERIFIED + ); + private static final List PHONE_CLAIMS = Arrays.asList( + StandardClaimNames.PHONE_NUMBER, + StandardClaimNames.PHONE_NUMBER_VERIFIED + ); + private static final List PROFILE_CLAIMS = Arrays.asList( + StandardClaimNames.NAME, + StandardClaimNames.FAMILY_NAME, + StandardClaimNames.GIVEN_NAME, + StandardClaimNames.MIDDLE_NAME, + StandardClaimNames.NICKNAME, + StandardClaimNames.PREFERRED_USERNAME, + StandardClaimNames.PROFILE, + StandardClaimNames.PICTURE, + StandardClaimNames.WEBSITE, + StandardClaimNames.GENDER, + StandardClaimNames.BIRTHDATE, + StandardClaimNames.ZONEINFO, + StandardClaimNames.LOCALE, + StandardClaimNames.UPDATED_AT + ); + + @Override + public OidcUserInfo apply(OAuth2AuthenticationContext authenticationContext) { + OAuth2Authorization authorization = authenticationContext.get(OAuth2Authorization.class); + OidcIdToken idToken = authorization.getToken(OidcIdToken.class).getToken(); + OAuth2AccessToken accessToken = authorization.getAccessToken().getToken(); + Map scopeRequestedClaims = getClaimsRequestedByScope(idToken.getClaims(), + accessToken.getScopes()); + + return new OidcUserInfo(scopeRequestedClaims); + } + + private Map getClaimsRequestedByScope(Map claims, Set requestedScopes) { + Set scopeRequestedClaimNames = new HashSet<>(32); + scopeRequestedClaimNames.add(StandardClaimNames.SUB); + + if (requestedScopes.contains(OidcScopes.ADDRESS)) { + scopeRequestedClaimNames.add(StandardClaimNames.ADDRESS); + } + if (requestedScopes.contains(OidcScopes.EMAIL)) { + scopeRequestedClaimNames.addAll(EMAIL_CLAIMS); + } + if (requestedScopes.contains(OidcScopes.PHONE)) { + scopeRequestedClaimNames.addAll(PHONE_CLAIMS); + } + if (requestedScopes.contains(OidcScopes.PROFILE)) { + scopeRequestedClaimNames.addAll(PROFILE_CLAIMS); + } + + Map requestedClaims = new HashMap<>(claims); + requestedClaims.keySet().removeIf(claimName -> !scopeRequestedClaimNames.contains(claimName)); + + return requestedClaims; + } + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationToken.java new file mode 100644 index 00000000..c74da8d2 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationToken.java @@ -0,0 +1,88 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.util.Collections; + +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.Version; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.util.Assert; + +/** + * An {@link Authentication} implementation used for OpenID Connect 1.0 UserInfo Endpoint. + * + * @author Steve Riesenberg + * @since 0.2.1 + * @see AbstractAuthenticationToken + * @see OidcUserInfo + * @see OidcUserInfoAuthenticationProvider + */ +public class OidcUserInfoAuthenticationToken extends AbstractAuthenticationToken { + + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; + + private final Authentication principal; + private final OidcUserInfo userInfo; + + /** + * Constructs an {@code OidcUserInfoAuthenticationToken} using the provided parameters. + * + * @param principal the authenticated principal + */ + public OidcUserInfoAuthenticationToken(Authentication principal) { + super(Collections.emptyList()); + Assert.notNull(principal, "principal cannot be null"); + this.principal = principal; + this.userInfo = null; + setAuthenticated(false); + } + + /** + * Constructs an {@code OidcUserInfoAuthenticationToken} using the provided parameters. + * + * @param principal the authenticated principal + * @param userInfo the UserInfo claims + */ + public OidcUserInfoAuthenticationToken(Authentication principal, OidcUserInfo userInfo) { + super(Collections.emptyList()); + Assert.notNull(principal, "principal cannot be null"); + Assert.notNull(userInfo, "userInfo cannot be null"); + this.principal = principal; + this.userInfo = userInfo; + setAuthenticated(principal.isAuthenticated()); + } + + @Override + public Object getPrincipal() { + return this.principal; + } + + @Override + public Object getCredentials() { + return ""; + } + + /** + * Returns the UserInfo claims. + * + * @return the UserInfo claims + */ + public OidcUserInfo getUserInfo() { + return this.userInfo; + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java index fb39e2cb..42ad6495 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,6 @@ package org.springframework.security.oauth2.server.authorization.oidc.web; import java.io.IOException; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -28,116 +23,120 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.oauth2.core.oidc.OidcUserInfo; -import org.springframework.security.oauth2.core.oidc.StandardClaimNames; import org.springframework.security.oauth2.core.oidc.http.converter.OidcUserInfoHttpMessageConverter; -import org.springframework.security.oauth2.server.authorization.oidc.UserInfoClaimsMapper; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; /** - * A {@code Filter} that processes OpenID User Info requests. + * A {@code Filter} that processes OpenID Connect 1.0 UserInfo Requests. * * @author Ido Salomon + * @author Steve Riesenberg + * @since 0.2.1 * @see OidcUserInfo - * @see 5.3.1. UserInfo Request - * @since 0.1.1 + * @see 5.3. UserInfo Endpoint */ -public class OidcUserInfoEndpointFilter extends OncePerRequestFilter { +public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter { /** - * The default endpoint {@code URI} for OpenID User Info requests. + * The default endpoint {@code URI} for OpenID Connect 1.0 UserInfo Requests. */ - public static final String DEFAULT_OIDC_USER_INFO_ENDPOINT_URI = "/userinfo"; + private static final String DEFAULT_OIDC_USER_INFO_ENDPOINT_URI = "/userinfo"; - private final RequestMatcher requestMatcher; - private final OidcUserInfoHttpMessageConverter oidcUserInfoHttpMessageConverter = + private final AuthenticationManager authenticationManager; + private final RequestMatcher userInfoEndpointMatcher; + + private final HttpMessageConverter userInfoHttpMessageConverter = new OidcUserInfoHttpMessageConverter(); - private final UserInfoClaimsMapper userInfoClaimsMapper; - - public OidcUserInfoEndpointFilter(UserInfoClaimsMapper userInfoClaimsMapper) { - AntPathRequestMatcher userInfoGetMatcher = new AntPathRequestMatcher( - DEFAULT_OIDC_USER_INFO_ENDPOINT_URI, - HttpMethod.GET.name() - ); - AntPathRequestMatcher userInfoPostMatcher = new AntPathRequestMatcher( - DEFAULT_OIDC_USER_INFO_ENDPOINT_URI, - HttpMethod.POST.name() - ); - this.requestMatcher = new OrRequestMatcher(userInfoGetMatcher, userInfoPostMatcher); - this.userInfoClaimsMapper = userInfoClaimsMapper; + private final HttpMessageConverter errorHttpResponseConverter = + new OAuth2ErrorHttpMessageConverter(); + + /** + * Constructs an {@code OidcUserInfoEndpointFilter} using the provided parameters. + * + * @param authenticationManager the authentication manager + */ + public OidcUserInfoEndpointFilter(AuthenticationManager authenticationManager) { + this(authenticationManager, DEFAULT_OIDC_USER_INFO_ENDPOINT_URI); + } + + /** + * Constructs an {@code OidcUserInfoEndpointFilter} using the provided parameters. + * + * @param authenticationManager the authentication manager + * @param userInfoEndpointUri the endpoint {@code URI} for OpenID Connect 1.0 UserInfo Requests + */ + public OidcUserInfoEndpointFilter(AuthenticationManager authenticationManager, String userInfoEndpointUri) { + Assert.notNull(authenticationManager, "authenticationManager cannot be null"); + Assert.hasText(userInfoEndpointUri, "userInfoEndpointUri cannot be empty"); + this.authenticationManager = authenticationManager; + this.userInfoEndpointMatcher = new OrRequestMatcher( + new AntPathRequestMatcher(userInfoEndpointUri, HttpMethod.GET.name()), + new AntPathRequestMatcher(userInfoEndpointUri, HttpMethod.POST.name())); } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - if (!this.requestMatcher.matches(request)) { + if (!this.userInfoEndpointMatcher.matches(request)) { filterChain.doFilter(request, response); return; } - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - Object authenticationDetails = authentication.getDetails(); - Object principal = authentication.getPrincipal(); - OidcUserInfo oidcUserInfo = userInfoClaimsMapper.map(principal); - - if (authenticationDetails instanceof OAuth2AccessToken) { - oidcUserInfo = getUserInfoClaimsRequestedByScope(oidcUserInfo, ((OAuth2AccessToken) authenticationDetails).getScopes()); - } else { - oidcUserInfo = OidcUserInfo.builder() - .subject(oidcUserInfo.getSubject()) - .build(); - } + try { + Authentication principal = SecurityContextHolder.getContext().getAuthentication(); - ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - this.oidcUserInfoHttpMessageConverter.write( - oidcUserInfo, MediaType.APPLICATION_JSON, httpResponse); - } + OidcUserInfoAuthenticationToken userInfoAuthentication = new OidcUserInfoAuthenticationToken(principal); - private OidcUserInfo getUserInfoClaimsRequestedByScope(OidcUserInfo userInfo, Set scopes) { - Set scopeRequestedClaimNames = getScopeRequestedClaimNames(scopes); + OidcUserInfoAuthenticationToken userInfoAuthenticationResult = + (OidcUserInfoAuthenticationToken) this.authenticationManager.authenticate(userInfoAuthentication); - Map scopeRequestedClaims = userInfo.getClaims().entrySet().stream() - .filter(claim -> scopeRequestedClaimNames.contains(claim.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + sendUserInfoResponse(response, userInfoAuthenticationResult.getUserInfo()); - return new OidcUserInfo(scopeRequestedClaims); - } - - private Set getScopeRequestedClaimNames(Set scopes) { - Set scopeRequestedClaimNames = new HashSet<>(Arrays.asList(StandardClaimNames.SUB)); - Set profileClaimNames = new HashSet<>(Arrays.asList(StandardClaimNames.NAME, - StandardClaimNames.FAMILY_NAME, StandardClaimNames.GIVEN_NAME, StandardClaimNames.MIDDLE_NAME, - StandardClaimNames.NICKNAME, StandardClaimNames.PREFERRED_USERNAME, StandardClaimNames.PROFILE, - StandardClaimNames.PICTURE, StandardClaimNames.WEBSITE, StandardClaimNames.GENDER, - StandardClaimNames.BIRTHDATE, StandardClaimNames.ZONEINFO, StandardClaimNames.LOCALE, StandardClaimNames.UPDATED_AT)); - Set emailClaimNames = new HashSet<>(Arrays.asList(StandardClaimNames.EMAIL, StandardClaimNames.EMAIL_VERIFIED)); - String addressClaimName = StandardClaimNames.ADDRESS; - Set phoneClaimNames = new HashSet<>(Arrays.asList(StandardClaimNames.PHONE_NUMBER, StandardClaimNames.PHONE_NUMBER_VERIFIED)); - - if (scopes.contains(OidcScopes.ADDRESS)) { - scopeRequestedClaimNames.add(addressClaimName); - } - if (scopes.contains(OidcScopes.EMAIL)) { - scopeRequestedClaimNames.addAll(emailClaimNames); - } - if (scopes.contains(OidcScopes.PHONE)) { - scopeRequestedClaimNames.addAll(phoneClaimNames); - } - if (scopes.contains(OidcScopes.PROFILE)) { - scopeRequestedClaimNames.addAll(profileClaimNames); + } catch (OAuth2AuthenticationException ex) { + sendErrorResponse(response, ex.getError()); + } catch (Exception ex) { + OAuth2Error error = new OAuth2Error( + OAuth2ErrorCodes.INVALID_REQUEST, + "OpenID Connect 1.0 UserInfo Error: " + ex.getMessage(), + "https://openid.net/specs/openid-connect-core-1_0.html#UserInfoError"); + sendErrorResponse(response, error); + } finally { + SecurityContextHolder.clearContext(); } + } - return scopeRequestedClaimNames; + private void sendUserInfoResponse(HttpServletResponse response, OidcUserInfo userInfo) throws IOException { + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + this.userInfoHttpMessageConverter.write(userInfo, MediaType.APPLICATION_JSON, httpResponse); } + private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException { + HttpStatus httpStatus = HttpStatus.BAD_REQUEST; + if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_TOKEN)) { + httpStatus = HttpStatus.UNAUTHORIZED; + } else if (error.getErrorCode().equals(OAuth2ErrorCodes.INSUFFICIENT_SCOPE)) { + httpStatus = HttpStatus.FORBIDDEN; + } + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + httpResponse.setStatusCode(httpStatus); + this.errorHttpResponseConverter.write(error, null, httpResponse); + } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcUserInfoTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcUserInfoTests.java new file mode 100644 index 00000000..7075a28c --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcUserInfoTests.java @@ -0,0 +1,308 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; + +import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import java.util.function.Function; + +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.source.ImmutableJWKSet; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.SecurityContext; +import org.junit.Rule; +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.http.HttpHeaders; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; +import org.springframework.security.config.annotation.web.configurers.oauth2.server.resource.OAuth2ResourceServerConfigurer; +import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.authentication.OAuth2AuthenticationContext; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeader; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.jwt.NimbusJwsEncoder; +import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.client.InMemoryRegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; +import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.ResultMatcher; + +import static org.springframework.test.web.servlet.ResultMatcher.matchAll; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * Integration tests for the OpenID Connect 1.0 UserInfo endpoint. + * + * @author Steve Riesenberg + */ +public class OidcUserInfoTests { + private static final String DEFAULT_OIDC_USER_INFO_ENDPOINT_URI = "/userinfo"; + + @Rule + public final SpringTestRule spring = new SpringTestRule(); + + @Autowired + private MockMvc mvc; + + @Autowired + private JwtEncoder jwtEncoder; + + @Autowired + private OAuth2AuthorizationService authorizationService; + + @Test + public void requestWhenUserInfoRequestGetThenUserInfoResponse() throws Exception { + this.spring.register(AuthorizationServerConfiguration.class).autowire(); + + OAuth2Authorization authorization = createAuthorization(); + this.authorizationService.save(authorization); + + OAuth2AccessToken accessToken = authorization.getAccessToken().getToken(); + // @formatter:off + this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI) + .header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue())) + .andExpect(status().is2xxSuccessful()) + .andExpect(userInfoResponse()); + // @formatter:on + } + + @Test + public void requestWhenUserInfoRequestPostThenUserInfoResponse() throws Exception { + this.spring.register(AuthorizationServerConfiguration.class).autowire(); + + OAuth2Authorization authorization = createAuthorization(); + this.authorizationService.save(authorization); + + OAuth2AccessToken accessToken = authorization.getAccessToken().getToken(); + // @formatter:off + this.mvc.perform(post(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI) + .header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue())) + .andExpect(status().is2xxSuccessful()) + .andExpect(userInfoResponse()); + // @formatter:on + } + + @Test + public void requestWhenSignedJwtAndCustomUserInfoMapperThenUserInfoResponse() throws Exception { + this.spring.register(CustomUserInfoConfiguration.class).autowire(); + + OAuth2Authorization authorization = createAuthorization(); + this.authorizationService.save(authorization); + + OAuth2AccessToken accessToken = authorization.getAccessToken().getToken(); + // @formatter:off + this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI) + .header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue())) + .andExpect(status().is2xxSuccessful()) + .andExpect(userInfoResponse()); + // @formatter:on + } + + private static ResultMatcher userInfoResponse() { + // @formatter:off + return matchAll( + jsonPath("sub").value("user1"), + jsonPath("name").value("First Last"), + jsonPath("given_name").value("First"), + jsonPath("family_name").value("Last"), + jsonPath("middle_name").value("Middle"), + jsonPath("nickname").value("User"), + jsonPath("preferred_username").value("user"), + jsonPath("profile").value("https://example.com/user1"), + jsonPath("picture").value("https://example.com/user1.jpg"), + jsonPath("website").value("https://example.com"), + jsonPath("email").value("user1@example.com"), + jsonPath("email_verified").value("true"), + jsonPath("gender").value("female"), + jsonPath("birthdate").value("1970-01-01"), + jsonPath("zoneinfo").value("Europe/Paris"), + jsonPath("locale").value("en-US"), + jsonPath("phone_number").value("+1 (604) 555-1234;ext=5678"), + jsonPath("phone_number_verified").value("false"), + jsonPath("address").value("Champ de Mars\n5 Av. Anatole France\n75007 Paris\nFrance"), + jsonPath("updated_at").value("1970-01-01T00:00:00Z") + ); + // @formatter:on + } + + private OAuth2Authorization createAuthorization() { + JoseHeader headers = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + // @formatter:off + JwtClaimsSet claimSet = JwtClaimsSet.builder() + .claims(claims -> claims.putAll(createUserInfo().getClaims())) + .build(); + // @formatter:on + Jwt jwt = this.jwtEncoder.encode(headers, claimSet); + + Instant now = Instant.now(); + Set scopes = new HashSet<>(Arrays.asList( + OidcScopes.OPENID, OidcScopes.ADDRESS, OidcScopes.EMAIL, OidcScopes.PHONE, OidcScopes.PROFILE)); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, jwt.getTokenValue(), now, now.plusSeconds(300), scopes); + OidcIdToken idToken = OidcIdToken.withTokenValue("id-token") + .claims(claims -> claims.putAll(createUserInfo().getClaims())) + .build(); + + return TestOAuth2Authorizations.authorization() + .accessToken(accessToken) + .token(idToken) + .build(); + } + + private static OidcUserInfo createUserInfo() { + // @formatter:off + return OidcUserInfo.builder() + .subject("user1") + .name("First Last") + .givenName("First") + .familyName("Last") + .middleName("Middle") + .nickname("User") + .preferredUsername("user") + .profile("https://example.com/user1") + .picture("https://example.com/user1.jpg") + .website("https://example.com") + .email("user1@example.com") + .emailVerified(true) + .gender("female") + .birthdate("1970-01-01") + .zoneinfo("Europe/Paris") + .locale("en-US") + .phoneNumber("+1 (604) 555-1234;ext=5678") + .phoneNumberVerified("false") + .address("Champ de Mars\n5 Av. Anatole France\n75007 Paris\nFrance") + .updatedAt("1970-01-01T00:00:00Z") + .build(); + // @formatter:on + } + + @EnableWebSecurity + static class CustomUserInfoConfiguration extends AuthorizationServerConfiguration { + + @Bean + @Override + SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { + OAuth2AuthorizationServerConfigurer authorizationServerConfigurer = + new OAuth2AuthorizationServerConfigurer<>(); + RequestMatcher endpointsMatcher = authorizationServerConfigurer + .getEndpointsMatcher(); + + // Custom User Info Mapper that retrieves claims from a signed JWT + Function userInfoMapper = context -> { + OidcUserInfoAuthenticationToken authentication = context.getAuthentication(); + JwtAuthenticationToken principal = (JwtAuthenticationToken) authentication.getPrincipal(); + + return new OidcUserInfo(principal.getToken().getClaims()); + }; + + // @formatter:off + http + .requestMatcher(endpointsMatcher) + .authorizeRequests(authorizeRequests -> + authorizeRequests.anyRequest().authenticated() + ) + .csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher)) + .oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt) + .apply(authorizationServerConfigurer) + .oidc(oidc -> oidc + .userInfoEndpoint(userInfo -> userInfo + .userInfoMapper(userInfoMapper) + ) + ); + // @formatter:on + + return http.build(); + } + } + + @EnableWebSecurity + static class AuthorizationServerConfiguration { + + @Bean + SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { + OAuth2AuthorizationServerConfigurer authorizationServerConfigurer = + new OAuth2AuthorizationServerConfigurer<>(); + RequestMatcher endpointsMatcher = authorizationServerConfigurer + .getEndpointsMatcher(); + + // @formatter:off + http + .requestMatcher(endpointsMatcher) + .authorizeRequests(authorizeRequests -> + authorizeRequests.anyRequest().authenticated() + ) + .csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher)) + .oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt) + .apply(authorizationServerConfigurer); + // @formatter:on + + return http.build(); + } + + @Bean + RegisteredClientRepository registeredClientRepository() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + return new InMemoryRegisteredClientRepository(registeredClient); + } + + @Bean + OAuth2AuthorizationService authorizationService() { + return new InMemoryOAuth2AuthorizationService(); + } + + @Bean + JWKSource jwkSource() { + return new ImmutableJWKSet<>(new JWKSet(TestJwks.DEFAULT_RSA_JWK)); + } + + @Bean + JwtDecoder jwtDecoder(JWKSource jwkSource) { + return OAuth2AuthorizationServerConfiguration.jwtDecoder(jwkSource); + } + + @Bean + JwtEncoder jwtEncoder(JWKSource jwkSource) { + return new NimbusJwsEncoder(jwkSource); + } + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcUserInfoHttpMessageConverterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcUserInfoHttpMessageConverterTests.java new file mode 100644 index 00000000..9a3f3727 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcUserInfoHttpMessageConverterTests.java @@ -0,0 +1,224 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.core.oidc.http.converter; + +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.mock.http.MockHttpOutputMessage; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link OidcUserInfoHttpMessageConverter}. + * + * @author Steve Riesenberg + */ +public class OidcUserInfoHttpMessageConverterTests { + private final OidcUserInfoHttpMessageConverter messageConverter = new OidcUserInfoHttpMessageConverter(); + + @Test + public void supportsWhenOidcUserInfoThenTrue() { + assertThat(this.messageConverter.supports(OidcUserInfo.class)).isTrue(); + } + + @Test + public void setUserInfoConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.messageConverter.setUserInfoConverter(null)); + } + + @Test + public void setUserInfoParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.messageConverter.setUserInfoParametersConverter(null)); + } + + @Test + public void readInternalWhenValidParametersThenSuccess() { + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"user1\",\n" + + " \"name\": \"First Last\",\n" + + " \"given_name\": \"First\",\n" + + " \"family_name\": \"Last\",\n" + + " \"middle_name\": \"Middle\",\n" + + " \"nickname\": \"User\",\n" + + " \"preferred_username\": \"user\",\n" + + " \"profile\": \"https://example.com/user1\",\n" + + " \"picture\": \"https://example.com/user1.jpg\",\n" + + " \"website\": \"https://example.com\",\n" + + " \"email\": \"user1@example.com\",\n" + + " \"email_verified\": \"true\",\n" + + " \"gender\": \"female\",\n" + + " \"birthdate\": \"1970-01-01\",\n" + + " \"zoneinfo\": \"Europe/Paris\",\n" + + " \"locale\": \"en-US\",\n" + + " \"phone_number\": \"+1 (604) 555-1234;ext=5678\",\n" + + " \"phone_number_verified\": \"false\",\n" + + " \"address\": {\n" + + " \"formatted\": \"Champ de Mars\\n5 Av. Anatole France\\n75007 Paris\\nFrance\",\n" + + " \"street_address\": \"Champ de Mars\\n5 Av. Anatole France\",\n" + + " \"locality\": \"Paris\",\n" + + " \"postal_code\": \"75007\",\n" + + " \"country\": \"France\"\n" + + " },\n" + + " \"updated_at\": 1607633867\n" + + "}\n"; + // @formatter:on + + MockClientHttpResponse response = new MockClientHttpResponse(userInfoResponse.getBytes(), HttpStatus.OK); + OidcUserInfo oidcUserInfo = this.messageConverter.readInternal(OidcUserInfo.class, response); + + assertThat(oidcUserInfo.getSubject()).isEqualTo("user1"); + assertThat(oidcUserInfo.getFullName()).isEqualTo("First Last"); + assertThat(oidcUserInfo.getGivenName()).isEqualTo("First"); + assertThat(oidcUserInfo.getFamilyName()).isEqualTo("Last"); + assertThat(oidcUserInfo.getMiddleName()).isEqualTo("Middle"); + assertThat(oidcUserInfo.getNickName()).isEqualTo("User"); + assertThat(oidcUserInfo.getPreferredUsername()).isEqualTo("user"); + assertThat(oidcUserInfo.getProfile()).isEqualTo("https://example.com/user1"); + assertThat(oidcUserInfo.getPicture()).isEqualTo("https://example.com/user1.jpg"); + assertThat(oidcUserInfo.getWebsite()).isEqualTo("https://example.com"); + assertThat(oidcUserInfo.getEmail()).isEqualTo("user1@example.com"); + assertThat(oidcUserInfo.getEmailVerified()).isTrue(); + assertThat(oidcUserInfo.getGender()).isEqualTo("female"); + assertThat(oidcUserInfo.getBirthdate()).isEqualTo("1970-01-01"); + assertThat(oidcUserInfo.getZoneInfo()).isEqualTo("Europe/Paris"); + assertThat(oidcUserInfo.getLocale()).isEqualTo("en-US"); + assertThat(oidcUserInfo.getPhoneNumber()).isEqualTo("+1 (604) 555-1234;ext=5678"); + assertThat(oidcUserInfo.getPhoneNumberVerified()).isFalse(); + assertThat(oidcUserInfo.getAddress().getFormatted()).isEqualTo("Champ de Mars\n5 Av. Anatole France\n75007 Paris\nFrance"); + assertThat(oidcUserInfo.getAddress().getStreetAddress()).isEqualTo("Champ de Mars\n5 Av. Anatole France"); + assertThat(oidcUserInfo.getAddress().getLocality()).isEqualTo("Paris"); + assertThat(oidcUserInfo.getAddress().getPostalCode()).isEqualTo("75007"); + assertThat(oidcUserInfo.getAddress().getCountry()).isEqualTo("France"); + assertThat(oidcUserInfo.getUpdatedAt()).isEqualTo(Instant.ofEpochSecond(1607633867)); + } + + @Test + public void readInternalWhenFailingConverterThenThrowException() { + String errorMessage = "this is not a valid converter"; + this.messageConverter.setUserInfoConverter(source -> { + throw new RuntimeException(errorMessage); + }); + MockClientHttpResponse response = new MockClientHttpResponse("{}".getBytes(), HttpStatus.OK); + + assertThatExceptionOfType(HttpMessageNotReadableException.class) + .isThrownBy(() -> this.messageConverter.readInternal(OidcUserInfo.class, response)) + .withMessageContaining("An error occurred reading the UserInfo") + .withMessageContaining(errorMessage); + } + + @Test + public void readInternalWhenInvalidResponseThenThrowException() { + String providerConfigurationResponse = "{}"; + MockClientHttpResponse response = new MockClientHttpResponse(providerConfigurationResponse.getBytes(), HttpStatus.OK); + + assertThatExceptionOfType(HttpMessageNotReadableException.class) + .isThrownBy(() -> this.messageConverter.readInternal(OidcUserInfo.class, response)) + .withMessageContaining("An error occurred reading the UserInfo") + .withMessageContaining("claims cannot be empty"); + } + + @Test + public void writeInternalWhenOidcUserInfoThenSuccess() { + OidcUserInfo userInfo = createUserInfo(); + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + + this.messageConverter.writeInternal(userInfo, outputMessage); + + String userInfoResponse = outputMessage.getBodyAsString(); + assertThat(userInfoResponse).contains("\"sub\":\"user1\""); + assertThat(userInfoResponse).contains("\"name\":\"First Last\""); + assertThat(userInfoResponse).contains("\"given_name\":\"First\""); + assertThat(userInfoResponse).contains("\"family_name\":\"Last\""); + assertThat(userInfoResponse).contains("\"middle_name\":\"Middle\""); + assertThat(userInfoResponse).contains("\"nickname\":\"User\""); + assertThat(userInfoResponse).contains("\"preferred_username\":\"user\""); + assertThat(userInfoResponse).contains("\"profile\":\"https://example.com/user1\""); + assertThat(userInfoResponse).contains("\"picture\":\"https://example.com/user1.jpg\""); + assertThat(userInfoResponse).contains("\"website\":\"https://example.com\""); + assertThat(userInfoResponse).contains("\"email\":\"user1@example.com\""); + assertThat(userInfoResponse).contains("\"email_verified\":true"); + assertThat(userInfoResponse).contains("\"gender\":\"female\""); + assertThat(userInfoResponse).contains("\"birthdate\":\"1970-01-01\""); + assertThat(userInfoResponse).contains("\"zoneinfo\":\"Europe/Paris\""); + assertThat(userInfoResponse).contains("\"locale\":\"en-US\""); + assertThat(userInfoResponse).contains("\"phone_number\":\"+1 (604) 555-1234;ext=5678\""); + assertThat(userInfoResponse).contains("\"phone_number_verified\":false"); + assertThat(userInfoResponse).contains("\"address\":"); + assertThat(userInfoResponse).contains("\"formatted\":\"Champ de Mars\\n5 Av. Anatole France\\n75007 Paris\\nFrance\""); + assertThat(userInfoResponse).contains("\"updated_at\":1607633867"); + assertThat(userInfoResponse).contains("\"custom_claim\":\"value\""); + assertThat(userInfoResponse).contains("\"custom_collection_claim\":[\"value1\",\"value2\"]"); + } + + @Test + public void writeInternalWhenWriteFailsThenThrowsException() { + String errorMessage = "this is not a valid converter"; + Converter> failingConverter = source -> { + throw new RuntimeException(errorMessage); + }; + this.messageConverter.setUserInfoParametersConverter(failingConverter); + + OidcUserInfo userInfo = createUserInfo(); + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + + assertThatExceptionOfType(HttpMessageNotWritableException.class) + .isThrownBy(() -> this.messageConverter.writeInternal(userInfo, outputMessage)) + .withMessageContaining("An error occurred writing the UserInfo response") + .withMessageContaining(errorMessage); + } + + private static OidcUserInfo createUserInfo() { + return OidcUserInfo.builder() + .subject("user1") + .name("First Last") + .givenName("First") + .familyName("Last") + .middleName("Middle") + .nickname("User") + .preferredUsername("user") + .profile("https://example.com/user1") + .picture("https://example.com/user1.jpg") + .website("https://example.com") + .email("user1@example.com") + .emailVerified(true) + .gender("female") + .birthdate("1970-01-01") + .zoneinfo("Europe/Paris") + .locale("en-US") + .phoneNumber("+1 (604) 555-1234;ext=5678") + .claim("phone_number_verified", false) + .claim("address", Collections.singletonMap("formatted", "Champ de Mars\n5 Av. Anatole France\n75007 Paris\nFrance")) + .claim(StandardClaimNames.UPDATED_AT, Instant.ofEpochSecond(1607633867)) + .claim("custom_claim", "value") + .claim("custom_collection_claim", Arrays.asList("value1", "value2")) + .build(); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettingsTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettingsTests.java index e51e5a08..5a5c18d8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettingsTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettingsTests.java @@ -38,6 +38,7 @@ public class ProviderSettingsTests { assertThat(providerSettings.getTokenRevocationEndpoint()).isEqualTo("/oauth2/revoke"); assertThat(providerSettings.getTokenIntrospectionEndpoint()).isEqualTo("/oauth2/introspect"); assertThat(providerSettings.getOidcClientRegistrationEndpoint()).isEqualTo("/connect/register"); + assertThat(providerSettings.getOidcUserInfoEndpoint()).isEqualTo("/userinfo"); } @Test @@ -48,6 +49,7 @@ public class ProviderSettingsTests { String tokenRevocationEndpoint = "/oauth2/v1/revoke"; String tokenIntrospectionEndpoint = "/oauth2/v1/introspect"; String oidcClientRegistrationEndpoint = "/connect/v1/register"; + String oidcUserInfoEndpoint = "/connect/v1/userinfo"; String issuer = "https://example.com:9000"; ProviderSettings providerSettings = ProviderSettings.builder() @@ -59,6 +61,7 @@ public class ProviderSettingsTests { .tokenIntrospectionEndpoint(tokenIntrospectionEndpoint) .tokenRevocationEndpoint(tokenRevocationEndpoint) .oidcClientRegistrationEndpoint(oidcClientRegistrationEndpoint) + .oidcUserInfoEndpoint(oidcUserInfoEndpoint) .build(); assertThat(providerSettings.getIssuer()).isEqualTo(issuer); @@ -68,6 +71,7 @@ public class ProviderSettingsTests { assertThat(providerSettings.getTokenRevocationEndpoint()).isEqualTo(tokenRevocationEndpoint); assertThat(providerSettings.getTokenIntrospectionEndpoint()).isEqualTo(tokenIntrospectionEndpoint); assertThat(providerSettings.getOidcClientRegistrationEndpoint()).isEqualTo(oidcClientRegistrationEndpoint); + assertThat(providerSettings.getOidcUserInfoEndpoint()).isEqualTo(oidcUserInfoEndpoint); } @Test @@ -77,7 +81,7 @@ public class ProviderSettingsTests { .settings(settings -> settings.put("name2", "value2")) .build(); - assertThat(providerSettings.getSettings()).hasSize(8); + assertThat(providerSettings.getSettings()).hasSize(9); assertThat(providerSettings.getSetting("name1")).isEqualTo("value1"); assertThat(providerSettings.getSetting("name2")).isEqualTo("value2"); } @@ -124,6 +128,13 @@ public class ProviderSettingsTests { .withMessage("value cannot be null"); } + @Test + public void oidcUserInfoEndpointWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> ProviderSettings.builder().oidcUserInfoEndpoint(null)) + .withMessage("value cannot be null"); + } + @Test public void jwksEndpointWhenNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException() diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationProviderTests.java new file mode 100644 index 00000000..aaf9c43e --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationProviderTests.java @@ -0,0 +1,285 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2TokenType; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeaderNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OidcUserInfoAuthenticationProvider}. + * + * @author Steve Riesenberg + */ +public class OidcUserInfoAuthenticationProviderTests { + private OAuth2AuthorizationService authorizationService; + private OidcUserInfoAuthenticationProvider authenticationProvider; + + @Before + public void setUp() throws Exception { + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.authenticationProvider = new OidcUserInfoAuthenticationProvider(authorizationService); + } + + @Test + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcUserInfoAuthenticationProvider(null)) + .withMessage("authorizationService cannot be null"); + } + + @Test + public void setUserInfoMapperWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authenticationProvider.setUserInfoMapper(null)) + .withMessage("userInfoMapper cannot be null"); + } + + @Test + public void supportsWhenTypeOidcUserInfoAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OidcUserInfoAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenPrincipalNotOfExpectedTypeThenThrowOAuth2AuthenticationException() { + OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken( + new UsernamePasswordAuthenticationToken(null, null)); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + + verifyNoInteractions(this.authorizationService); + } + + @Test + public void authenticateWhenPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() { + String tokenValue = "token"; + JwtAuthenticationToken principal = createJwtAuthenticationToken(tokenValue); + principal.setAuthenticated(false); + OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + + verifyNoInteractions(this.authorizationService); + } + + @Test + public void authenticateWhenAccessTokenNotFoundThenThrowOAuth2AuthenticationException() { + String tokenValue = "token"; + JwtAuthenticationToken principal = createJwtAuthenticationToken(tokenValue); + OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + + verify(this.authorizationService).findByToken(eq(tokenValue), eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenAccessTokenNotActiveThenThrowOAuth2AuthenticationException() { + String tokenValue = "token"; + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + authorization = OidcAuthenticationProviderUtils.invalidate(authorization, + authorization.getAccessToken().getToken()); + when(this.authorizationService.findByToken(eq(tokenValue), eq(OAuth2TokenType.ACCESS_TOKEN))) + .thenReturn(authorization); + + JwtAuthenticationToken principal = createJwtAuthenticationToken(tokenValue); + OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + + verify(this.authorizationService).findByToken(eq(tokenValue), eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenAccessTokenNotAuthorizedThenThrowOAuth2AuthenticationException() { + String tokenValue = "token"; + when(this.authorizationService.findByToken(eq(tokenValue), eq(OAuth2TokenType.ACCESS_TOKEN))) + .thenReturn(TestOAuth2Authorizations.authorization().build()); + + JwtAuthenticationToken principal = createJwtAuthenticationToken(tokenValue); + OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + + verify(this.authorizationService).findByToken(eq(tokenValue), eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenIdTokenNullThenThrowOAuth2AuthenticationException() { + String tokenValue = "token"; + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization() + .token(createAuthorization(tokenValue).getAccessToken().getToken()) + .build(); + when(this.authorizationService.findByToken(eq(tokenValue), eq(OAuth2TokenType.ACCESS_TOKEN))) + .thenReturn(authorization); + + JwtAuthenticationToken principal = createJwtAuthenticationToken(tokenValue); + OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + + verify(this.authorizationService).findByToken(eq(tokenValue), eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenValidAccessTokenThenReturnUserInfo() { + String tokenValue = "access-token"; + when(this.authorizationService.findByToken(eq(tokenValue), eq(OAuth2TokenType.ACCESS_TOKEN))) + .thenReturn(createAuthorization(tokenValue)); + + JwtAuthenticationToken principal = createJwtAuthenticationToken(tokenValue); + OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal); + OidcUserInfoAuthenticationToken authenticationResult = + (OidcUserInfoAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + assertThat(authenticationResult.getPrincipal()).isEqualTo(principal); + assertThat(authenticationResult.getCredentials()).isEqualTo(""); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + + OidcUserInfo userInfo = authenticationResult.getUserInfo(); + assertThat(userInfo.getClaims()).hasSize(20); + assertThat(userInfo.getSubject()).isEqualTo("user1"); + assertThat(userInfo.getFullName()).isEqualTo("First Last"); + assertThat(userInfo.getGivenName()).isEqualTo("First"); + assertThat(userInfo.getFamilyName()).isEqualTo("Last"); + assertThat(userInfo.getMiddleName()).isEqualTo("Middle"); + assertThat(userInfo.getNickName()).isEqualTo("User"); + assertThat(userInfo.getPreferredUsername()).isEqualTo("user"); + assertThat(userInfo.getProfile()).isEqualTo("https://example.com/user1"); + assertThat(userInfo.getPicture()).isEqualTo("https://example.com/user1.jpg"); + assertThat(userInfo.getWebsite()).isEqualTo("https://example.com"); + assertThat(userInfo.getEmail()).isEqualTo("user1@example.com"); + assertThat(userInfo.getEmailVerified()).isEqualTo(true); + assertThat(userInfo.getGender()).isEqualTo("female"); + assertThat(userInfo.getBirthdate()).isEqualTo("1970-01-01"); + assertThat(userInfo.getZoneInfo()).isEqualTo("Europe/Paris"); + assertThat(userInfo.getLocale()).isEqualTo("en-US"); + assertThat(userInfo.getPhoneNumber()).isEqualTo("+1 (604) 555-1234;ext=5678"); + assertThat(userInfo.getPhoneNumberVerified()).isEqualTo(false); + assertThat(userInfo.getClaimAsString(StandardClaimNames.ADDRESS)) + .isEqualTo("Champ de Mars\n5 Av. Anatole France\n75007 Paris\nFrance"); + assertThat(userInfo.getUpdatedAt()).isEqualTo(Instant.parse("1970-01-01T00:00:00Z")); + + verify(this.authorizationService).findByToken(eq(tokenValue), eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + private static OAuth2Authorization createAuthorization(String tokenValue) { + Instant now = Instant.now(); + Set scopes = new HashSet<>(Arrays.asList( + OidcScopes.OPENID, OidcScopes.ADDRESS, OidcScopes.EMAIL, OidcScopes.PHONE, OidcScopes.PROFILE)); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, tokenValue, now, now.plusSeconds(300), scopes); + OidcIdToken idToken = new OidcIdToken("id-token", now, now.plusSeconds(900), createUserInfo().getClaims()); + + return TestOAuth2Authorizations.authorization() + .token(accessToken) + .token(idToken) + .build(); + } + + private static JwtAuthenticationToken createJwtAuthenticationToken(String tokenValue) { + Instant now = Instant.now(); + // @formatter:off + Jwt jwt = Jwt.withTokenValue(tokenValue) + .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) + .issuedAt(now) + .expiresAt(now.plusSeconds(300)) + .claim(StandardClaimNames.SUB, "user") + .build(); + // @formatter:on + return new JwtAuthenticationToken(jwt, Collections.emptyList()); + } + + private static OidcUserInfo createUserInfo() { + return OidcUserInfo.builder() + .subject("user1") + .name("First Last") + .givenName("First") + .familyName("Last") + .middleName("Middle") + .nickname("User") + .preferredUsername("user") + .profile("https://example.com/user1") + .picture("https://example.com/user1.jpg") + .website("https://example.com") + .email("user1@example.com") + .emailVerified(true) + .gender("female") + .birthdate("1970-01-01") + .zoneinfo("Europe/Paris") + .locale("en-US") + .phoneNumber("+1 (604) 555-1234;ext=5678") + .phoneNumberVerified("false") + .address("Champ de Mars\n5 Av. Anatole France\n75007 Paris\nFrance") + .updatedAt("1970-01-01T00:00:00Z") + .build(); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationTokenTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationTokenTests.java new file mode 100644 index 00000000..ab3b9128 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationTokenTests.java @@ -0,0 +1,60 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.util.Collections; + +import org.junit.Test; + +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link OidcUserInfoAuthenticationToken}. + * + * @author Steve Riesenberg + */ +public class OidcUserInfoAuthenticationTokenTests { + @Test + public void constructorWhenPrincipalNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcUserInfoAuthenticationToken(null)) + .withMessage("principal cannot be null"); + } + + @Test + public void constructorWhenPrincipalProvidedThenCreated() { + UsernamePasswordAuthenticationToken principal = new UsernamePasswordAuthenticationToken(null, null); + OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal); + assertThat(authentication.getPrincipal()).isEqualTo(principal); + assertThat(authentication.getUserInfo()).isNull(); + assertThat(authentication.isAuthenticated()).isFalse(); + } + + @Test + public void constructorWhenPrincipalAndUserInfoProvidedThenCreated() { + UsernamePasswordAuthenticationToken principal = new UsernamePasswordAuthenticationToken(null, null); + OidcUserInfo userInfo = new OidcUserInfo(Collections.singletonMap(StandardClaimNames.SUB, "user")); + OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal, userInfo); + assertThat(authentication.getPrincipal()).isEqualTo(principal); + assertThat(authentication.getUserInfo()).isEqualTo(userInfo); + assertThat(authentication.isAuthenticated()).isFalse(); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilterTests.java new file mode 100644 index 00000000..81d0ce26 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilterTests.java @@ -0,0 +1,254 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.web; + +import java.time.Instant; +import java.util.Collections; + +import javax.servlet.FilterChain; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeaderNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OidcUserInfoEndpointFilter}. + * + * @author Steve Riesenberg + */ +public class OidcUserInfoEndpointFilterTests { + private static final String DEFAULT_OIDC_USER_INFO_ENDPOINT_URI = "/userinfo"; + private final HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); + + @Test + public void constructorWhenAuthenticationManagerNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcUserInfoEndpointFilter(null)) + .withMessage("authenticationManager cannot be null"); + } + + @Test + public void constructorWhenUserInfoEndpointUriIsEmptyThenThrowIllegalArgumentException() { + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcUserInfoEndpointFilter(authenticationManager, "")) + .withMessage("userInfoEndpointUri cannot be empty"); + } + + @Test + public void doFilterWhenNotUserInfoRequestThenNotProcessed() throws Exception { + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + OidcUserInfoEndpointFilter userInfoEndpointFilter = + new OidcUserInfoEndpointFilter(authenticationManager, DEFAULT_OIDC_USER_INFO_ENDPOINT_URI); + + String requestUri = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + userInfoEndpointFilter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(request, response); + } + + @Test + public void doFilterWhenUserInfoRequestPutThenNotProcessed() throws Exception { + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + OidcUserInfoEndpointFilter userInfoEndpointFilter = + new OidcUserInfoEndpointFilter(authenticationManager, DEFAULT_OIDC_USER_INFO_ENDPOINT_URI); + + String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("PUT", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + userInfoEndpointFilter.doFilter(request, response, filterChain); + + verifyNoInteractions(authenticationManager); + verify(filterChain).doFilter(request, response); + } + + @Test + public void doFilterWhenUserInfoRequestGetThenSuccess() throws Exception { + JwtAuthenticationToken principal = createJwtAuthenticationToken(); + SecurityContextHolder.getContext().setAuthentication(principal); + + OidcUserInfoAuthenticationToken authenticationResult = new OidcUserInfoAuthenticationToken(principal, createUserInfo()); + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + when(authenticationManager.authenticate(any())).thenReturn(authenticationResult); + OidcUserInfoEndpointFilter userInfoEndpointFilter = new OidcUserInfoEndpointFilter(authenticationManager); + + String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + userInfoEndpointFilter.doFilter(request, response, filterChain); + + verify(authenticationManager).authenticate(any()); + verifyNoInteractions(filterChain); + + assertThat(response.getContentType()).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertUserInfoResponse(response.getContentAsString()); + } + + @Test + public void doFilterWhenUserInfoRequestPostThenSuccess() throws Exception { + JwtAuthenticationToken principal = createJwtAuthenticationToken(); + SecurityContextHolder.getContext().setAuthentication(principal); + + OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal, createUserInfo()); + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + when(authenticationManager.authenticate(any())).thenReturn(authentication); + OidcUserInfoEndpointFilter userInfoEndpointFilter = new OidcUserInfoEndpointFilter(authenticationManager); + + String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + userInfoEndpointFilter.doFilter(request, response, filterChain); + + verify(authenticationManager).authenticate(any()); + verifyNoInteractions(filterChain); + + assertThat(response.getContentType()).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertUserInfoResponse(response.getContentAsString()); + } + + @Test + public void doFilterWhenAuthenticationNullThenInvalidRequestError() throws Exception { + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + when(authenticationManager.authenticate(any(Authentication.class))) + .thenReturn(new UsernamePasswordAuthenticationToken("user", "password")); + OidcUserInfoEndpointFilter userInfoEndpointFilter = new OidcUserInfoEndpointFilter(authenticationManager); + + String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.addHeader(HttpHeaders.AUTHORIZATION, "Bearer token"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + userInfoEndpointFilter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + OAuth2Error error = readError(response); + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + assertThat(error.getDescription()).isEqualTo("OpenID Connect 1.0 UserInfo Error: principal cannot be null"); + } + + private OAuth2Error readError(MockHttpServletResponse response) throws Exception { + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); + return this.errorHttpResponseConverter.read(OAuth2Error.class, httpResponse); + } + + private JwtAuthenticationToken createJwtAuthenticationToken() { + Instant now = Instant.now(); + // @formatter:off + Jwt jwt = Jwt.withTokenValue("token") + .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) + .issuedAt(now) + .expiresAt(now.plusSeconds(300)) + .claim(StandardClaimNames.SUB, "user") + .build(); + // @formatter:on + return new JwtAuthenticationToken(jwt, Collections.emptyList()); + } + + private static OidcUserInfo createUserInfo() { + return OidcUserInfo.builder() + .subject("user1") + .name("First Last") + .givenName("First") + .familyName("Last") + .middleName("Middle") + .nickname("User") + .preferredUsername("user") + .profile("https://example.com/user1") + .picture("https://example.com/user1.jpg") + .website("https://example.com") + .email("user1@example.com") + .emailVerified(true) + .gender("female") + .birthdate("1970-01-01") + .zoneinfo("Europe/Paris") + .locale("en-US") + .phoneNumber("+1 (604) 555-1234;ext=5678") + .phoneNumberVerified("false") + .address("Champ de Mars\n5 Av. Anatole France\n75007 Paris\nFrance") + .updatedAt("1970-01-01T00:00:00Z") + .build(); + } + + private static void assertUserInfoResponse(String userInfoResponse) { + assertThat(userInfoResponse).contains("\"sub\":\"user1\""); + assertThat(userInfoResponse).contains("\"name\":\"First Last\""); + assertThat(userInfoResponse).contains("\"given_name\":\"First\""); + assertThat(userInfoResponse).contains("\"family_name\":\"Last\""); + assertThat(userInfoResponse).contains("\"middle_name\":\"Middle\""); + assertThat(userInfoResponse).contains("\"nickname\":\"User\""); + assertThat(userInfoResponse).contains("\"preferred_username\":\"user\""); + assertThat(userInfoResponse).contains("\"profile\":\"https://example.com/user1\""); + assertThat(userInfoResponse).contains("\"picture\":\"https://example.com/user1.jpg\""); + assertThat(userInfoResponse).contains("\"website\":\"https://example.com\""); + assertThat(userInfoResponse).contains("\"email\":\"user1@example.com\""); + assertThat(userInfoResponse).contains("\"email_verified\":true"); + assertThat(userInfoResponse).contains("\"gender\":\"female\""); + assertThat(userInfoResponse).contains("\"birthdate\":\"1970-01-01\""); + assertThat(userInfoResponse).contains("\"zoneinfo\":\"Europe/Paris\""); + assertThat(userInfoResponse).contains("\"locale\":\"en-US\""); + assertThat(userInfoResponse).contains("\"phone_number\":\"+1 (604) 555-1234;ext=5678\""); + assertThat(userInfoResponse).contains("\"phone_number_verified\":\"false\""); + assertThat(userInfoResponse).contains("\"address\":\"Champ de Mars\\n5 Av. Anatole France\\n75007 Paris\\nFrance\""); + assertThat(userInfoResponse).contains("\"updated_at\":\"1970-01-01T00:00:00Z\""); + } +}