From 1da383b3608064bffa05d8d265a30ecc0bd7bbaa Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Mon, 5 Aug 2024 10:39:26 -0600 Subject: [PATCH] Add OpenSAML 5 Support Issue gh-11658 --- .../saml2/Saml2LoginConfigurer.java | 60 +- .../saml2/Saml2LogoutConfigurer.java | 61 +- .../saml2/Saml2MetadataConfigurer.java | 14 + .../Saml2LoginBeanDefinitionParserUtils.java | 22 +- .../Saml2LogoutBeanDefinitionParserUtils.java | 39 +- .../ROOT/pages/servlet/saml2/opensaml.adoc | 0 gradle/libs.versions.toml | 5 + ...ing-security-saml2-service-provider.gradle | 41 +- .../BaseOpenSamlAuthenticationProvider.java | 21 +- ...nSamlAssertingPartyMetadataRepository.java | 168 +--- .../registration/OpenSamlMetadataUtils.java | 67 +- ...Saml4AssertingPartyMetadataRepository.java | 165 ++- .../OpenSaml4AuthenticationProviderTests.java | 4 +- .../TestCustomOpenSaml4Objects.java} | 10 +- .../saml2/internal/OpenSaml5Template.java | 617 ++++++++++++ .../OpenSaml5AuthenticationProvider.java | 496 +++++++++ .../authentication/OpenSaml5Template.java | 617 ++++++++++++ .../OpenSaml5LogoutRequestValidator.java | 36 + .../OpenSaml5LogoutResponseValidator.java | 36 + .../logout/OpenSaml5Template.java | 617 ++++++++++++ .../authentication/logout/Saml2Utils.java | 196 ++++ .../metadata/OpenSaml5MetadataResolver.java | 117 +++ .../service/metadata/OpenSaml5Template.java | 617 ++++++++++++ ...Saml5AssertingPartyMetadataRepository.java | 318 ++++++ .../registration/OpenSaml5Template.java | 617 ++++++++++++ ...OpenSaml5AuthenticationTokenConverter.java | 104 ++ .../service/web/OpenSaml5Template.java | 617 ++++++++++++ ...penSaml5AuthenticationRequestResolver.java | 144 +++ .../web/authentication/OpenSaml5Template.java | 617 ++++++++++++ .../OpenSaml5LogoutRequestResolver.java | 142 +++ ...outRequestValidatorParametersResolver.java | 100 ++ .../OpenSaml5LogoutResponseResolver.java | 130 +++ .../logout/OpenSaml5Template.java | 617 ++++++++++++ .../OpenSaml5AuthenticationProviderTests.java | 945 ++++++++++++++++++ .../TestCustomOpenSaml5Objects.java | 213 ++++ .../OpenSaml5LogoutRequestValidatorTests.java | 223 +++++ ...OpenSaml5LogoutResponseValidatorTests.java | 190 ++++ .../OpenSaml5MetadataResolverTests.java | 185 ++++ ...AssertingPartyMetadataRepositoryTests.java | 379 +++++++ ...aml5AuthenticationTokenConverterTests.java | 246 +++++ ...ml5AuthenticationRequestResolverTests.java | 110 ++ .../OpenSaml5SigningUtilsTests.java | 93 ++ .../OpenSaml5LogoutRequestResolverTests.java | 94 ++ ...questValidatorParametersResolverTests.java | 153 +++ .../OpenSaml5LogoutResponseResolverTests.java | 80 ++ 45 files changed, 10066 insertions(+), 277 deletions(-) create mode 100644 docs/modules/ROOT/pages/servlet/saml2/opensaml.adoc rename saml2/saml2-service-provider/src/{test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObjects.java => opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSaml4Objects.java} (95%) create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/internal/OpenSaml5Template.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml5AuthenticationProvider.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml5Template.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutRequestValidator.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutResponseValidator.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5Template.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/Saml2Utils.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/metadata/OpenSaml5MetadataResolver.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/metadata/OpenSaml5Template.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepository.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5Template.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverter.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/OpenSaml5Template.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolver.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5Template.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestResolver.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolver.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutResponseResolver.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5Template.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml5AuthenticationProviderTests.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSaml5Objects.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutRequestValidatorTests.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutResponseValidatorTests.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/metadata/OpenSaml5MetadataResolverTests.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepositoryTests.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverterTests.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5SigningUtilsTests.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestResolverTests.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolverTests.java create mode 100644 saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutResponseResolverTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index d5cd86550a..56b6052756 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import jakarta.servlet.http.HttpServletRequest; +import org.opensaml.core.Version; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; @@ -35,15 +36,19 @@ import org.springframework.security.config.annotation.web.configurers.CsrfConfig import org.springframework.security.core.Authentication; import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider; +import org.springframework.security.saml2.provider.service.authentication.OpenSaml5AuthenticationProvider; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.OpenSaml4AuthenticationTokenConverter; +import org.springframework.security.saml2.provider.service.web.OpenSaml5AuthenticationTokenConverter; +import org.springframework.security.saml2.provider.service.web.OpenSamlAuthenticationTokenConverter; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationRequestFilter; import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver; +import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml5AuthenticationRequestResolver; import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver; import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter; import org.springframework.security.web.AuthenticationEntryPoint; @@ -115,6 +120,8 @@ import org.springframework.util.StringUtils; public final class Saml2LoginConfigurer> extends AbstractAuthenticationFilterConfigurer, Saml2WebSsoAuthenticationFilter> { + private static final boolean USE_OPENSAML_5 = Version.getVersion().startsWith("5"); + private String loginPage; private String authenticationRequestUri = "/saml2/authenticate"; @@ -366,10 +373,18 @@ public final class Saml2LoginConfigurer> if (bean != null) { return bean; } - OpenSaml4AuthenticationRequestResolver openSaml4AuthenticationRequestResolver = new OpenSaml4AuthenticationRequestResolver( - relyingPartyRegistrationRepository(http)); - openSaml4AuthenticationRequestResolver.setRequestMatcher(this.authenticationRequestMatcher); - return openSaml4AuthenticationRequestResolver; + if (USE_OPENSAML_5) { + OpenSaml5AuthenticationRequestResolver openSamlAuthenticationRequestResolver = new OpenSaml5AuthenticationRequestResolver( + relyingPartyRegistrationRepository(http)); + openSamlAuthenticationRequestResolver.setRequestMatcher(this.authenticationRequestMatcher); + return openSamlAuthenticationRequestResolver; + } + else { + OpenSaml4AuthenticationRequestResolver openSamlAuthenticationRequestResolver = new OpenSaml4AuthenticationRequestResolver( + relyingPartyRegistrationRepository(http)); + openSamlAuthenticationRequestResolver.setRequestMatcher(this.authenticationRequestMatcher); + return openSamlAuthenticationRequestResolver; + } } private AuthenticationConverter getAuthenticationConverter(B http) { @@ -379,22 +394,45 @@ public final class Saml2LoginConfigurer> AuthenticationConverter authenticationConverterBean = getBeanOrNull(http, Saml2AuthenticationTokenConverter.class); if (authenticationConverterBean == null) { - authenticationConverterBean = getBeanOrNull(http, OpenSaml4AuthenticationTokenConverter.class); + authenticationConverterBean = getBeanOrNull(http, OpenSamlAuthenticationTokenConverter.class); } - if (authenticationConverterBean == null) { - OpenSaml4AuthenticationTokenConverter converter = new OpenSaml4AuthenticationTokenConverter( + if (authenticationConverterBean != null) { + return authenticationConverterBean; + } + if (USE_OPENSAML_5) { + authenticationConverterBean = getBeanOrNull(http, OpenSaml5AuthenticationTokenConverter.class); + if (authenticationConverterBean != null) { + return authenticationConverterBean; + } + OpenSaml5AuthenticationTokenConverter converter = new OpenSaml5AuthenticationTokenConverter( this.relyingPartyRegistrationRepository); converter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http)); converter.setRequestMatcher(this.loginProcessingUrl); return converter; } - return authenticationConverterBean; + authenticationConverterBean = getBeanOrNull(http, OpenSaml4AuthenticationTokenConverter.class); + if (authenticationConverterBean != null) { + return authenticationConverterBean; + } + OpenSaml4AuthenticationTokenConverter converter = new OpenSaml4AuthenticationTokenConverter( + this.relyingPartyRegistrationRepository); + converter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http)); + converter.setRequestMatcher(this.loginProcessingUrl); + return converter; } private void registerDefaultAuthenticationProvider(B http) { - OpenSaml4AuthenticationProvider provider = getBeanOrNull(http, OpenSaml4AuthenticationProvider.class); - if (provider == null) { - http.authenticationProvider(postProcess(new OpenSaml4AuthenticationProvider())); + if (USE_OPENSAML_5) { + OpenSaml5AuthenticationProvider provider = getBeanOrNull(http, OpenSaml5AuthenticationProvider.class); + if (provider == null) { + http.authenticationProvider(postProcess(new OpenSaml5AuthenticationProvider())); + } + } + else { + OpenSaml4AuthenticationProvider provider = getBeanOrNull(http, OpenSaml4AuthenticationProvider.class); + if (provider == null) { + http.authenticationProvider(postProcess(new OpenSaml4AuthenticationProvider())); + } } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java index d746732e5b..d3e5dd912b 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.List; import jakarta.servlet.http.HttpServletRequest; +import org.opensaml.core.Version; import org.springframework.context.ApplicationContext; import org.springframework.security.authentication.AuthenticationManager; @@ -33,17 +34,23 @@ import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal; import org.springframework.security.saml2.provider.service.authentication.logout.OpenSaml4LogoutRequestValidator; import org.springframework.security.saml2.provider.service.authentication.logout.OpenSaml4LogoutResponseValidator; +import org.springframework.security.saml2.provider.service.authentication.logout.OpenSaml5LogoutRequestValidator; +import org.springframework.security.saml2.provider.service.authentication.logout.OpenSaml5LogoutResponseValidator; import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidator; import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutResponseValidator; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.web.authentication.logout.HttpSessionLogoutRequestRepository; import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutRequestResolver; -import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutResponseResolver; import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutRequestValidatorParametersResolver; +import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutResponseResolver; +import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml5LogoutRequestResolver; +import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml5LogoutRequestValidatorParametersResolver; +import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml5LogoutResponseResolver; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestFilter; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestRepository; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestResolver; +import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestValidatorParametersResolver; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutResponseFilter; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutResponseResolver; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2RelyingPartyInitiatedLogoutSuccessHandler; @@ -106,6 +113,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher; public final class Saml2LogoutConfigurer> extends AbstractHttpConfigurer, H> { + private static final boolean USE_OPENSAML_5 = Version.getVersion().startsWith("5"); + private ApplicationContext context; private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; @@ -250,14 +259,26 @@ public final class Saml2LogoutConfigurer> RelyingPartyRegistrationRepository registrations) { LogoutHandler[] logoutHandlers = this.logoutHandlers.toArray(new LogoutHandler[0]); Saml2LogoutResponseResolver logoutResponseResolver = createSaml2LogoutResponseResolver(registrations); + Saml2LogoutRequestFilter filter = new Saml2LogoutRequestFilter( + createSaml2LogoutResponseParametersResolver(registrations), + this.logoutRequestConfigurer.logoutRequestValidator(), logoutResponseResolver, logoutHandlers); + filter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy()); + return postProcess(filter); + } + + private Saml2LogoutRequestValidatorParametersResolver createSaml2LogoutResponseParametersResolver( + RelyingPartyRegistrationRepository registrations) { RequestMatcher requestMatcher = createLogoutRequestMatcher(); + if (USE_OPENSAML_5) { + OpenSaml5LogoutRequestValidatorParametersResolver parameters = new OpenSaml5LogoutRequestValidatorParametersResolver( + registrations); + parameters.setRequestMatcher(requestMatcher); + return parameters; + } OpenSaml4LogoutRequestValidatorParametersResolver parameters = new OpenSaml4LogoutRequestValidatorParametersResolver( registrations); parameters.setRequestMatcher(requestMatcher); - Saml2LogoutRequestFilter filter = new Saml2LogoutRequestFilter(parameters, - this.logoutRequestConfigurer.logoutRequestValidator(), logoutResponseResolver, logoutHandlers); - filter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy()); - return postProcess(filter); + return parameters; } private Saml2LogoutResponseFilter createLogoutResponseProcessingFilter( @@ -397,16 +418,22 @@ public final class Saml2LogoutConfigurer> } private Saml2LogoutRequestValidator logoutRequestValidator() { - if (this.logoutRequestValidator == null) { - return new OpenSaml4LogoutRequestValidator(); + if (this.logoutRequestValidator != null) { + return this.logoutRequestValidator; } - return this.logoutRequestValidator; + if (USE_OPENSAML_5) { + return new OpenSaml5LogoutRequestValidator(); + } + return new OpenSaml4LogoutRequestValidator(); } private Saml2LogoutRequestResolver logoutRequestResolver(RelyingPartyRegistrationRepository registrations) { if (this.logoutRequestResolver != null) { return this.logoutRequestResolver; } + if (USE_OPENSAML_5) { + return new OpenSaml5LogoutRequestResolver(registrations); + } return new OpenSaml4LogoutRequestResolver(registrations); } @@ -473,17 +500,23 @@ public final class Saml2LogoutConfigurer> } private Saml2LogoutResponseValidator logoutResponseValidator() { - if (this.logoutResponseValidator == null) { - return new OpenSaml4LogoutResponseValidator(); + if (this.logoutResponseValidator != null) { + return this.logoutResponseValidator; } - return this.logoutResponseValidator; + if (USE_OPENSAML_5) { + return new OpenSaml5LogoutResponseValidator(); + } + return new OpenSaml4LogoutResponseValidator(); } private Saml2LogoutResponseResolver logoutResponseResolver(RelyingPartyRegistrationRepository registrations) { - if (this.logoutResponseResolver == null) { - return new OpenSaml4LogoutResponseResolver(registrations); + if (this.logoutResponseResolver != null) { + return this.logoutResponseResolver; + } + if (USE_OPENSAML_5) { + return new OpenSaml5LogoutResponseResolver(registrations); } - return this.logoutResponseResolver; + return new OpenSaml4LogoutResponseResolver(registrations); } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2MetadataConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2MetadataConfigurer.java index e2b9d66360..baf82dc635 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2MetadataConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2MetadataConfigurer.java @@ -18,11 +18,14 @@ package org.springframework.security.config.annotation.web.configurers.saml2; import java.util.function.Function; +import org.opensaml.core.Version; + import org.springframework.context.ApplicationContext; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; import org.springframework.security.saml2.provider.service.metadata.OpenSaml4MetadataResolver; +import org.springframework.security.saml2.provider.service.metadata.OpenSaml5MetadataResolver; import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResponseResolver; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; @@ -73,6 +76,8 @@ import org.springframework.util.Assert; public class Saml2MetadataConfigurer> extends AbstractHttpConfigurer, H> { + private static final boolean USE_OPENSAML_5 = Version.getVersion().startsWith("5"); + private final ApplicationContext context; private Function metadataResponseResolver; @@ -103,6 +108,12 @@ public class Saml2MetadataConfigurer> public Saml2MetadataConfigurer metadataUrl(String metadataUrl) { Assert.hasText(metadataUrl, "metadataUrl cannot be empty"); this.metadataResponseResolver = (registrations) -> { + if (USE_OPENSAML_5) { + RequestMatcherMetadataResponseResolver metadata = new RequestMatcherMetadataResponseResolver( + registrations, new OpenSaml5MetadataResolver()); + metadata.setRequestMatcher(new AntPathRequestMatcher(metadataUrl)); + return metadata; + } RequestMatcherMetadataResponseResolver metadata = new RequestMatcherMetadataResponseResolver(registrations, new OpenSaml4MetadataResolver()); metadata.setRequestMatcher(new AntPathRequestMatcher(metadataUrl)); @@ -143,6 +154,9 @@ public class Saml2MetadataConfigurer> return metadataResponseResolver; } RelyingPartyRegistrationRepository registrations = getRelyingPartyRegistrationRepository(http); + if (USE_OPENSAML_5) { + return new RequestMatcherMetadataResponseResolver(registrations, new OpenSaml5MetadataResolver()); + } return new RequestMatcherMetadataResponseResolver(registrations, new OpenSaml4MetadataResolver()); } diff --git a/config/src/main/java/org/springframework/security/config/http/Saml2LoginBeanDefinitionParserUtils.java b/config/src/main/java/org/springframework/security/config/http/Saml2LoginBeanDefinitionParserUtils.java index 530aba60f8..17d5469d05 100644 --- a/config/src/main/java/org/springframework/security/config/http/Saml2LoginBeanDefinitionParserUtils.java +++ b/config/src/main/java/org/springframework/security/config/http/Saml2LoginBeanDefinitionParserUtils.java @@ -16,6 +16,7 @@ package org.springframework.security.config.http; +import org.opensaml.core.Version; import org.w3c.dom.Element; import org.springframework.beans.BeanMetadataElement; @@ -23,10 +24,14 @@ import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider; +import org.springframework.security.saml2.provider.service.authentication.OpenSaml5AuthenticationProvider; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter; +import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver; +import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml5AuthenticationRequestResolver; import org.springframework.util.StringUtils; /** @@ -35,6 +40,8 @@ import org.springframework.util.StringUtils; */ final class Saml2LoginBeanDefinitionParserUtils { + private static final boolean USE_OPENSAML_5 = Version.getVersion().startsWith("5"); + private static final String ATT_RELYING_PARTY_REGISTRATION_REPOSITORY_REF = "relying-party-registration-repository-ref"; private static final String ATT_AUTHENTICATION_REQUEST_REPOSITORY_REF = "authentication-request-repository-ref"; @@ -78,16 +85,21 @@ final class Saml2LoginBeanDefinitionParserUtils { .rootBeanDefinition(DefaultRelyingPartyRegistrationResolver.class) .addConstructorArgValue(relyingPartyRegistrationRepository) .getBeanDefinition(); - return BeanDefinitionBuilder.rootBeanDefinition( - "org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver") + if (USE_OPENSAML_5) { + return BeanDefinitionBuilder.rootBeanDefinition(OpenSaml5AuthenticationRequestResolver.class) + .addConstructorArgValue(defaultRelyingPartyRegistrationResolver) + .getBeanDefinition(); + } + return BeanDefinitionBuilder.rootBeanDefinition(OpenSaml4AuthenticationRequestResolver.class) .addConstructorArgValue(defaultRelyingPartyRegistrationResolver) .getBeanDefinition(); } static BeanDefinition createAuthenticationProvider() { - return BeanDefinitionBuilder.rootBeanDefinition( - "org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider") - .getBeanDefinition(); + if (USE_OPENSAML_5) { + return BeanDefinitionBuilder.rootBeanDefinition(OpenSaml5AuthenticationProvider.class).getBeanDefinition(); + } + return BeanDefinitionBuilder.rootBeanDefinition(OpenSaml4AuthenticationProvider.class).getBeanDefinition(); } static BeanMetadataElement getAuthenticationConverter(Element element) { diff --git a/config/src/main/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserUtils.java b/config/src/main/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserUtils.java index c7cb1792d5..647add7e2c 100644 --- a/config/src/main/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserUtils.java +++ b/config/src/main/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserUtils.java @@ -16,15 +16,22 @@ package org.springframework.security.config.http; +import org.opensaml.core.Version; import org.w3c.dom.Element; import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.BeanDefinitionBuilder; -import org.springframework.security.saml2.provider.service.authentication.logout.OpenSamlLogoutRequestValidator; -import org.springframework.security.saml2.provider.service.authentication.logout.OpenSamlLogoutResponseValidator; +import org.springframework.security.saml2.provider.service.authentication.logout.OpenSaml4LogoutRequestValidator; +import org.springframework.security.saml2.provider.service.authentication.logout.OpenSaml4LogoutResponseValidator; +import org.springframework.security.saml2.provider.service.authentication.logout.OpenSaml5LogoutRequestValidator; +import org.springframework.security.saml2.provider.service.authentication.logout.OpenSaml5LogoutResponseValidator; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.web.authentication.logout.HttpSessionLogoutRequestRepository; +import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutRequestResolver; +import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutResponseResolver; +import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml5LogoutRequestResolver; +import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml5LogoutResponseResolver; import org.springframework.util.StringUtils; /** @@ -33,6 +40,8 @@ import org.springframework.util.StringUtils; */ final class Saml2LogoutBeanDefinitionParserUtils { + private static final boolean USE_OPENSAML_5 = Version.getVersion().startsWith("5"); + private static final String ATT_RELYING_PARTY_REGISTRATION_REPOSITORY_REF = "relying-party-registration-repository-ref"; private static final String ATT_LOGOUT_REQUEST_VALIDATOR_REF = "logout-request-validator-ref"; @@ -62,8 +71,12 @@ final class Saml2LogoutBeanDefinitionParserUtils { if (StringUtils.hasText(logoutResponseResolver)) { return new RuntimeBeanReference(logoutResponseResolver); } - return BeanDefinitionBuilder.rootBeanDefinition( - "org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutResponseResolver") + if (USE_OPENSAML_5) { + return BeanDefinitionBuilder.rootBeanDefinition(OpenSaml5LogoutResponseResolver.class) + .addConstructorArgValue(registrations) + .getBeanDefinition(); + } + return BeanDefinitionBuilder.rootBeanDefinition(OpenSaml4LogoutResponseResolver.class) .addConstructorArgValue(registrations) .getBeanDefinition(); } @@ -73,7 +86,10 @@ final class Saml2LogoutBeanDefinitionParserUtils { if (StringUtils.hasText(logoutRequestValidator)) { return new RuntimeBeanReference(logoutRequestValidator); } - return BeanDefinitionBuilder.rootBeanDefinition(OpenSamlLogoutRequestValidator.class).getBeanDefinition(); + if (USE_OPENSAML_5) { + return BeanDefinitionBuilder.rootBeanDefinition(OpenSaml5LogoutRequestValidator.class).getBeanDefinition(); + } + return BeanDefinitionBuilder.rootBeanDefinition(OpenSaml4LogoutRequestValidator.class).getBeanDefinition(); } static BeanMetadataElement getLogoutResponseValidator(Element element) { @@ -81,7 +97,10 @@ final class Saml2LogoutBeanDefinitionParserUtils { if (StringUtils.hasText(logoutResponseValidator)) { return new RuntimeBeanReference(logoutResponseValidator); } - return BeanDefinitionBuilder.rootBeanDefinition(OpenSamlLogoutResponseValidator.class).getBeanDefinition(); + if (USE_OPENSAML_5) { + return BeanDefinitionBuilder.rootBeanDefinition(OpenSaml5LogoutResponseValidator.class).getBeanDefinition(); + } + return BeanDefinitionBuilder.rootBeanDefinition(OpenSaml4LogoutResponseValidator.class).getBeanDefinition(); } static BeanMetadataElement getLogoutRequestRepository(Element element) { @@ -97,8 +116,12 @@ final class Saml2LogoutBeanDefinitionParserUtils { if (StringUtils.hasText(logoutRequestResolver)) { return new RuntimeBeanReference(logoutRequestResolver); } - return BeanDefinitionBuilder.rootBeanDefinition( - "org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutRequestResolver") + if (USE_OPENSAML_5) { + return BeanDefinitionBuilder.rootBeanDefinition(OpenSaml5LogoutRequestResolver.class) + .addConstructorArgValue(registrations) + .getBeanDefinition(); + } + return BeanDefinitionBuilder.rootBeanDefinition(OpenSaml4LogoutRequestResolver.class) .addConstructorArgValue(registrations) .getBeanDefinition(); } diff --git a/docs/modules/ROOT/pages/servlet/saml2/opensaml.adoc b/docs/modules/ROOT/pages/servlet/saml2/opensaml.adoc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index c16d069519..8b33b5b892 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -13,6 +13,7 @@ org-jetbrains-kotlin = "1.9.25" org-jetbrains-kotlinx = "1.8.1" org-mockito = "5.11.0" org-opensaml = "4.3.2" +org-opensaml5 = "5.1.2" org-springframework = "6.2.0-M6" [libraries] @@ -79,6 +80,10 @@ org-mockito-mockito-bom = { module = "org.mockito:mockito-bom", version.ref = "o org-opensaml-opensaml-core = { module = "org.opensaml:opensaml-core", version.ref = "org-opensaml" } org-opensaml-opensaml-saml-api = { module = "org.opensaml:opensaml-saml-api", version.ref = "org-opensaml" } org-opensaml-opensaml-saml-impl = { module = "org.opensaml:opensaml-saml-impl", version.ref = "org-opensaml" } +org-opensaml-opensaml5-core-api = { module = "org.opensaml:opensaml-core-api", version.ref = "org-opensaml5" } +org-opensaml-opensaml5-core-impl = { module = "org.opensaml:opensaml-core-impl", version.ref = "org-opensaml5" } +org-opensaml-opensaml5-saml-api = { module = "org.opensaml:opensaml-saml-api", version.ref = "org-opensaml5" } +org-opensaml-opensaml5-saml-impl = { module = "org.opensaml:opensaml-saml-impl", version.ref = "org-opensaml5" } org-python-jython = { module = "org.python:jython", version = "2.5.3" } org-seleniumhq-selenium-htmlunit-driver = "org.seleniumhq.selenium:htmlunit3-driver:4.20.0" org-seleniumhq-selenium-selenium-java = "org.seleniumhq.selenium:selenium-java:4.20.0" diff --git a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle index a2107e1774..4b538a037f 100644 --- a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle +++ b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle @@ -2,7 +2,9 @@ apply plugin: 'io.spring.convention.spring-module' configurations { opensaml4Main { extendsFrom(optional, provided) } + opensaml5Main { extendsFrom(optional, provided) } opensaml4Test { extendsFrom(opensaml4Main, tests) } + opensaml5Test { extendsFrom(opensaml5Main, tests) } } sourceSets { @@ -12,12 +14,24 @@ sourceSets { srcDir 'src/opensaml4Main/java' } } + opensaml5Main { + java { + compileClasspath += main.output + configurations.opensaml5Main + srcDir 'src/opensaml5Main/java' + } + } opensaml4Test { java { compileClasspath += main.output + test.output + opensaml4Main.output + configurations.opensaml4Test srcDir 'src/opensaml4Test/java' } } + opensaml5Test { + java { + compileClasspath += main.output + test.output + opensaml5Main.output + configurations.opensaml5Test + srcDir 'src/opensaml5Test/java' + } + } } sourceSets.configureEach { set -> @@ -74,11 +88,20 @@ sourceSets.configureEach { set -> dependencies { management platform(project(":spring-security-dependencies")) api project(':spring-security-web') - api "org.opensaml:opensaml-core" - api ("org.opensaml:opensaml-saml-api") { + api 'org.opensaml:opensaml-core' + api ('org.opensaml:opensaml-saml-api') { exclude group: 'commons-logging', module: 'commons-logging' } - api ("org.opensaml:opensaml-saml-impl") { + api ('org.opensaml:opensaml-saml-impl') { + exclude group: 'commons-logging', module: 'commons-logging' + } + + opensaml5MainImplementation libs.org.opensaml.opensaml5.core.api + opensaml5MainImplementation libs.org.opensaml.opensaml5.core.impl + opensaml5MainImplementation (libs.org.opensaml.opensaml5.saml.api) { + exclude group: 'commons-logging', module: 'commons-logging' + } + opensaml5MainImplementation (libs.org.opensaml.opensaml5.saml.impl) { exclude group: 'commons-logging', module: 'commons-logging' } @@ -100,21 +123,24 @@ dependencies { jar { duplicatesStrategy = DuplicatesStrategy.EXCLUDE from sourceSets.opensaml4Main.output + from sourceSets.opensaml5Main.output } sourcesJar { duplicatesStrategy = DuplicatesStrategy.EXCLUDE from sourceSets.opensaml4Main.allJava + from sourceSets.opensaml5Main.allJava } testJar { duplicatesStrategy = DuplicatesStrategy.EXCLUDE from sourceSets.opensaml4Test.output + from sourceSets.opensaml5Test.output } javadoc { - classpath += sourceSets.opensaml4Main.runtimeClasspath - source += sourceSets.opensaml4Main.allJava + classpath += sourceSets.opensaml4Main.runtimeClasspath + sourceSets.opensaml5Main.runtimeClasspath + source += sourceSets.opensaml4Main.allJava + sourceSets.opensaml5Main.allJava } tasks.register("opensaml4Test", Test) { @@ -122,8 +148,13 @@ tasks.register("opensaml4Test", Test) { classpath += sourceSets.opensaml4Test.output } +tasks.register("opensaml5Test", Test) { + useJUnitPlatform() + classpath += sourceSets.opensaml5Test.output +} tasks.named("test") { dependsOn opensaml4Test + dependsOn opensaml5Test } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/BaseOpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/BaseOpenSamlAuthenticationProvider.java index fa6035f95c..a4c76c0c66 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/BaseOpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/BaseOpenSamlAuthenticationProvider.java @@ -16,7 +16,6 @@ package org.springframework.security.saml2.provider.service.authentication; -import java.lang.reflect.Field; import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; @@ -85,7 +84,6 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider { @@ -510,22 +508,11 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider { return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message)); } String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(), - ((Response) assertion.getParent()).getID(), contextToString(context)); + ((Response) assertion.getParent()).getID(), context.getValidationFailureMessage()); return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message)); }; } - private static String contextToString(ValidationContext context) { - StringBuilder sb = new StringBuilder(); - for (Field field : context.getClass().getDeclaredFields()) { - ReflectionUtils.makeAccessible(field); - Object value = ReflectionUtils.getField(field, context); - sb.append(field.getName() + " = " + value + ","); - } - sb.deleteCharAt(sb.length() - 1); - return sb.toString(); - } - private static ValidationContext createValidationContext(AssertionToken assertionToken, Consumer> paramsConsumer) { Saml2AuthenticationToken token = assertionToken.token; @@ -566,7 +553,7 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider { return (serialized != null) ? serialized.getId() : null; } - private static class SAML20AssertionValidators { + static class SAML20AssertionValidators { private static final Collection conditions = new ArrayList<>(); @@ -610,8 +597,8 @@ class BaseOpenSamlAuthenticationProvider implements AuthenticationProvider { }); } - private static final SAML20AssertionValidator attributeValidator = new SAML20AssertionValidator(conditions, - subjects, statements, null, null, null) { + static final SAML20AssertionValidator attributeValidator = new SAML20AssertionValidator(conditions, subjects, + statements, null, null, null) { @Nonnull @Override protected ValidationResult validateSignature(Assertion token, ValidationContext context) { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/BaseOpenSamlAssertingPartyMetadataRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/BaseOpenSamlAssertingPartyMetadataRepository.java index 8de0b9b13a..e54b3789b9 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/BaseOpenSamlAssertingPartyMetadataRepository.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/BaseOpenSamlAssertingPartyMetadataRepository.java @@ -16,41 +16,21 @@ package org.springframework.security.saml2.provider.service.registration; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.net.URI; -import java.net.URL; -import java.util.ArrayList; -import java.util.Collection; import java.util.Iterator; import java.util.Set; -import java.util.function.Consumer; import java.util.function.Supplier; -import javax.annotation.Nonnull; - import org.opensaml.core.criterion.EntityIdCriterion; -import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.saml.criterion.EntityRoleCriterion; import org.opensaml.saml.metadata.IterableMetadataSource; import org.opensaml.saml.metadata.resolver.MetadataResolver; -import org.opensaml.saml.metadata.resolver.filter.impl.SignatureValidationFilter; import org.opensaml.saml.metadata.resolver.impl.AbstractBatchMetadataResolver; import org.opensaml.saml.metadata.resolver.impl.ResourceBackedMetadataResolver; import org.opensaml.saml.metadata.resolver.index.MetadataIndex; import org.opensaml.saml.metadata.resolver.index.impl.RoleMetadataIndex; import org.opensaml.saml.saml2.metadata.EntityDescriptor; import org.opensaml.saml.saml2.metadata.IDPSSODescriptor; -import org.opensaml.security.credential.Credential; -import org.opensaml.security.credential.impl.CollectionCredentialResolver; -import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; -import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; -import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; -import org.springframework.core.io.DefaultResourceLoader; -import org.springframework.core.io.Resource; -import org.springframework.core.io.ResourceLoader; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.security.saml2.Saml2Exception; @@ -150,148 +130,16 @@ class BaseOpenSamlAssertingPartyMetadataRepository implements AssertingPartyMeta } } - /** - * A builder class for configuring - * {@link BaseOpenSamlAssertingPartyMetadataRepository} for a specific metadata - * location. - * - * @author Josh Cummings - */ - static final class MetadataLocationRepositoryBuilder { - - private final String metadataLocation; - - private final boolean requireVerificationCredentials; - - private final Collection verificationCredentials = new ArrayList<>(); - - private ResourceLoader resourceLoader = new DefaultResourceLoader(); - - MetadataLocationRepositoryBuilder(String metadataLocation, boolean trusted) { - this.metadataLocation = metadataLocation; - this.requireVerificationCredentials = !trusted; - } - - MetadataLocationRepositoryBuilder verificationCredentials(Consumer> credentials) { - credentials.accept(this.verificationCredentials); - return this; - } - - MetadataLocationRepositoryBuilder resourceLoader(ResourceLoader resourceLoader) { - this.resourceLoader = resourceLoader; - return this; - } - - MetadataResolver metadataResolver() { - ResourceBackedMetadataResolver metadataResolver = resourceBackedMetadataResolver(); - if (!this.verificationCredentials.isEmpty()) { - SignatureTrustEngine engine = new ExplicitKeySignatureTrustEngine( - new CollectionCredentialResolver(this.verificationCredentials), - DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); - SignatureValidationFilter filter = new SignatureValidationFilter(engine); - filter.setRequireSignedRoot(true); - metadataResolver.setMetadataFilter(filter); - return initialize(metadataResolver); - } - Assert.isTrue(!this.requireVerificationCredentials, "Verification credentials are required"); - return initialize(metadataResolver); - } - - private ResourceBackedMetadataResolver resourceBackedMetadataResolver() { - Resource resource = this.resourceLoader.getResource(this.metadataLocation); - try { - return new ResourceBackedMetadataResolver(new SpringResource(resource)); - } - catch (IOException ex) { - throw new Saml2Exception(ex); - } - } - - private MetadataResolver initialize(ResourceBackedMetadataResolver metadataResolver) { - try { - metadataResolver.setId(this.getClass().getName() + ".metadataResolver"); - metadataResolver.setParserPool(XMLObjectProviderRegistrySupport.getParserPool()); - metadataResolver.setIndexes(Set.of(new RoleMetadataIndex())); - metadataResolver.initialize(); - return metadataResolver; - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } + static MetadataResolver initialize(ResourceBackedMetadataResolver metadataResolver) { + try { + metadataResolver.setId(BaseOpenSamlAssertingPartyMetadataRepository.class.getName() + ".metadataResolver"); + metadataResolver.setIndexes(Set.of(new RoleMetadataIndex())); + metadataResolver.initialize(); + return metadataResolver; } - - private static final class SpringResource implements net.shibboleth.utilities.java.support.resource.Resource { - - private final Resource resource; - - SpringResource(Resource resource) { - this.resource = resource; - } - - @Override - public boolean exists() { - return this.resource.exists(); - } - - @Override - public boolean isReadable() { - return this.resource.isReadable(); - } - - @Override - public boolean isOpen() { - return this.resource.isOpen(); - } - - @Override - public URL getURL() throws IOException { - return this.resource.getURL(); - } - - @Override - public URI getURI() throws IOException { - return this.resource.getURI(); - } - - @Override - public File getFile() throws IOException { - return this.resource.getFile(); - } - - @Nonnull - @Override - public InputStream getInputStream() throws IOException { - return this.resource.getInputStream(); - } - - @Override - public long contentLength() throws IOException { - return this.resource.contentLength(); - } - - @Override - public long lastModified() throws IOException { - return this.resource.lastModified(); - } - - @Override - public net.shibboleth.utilities.java.support.resource.Resource createRelativeResource(String relativePath) - throws IOException { - return new SpringResource(this.resource.createRelative(relativePath)); - } - - @Override - public String getFilename() { - return this.resource.getFilename(); - } - - @Override - public String getDescription() { - return this.resource.getDescription(); - } - + catch (Exception ex) { + throw new Saml2Exception(ex); } - } abstract static class MetadataResolverAdapter { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/OpenSamlMetadataUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/OpenSamlMetadataUtils.java index b7efc22c40..d6fb90cf71 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/OpenSamlMetadataUtils.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/OpenSamlMetadataUtils.java @@ -20,9 +20,11 @@ import java.io.InputStream; import java.util.Collection; import java.util.Collections; +import org.opensaml.core.Version; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.io.Unmarshaller; +import org.opensaml.core.xml.io.UnmarshallerFactory; import org.opensaml.saml.saml2.metadata.EntitiesDescriptor; import org.opensaml.saml.saml2.metadata.EntityDescriptor; import org.w3c.dom.Document; @@ -33,8 +35,27 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService; final class OpenSamlMetadataUtils { + private static final OpenSamlDeserializer saml; + static { OpenSamlInitializationService.initialize(); + saml = resolveDeserializer(); + } + + static OpenSamlDeserializer resolveDeserializer() { + if (Version.getVersion().startsWith("4")) { + return new OpenSaml4Deserializer(); + } + String opensaml5 = "org.springframework.security.saml2.provider.service.registration.OpenSaml5Template"; + try { + Class template = Class.forName(opensaml5); + OpenSamlOperations operations = (OpenSamlOperations) template.getDeclaredConstructor().newInstance(); + return operations::deserialize; + } + catch (Exception ex) { + throw new IllegalStateException( + "Application appears to be using OpenSAML 5, but Spring's OpenSAML 5 support is not on the classpath"); + } } private OpenSamlMetadataUtils() { @@ -42,7 +63,7 @@ final class OpenSamlMetadataUtils { } static Collection descriptors(InputStream metadata) { - XMLObject object = xmlObject(metadata); + XMLObject object = saml.deserialize(metadata); if (object instanceof EntityDescriptor descriptor) { return Collections.singleton(descriptor); } @@ -52,28 +73,34 @@ final class OpenSamlMetadataUtils { throw new Saml2Exception("Unsupported element type: " + object.getClass().getName()); } - static XMLObject xmlObject(InputStream inputStream) { - Document document = document(inputStream); - Element element = document.getDocumentElement(); - Unmarshaller unmarshaller = XMLObjectProviderRegistrySupport.getUnmarshallerFactory().getUnmarshaller(element); - if (unmarshaller == null) { - throw new Saml2Exception("Unsupported element of type " + element.getTagName()); - } - try { - return unmarshaller.unmarshall(element); - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } + private interface OpenSamlDeserializer { + + XMLObject deserialize(InputStream serialized); + } - static Document document(InputStream inputStream) { - try { - return XMLObjectProviderRegistrySupport.getParserPool().parse(inputStream); - } - catch (Exception ex) { - throw new Saml2Exception(ex); + private static class OpenSaml4Deserializer implements OpenSamlDeserializer { + + @Override + public XMLObject deserialize(InputStream serialized) { + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized); + Element element = document.getDocumentElement(); + UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory(); + Unmarshaller unmarshaller = factory.getUnmarshaller(element); + if (unmarshaller == null) { + throw new Saml2Exception("Unsupported element of type " + element.getTagName()); + } + return unmarshaller.unmarshall(element); + } + catch (Saml2Exception ex) { + throw ex; + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize payload", ex); + } } + } } diff --git a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepository.java b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepository.java index 4d85a683b1..5614532628 100644 --- a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepository.java +++ b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepository.java @@ -16,23 +16,40 @@ package org.springframework.security.saml2.provider.service.registration; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URL; +import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.function.Consumer; +import javax.annotation.Nonnull; + import net.shibboleth.utilities.java.support.resolver.CriteriaSet; import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.saml.criterion.EntityRoleCriterion; import org.opensaml.saml.metadata.IterableMetadataSource; import org.opensaml.saml.metadata.resolver.MetadataResolver; +import org.opensaml.saml.metadata.resolver.filter.impl.SignatureValidationFilter; +import org.opensaml.saml.metadata.resolver.impl.ResourceBackedMetadataResolver; import org.opensaml.saml.metadata.resolver.index.impl.RoleMetadataIndex; import org.opensaml.saml.saml2.metadata.EntityDescriptor; import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.springframework.core.io.DefaultResourceLoader; +import org.springframework.core.io.Resource; import org.springframework.core.io.ResourceLoader; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; -import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.provider.service.registration.BaseOpenSamlAssertingPartyMetadataRepository.MetadataResolverAdapter; import org.springframework.util.Assert; @@ -145,48 +162,136 @@ public final class OpenSaml4AssertingPartyMetadataRepository implements Assertin */ public static final class MetadataLocationRepositoryBuilder { - private final BaseOpenSamlAssertingPartyMetadataRepository.MetadataLocationRepositoryBuilder builder; + private final String metadataLocation; + + private final boolean requireVerificationCredentials; + + private final Collection verificationCredentials = new ArrayList<>(); + + private ResourceLoader resourceLoader = new DefaultResourceLoader(); MetadataLocationRepositoryBuilder(String metadataLocation, boolean trusted) { - this.builder = new BaseOpenSamlAssertingPartyMetadataRepository.MetadataLocationRepositoryBuilder( - metadataLocation, trusted); + this.metadataLocation = metadataLocation; + this.requireVerificationCredentials = !trusted; } - /** - * Apply this {@link Consumer} to the list of {@link Saml2X509Credential}s to use - * for verifying metadata signatures. - * - *

- * If no credentials are supplied, no signature verification is performed. - * @param credentials a {@link Consumer} of the {@link Collection} of - * {@link Saml2X509Credential}s - * @return the - * {@link BaseOpenSamlAssertingPartyMetadataRepository.MetadataLocationRepositoryBuilder} - * for further configuration - */ public MetadataLocationRepositoryBuilder verificationCredentials(Consumer> credentials) { - this.builder.verificationCredentials(credentials); + credentials.accept(this.verificationCredentials); return this; } - /** - * Use this {@link ResourceLoader} for resolving the {@code metadataLocation} - * @param resourceLoader the {@link ResourceLoader} to use - * @return the - * {@link BaseOpenSamlAssertingPartyMetadataRepository.MetadataLocationRepositoryBuilder} - * for further configuration - */ public MetadataLocationRepositoryBuilder resourceLoader(ResourceLoader resourceLoader) { - this.builder.resourceLoader(resourceLoader); + this.resourceLoader = resourceLoader; return this; } - /** - * Build the {@link OpenSaml4AssertingPartyMetadataRepository} - * @return the {@link OpenSaml4AssertingPartyMetadataRepository} - */ public OpenSaml4AssertingPartyMetadataRepository build() { - return new OpenSaml4AssertingPartyMetadataRepository(this.builder.metadataResolver()); + return new OpenSaml4AssertingPartyMetadataRepository(metadataResolver()); + } + + private MetadataResolver metadataResolver() { + ResourceBackedMetadataResolver metadataResolver = resourceBackedMetadataResolver(); + boolean missingCredentials = this.requireVerificationCredentials && this.verificationCredentials.isEmpty(); + Assert.isTrue(!missingCredentials, "Verification credentials are required"); + return initialize(metadataResolver); + } + + private ResourceBackedMetadataResolver resourceBackedMetadataResolver() { + Resource resource = this.resourceLoader.getResource(this.metadataLocation); + try { + ResourceBackedMetadataResolver metadataResolver = new ResourceBackedMetadataResolver( + new SpringResource(resource)); + if (this.verificationCredentials.isEmpty()) { + return metadataResolver; + } + SignatureTrustEngine engine = new ExplicitKeySignatureTrustEngine( + new CollectionCredentialResolver(this.verificationCredentials), + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + SignatureValidationFilter filter = new SignatureValidationFilter(engine); + filter.setRequireSignedRoot(true); + metadataResolver.setMetadataFilter(filter); + return metadataResolver; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private MetadataResolver initialize(ResourceBackedMetadataResolver metadataResolver) { + metadataResolver.setParserPool(XMLObjectProviderRegistrySupport.getParserPool()); + return BaseOpenSamlAssertingPartyMetadataRepository.initialize(metadataResolver); + } + + private static final class SpringResource implements net.shibboleth.utilities.java.support.resource.Resource { + + private final Resource resource; + + SpringResource(Resource resource) { + this.resource = resource; + } + + @Override + public boolean exists() { + return this.resource.exists(); + } + + @Override + public boolean isReadable() { + return this.resource.isReadable(); + } + + @Override + public boolean isOpen() { + return this.resource.isOpen(); + } + + @Override + public URL getURL() throws IOException { + return this.resource.getURL(); + } + + @Override + public URI getURI() throws IOException { + return this.resource.getURI(); + } + + @Override + public File getFile() throws IOException { + return this.resource.getFile(); + } + + @Nonnull + @Override + public InputStream getInputStream() throws IOException { + return this.resource.getInputStream(); + } + + @Override + public long contentLength() throws IOException { + return this.resource.contentLength(); + } + + @Override + public long lastModified() throws IOException { + return this.resource.lastModified(); + } + + @Override + public net.shibboleth.utilities.java.support.resource.Resource createRelativeResource(String relativePath) + throws IOException { + return new SpringResource(this.resource.createRelative(relativePath)); + } + + @Override + public String getFilename() { + return this.resource.getFilename(); + } + + @Override + public String getDescription() { + return this.resource.getDescription(); + } + } } diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index 1713de3bae..ef45985419 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -74,7 +74,7 @@ import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken; -import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSamlObjects.CustomOpenSamlObject; +import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSaml4Objects.CustomOpenSamlObject; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.util.StringUtils; @@ -349,7 +349,7 @@ public class OpenSaml4AuthenticationProviderTests { Response response = response(); Assertion assertion = assertion(); AttributeStatement attribute = TestOpenSamlObjects.customAttributeStatement("Address", - TestCustomOpenSamlObjects.instance()); + TestCustomOpenSaml4Objects.instance()); assertion.getAttributeStatements().add(attribute); response.getAssertions().add(signed(assertion)); Saml2AuthenticationToken token = token(response, verifying(registration())); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObjects.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSaml4Objects.java similarity index 95% rename from saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObjects.java rename to saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSaml4Objects.java index ae6c0cc244..f5336b9be3 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObjects.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSaml4Objects.java @@ -42,20 +42,20 @@ import org.w3c.dom.Element; import org.springframework.security.saml2.core.OpenSamlInitializationService; -public final class TestCustomOpenSamlObjects { +public final class TestCustomOpenSaml4Objects { static { OpenSamlInitializationService.initialize(); XMLObjectProviderRegistrySupport.getMarshallerFactory() .registerMarshaller(CustomOpenSamlObject.TYPE_NAME, - new TestCustomOpenSamlObjects.CustomSamlObjectMarshaller()); + new TestCustomOpenSaml4Objects.CustomSamlObjectMarshaller()); XMLObjectProviderRegistrySupport.getUnmarshallerFactory() .registerUnmarshaller(CustomOpenSamlObject.TYPE_NAME, - new TestCustomOpenSamlObjects.CustomSamlObjectUnmarshaller()); + new TestCustomOpenSaml4Objects.CustomSamlObjectUnmarshaller()); } public static CustomOpenSamlObject instance() { - CustomOpenSamlObject samlObject = new TestCustomOpenSamlObjects.CustomSamlObjectBuilder() + CustomOpenSamlObject samlObject = new TestCustomOpenSaml4Objects.CustomSamlObjectBuilder() .buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, CustomOpenSamlObject.TYPE_NAME); XSAny street = new XSAnyBuilder().buildObject(CustomOpenSamlObject.CUSTOM_NS, "Street", CustomOpenSamlObject.TYPE_CUSTOM_PREFIX); @@ -76,7 +76,7 @@ public final class TestCustomOpenSamlObjects { return samlObject; } - private TestCustomOpenSamlObjects() { + private TestCustomOpenSaml4Objects() { } diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/internal/OpenSaml5Template.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/internal/OpenSaml5Template.java new file mode 100644 index 0000000000..ef74e61c97 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/internal/OpenSaml5Template.java @@ -0,0 +1,617 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.internal; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.xml.namespace.QName; + +import net.shibboleth.shared.resolver.CriteriaSet; +import net.shibboleth.shared.xml.SerializeSupport; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.core.xml.io.Unmarshaller; +import org.opensaml.core.xml.io.UnmarshallerFactory; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.criterion.ProtocolCriterion; +import org.opensaml.saml.ext.saml2delrestrict.Delegate; +import org.opensaml.saml.ext.saml2delrestrict.DelegationRestrictionType; +import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Condition; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.LogoutRequest; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.encryption.Decrypter; +import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialResolver; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; +import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.security.criteria.UsageCriterion; +import org.opensaml.security.x509.BasicX509Credential; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.DecryptionException; +import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * For internal use only. Subject to breaking changes at any time. + */ +final class OpenSaml5Template implements OpenSamlOperations { + + private static final Log logger = LogFactory.getLog(OpenSaml5Template.class); + + @Override + public T build(QName elementName) { + XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); + if (builder == null) { + throw new Saml2Exception("Unable to resolve Builder for " + elementName); + } + return (T) builder.buildObject(elementName); + } + + @Override + public T deserialize(String serialized) { + return deserialize(new ByteArrayInputStream(serialized.getBytes(StandardCharsets.UTF_8))); + } + + @Override + public T deserialize(InputStream serialized) { + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized); + Element element = document.getDocumentElement(); + UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory(); + Unmarshaller unmarshaller = factory.getUnmarshaller(element); + if (unmarshaller == null) { + throw new Saml2Exception("Unsupported element of type " + element.getTagName()); + } + return (T) unmarshaller.unmarshall(element); + } + catch (Saml2Exception ex) { + throw ex; + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize payload", ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(XMLObject object) { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + try { + return serialize(marshaller.marshall(object)); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(Element element) { + return new OpenSaml5SerializationConfigurer(element); + } + + @Override + public OpenSaml5SignatureConfigurer withSigningKeys(Collection credentials) { + return new OpenSaml5SignatureConfigurer(credentials); + } + + @Override + public OpenSaml5VerificationConfigurer withVerificationKeys(Collection credentials) { + return new OpenSaml5VerificationConfigurer(credentials); + } + + @Override + public OpenSaml5DecryptionConfigurer withDecryptionKeys(Collection credentials) { + return new OpenSaml5DecryptionConfigurer(credentials); + } + + OpenSaml5Template() { + + } + + static final class OpenSaml5SerializationConfigurer + implements SerializationConfigurer { + + private final Element element; + + boolean pretty; + + OpenSaml5SerializationConfigurer(Element element) { + this.element = element; + } + + @Override + public OpenSaml5SerializationConfigurer prettyPrint(boolean pretty) { + this.pretty = pretty; + return this; + } + + @Override + public String serialize() { + if (this.pretty) { + return SerializeSupport.prettyPrintXML(this.element); + } + return SerializeSupport.nodeToString(this.element); + } + + } + + static final class OpenSaml5SignatureConfigurer implements SignatureConfigurer { + + private final Collection credentials; + + private final Map components = new LinkedHashMap<>(); + + private List algs = List.of(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + + OpenSaml5SignatureConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public OpenSaml5SignatureConfigurer algorithms(List algs) { + this.algs = algs; + return this; + } + + @Override + public O sign(O object) { + SignatureSigningParameters parameters = resolveSigningParameters(); + try { + SignatureSupport.signObject(object, parameters); + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + return object; + } + + @Override + public Map sign(Map params) { + SignatureSigningParameters parameters = resolveSigningParameters(); + this.components.putAll(params); + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : this.components.entrySet()) { + builder.queryParam(component.getKey(), + UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + return this.components; + } + + private SignatureSigningParameters resolveSigningParameters() { + List credentials = resolveSigningCredentials(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(this.algs); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); + CriteriaSet criteria = new CriteriaSet(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { + final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); + + namedManager.setUseDefaultManager(true); + final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); + + // Generator for X509Credentials + final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); + x509Factory.setEmitEntityCertificate(true); + x509Factory.setEmitEntityCertificateChain(true); + + defaultManager.registerFactory(x509Factory); + + return namedManager; + } + + private List resolveSigningCredentials() { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : this.credentials) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + + } + + static final class OpenSaml5VerificationConfigurer implements VerificationConfigurer { + + private final Collection credentials; + + private String entityId; + + OpenSaml5VerificationConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public VerificationConfigurer entityId(String entityId) { + this.entityId = entityId; + return this; + } + + private SignatureTrustEngine trustEngine(Collection keys) { + Set credentials = new HashSet<>(); + for (Saml2X509Credential key : keys) { + BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); + cred.setUsageType(UsageType.SIGNING); + cred.setEntityId(this.entityId); + credentials.add(cred); + } + CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); + return new ExplicitKeySignatureTrustEngine(credentialsResolver, + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + } + + private CriteriaSet verificationCriteria(Issuer issuer) { + return new CriteriaSet(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())), + new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)), + new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + } + + @Override + public Collection verify(SignableXMLObject signable) { + if (signable instanceof StatusResponseType response) { + return verifySignature(response.getID(), response.getIssuer(), response.getSignature()); + } + if (signable instanceof RequestAbstractType request) { + return verifySignature(request.getID(), request.getIssuer(), request.getSignature()); + } + if (signable instanceof Assertion assertion) { + return verifySignature(assertion.getID(), assertion.getIssuer(), assertion.getSignature()); + } + throw new Saml2Exception("Unsupported object of type: " + signable.getClass().getName()); + } + + private Collection verifySignature(String id, Issuer issuer, Signature signature) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(issuer); + Collection errors = new ArrayList<>(); + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(signature); + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + try { + if (!trustEngine.validate(signature, criteria)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + return errors; + } + + @Override + public Collection verify(RedirectParameters parameters) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(parameters.getIssuer()); + if (parameters.getAlgorithm() == null) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature algorithm for object [" + parameters.getId() + "]")); + } + if (!parameters.hasSignature()) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature for object [" + parameters.getId() + "]")); + } + Collection errors = new ArrayList<>(); + String algorithmUri = parameters.getAlgorithm(); + try { + if (!trustEngine.validate(parameters.getSignature(), parameters.getContent(), algorithmUri, criteria, + null)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]: ")); + } + return errors; + } + + } + + static final class OpenSaml5DecryptionConfigurer implements DecryptionConfigurer { + + private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( + Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), + new SimpleRetrievalMethodEncryptedKeyResolver())); + + private final Decrypter decrypter; + + OpenSaml5DecryptionConfigurer(Collection decryptionCredentials) { + this.decrypter = decrypter(decryptionCredentials); + } + + private static Decrypter decrypter(Collection decryptionCredentials) { + Collection credentials = new ArrayList<>(); + for (Saml2X509Credential key : decryptionCredentials) { + Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); + credentials.add(cred); + } + KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials); + Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); + decrypter.setRootInNewDocument(true); + return decrypter; + } + + @Override + public void decrypt(XMLObject object) { + if (object instanceof Response response) { + decryptResponse(response); + return; + } + if (object instanceof Assertion assertion) { + decryptAssertion(assertion); + } + if (object instanceof LogoutRequest request) { + decryptLogoutRequest(request); + } + } + + /* + * The methods that follow are adapted from OpenSAML's {@link DecryptAssertions}, + * {@link DecryptNameIDs}, and {@link DecryptAttributes}. + * + *

The reason that these OpenSAML classes are not used directly is because they + * reference {@link javax.servlet.http.HttpServletRequest} which is a lower + * Servlet API version than what Spring Security SAML uses. + * + * If OpenSAML 5 updates to {@link jakarta.servlet.http.HttpServletRequest}, then + * this arrangement can be revisited. + */ + + private void decryptResponse(Response response) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + + int count = 0; + int size = response.getEncryptedAssertions().size(); + for (EncryptedAssertion encrypted : response.getEncryptedAssertions()) { + logger.trace(String.format("Decrypting EncryptedAssertion (%d/%d) in Response [%s]", count, size, + response.getID())); + try { + Assertion decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + count++; + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + response.getEncryptedAssertions().removeAll(encrypteds); + response.getAssertions().addAll(decrypteds); + + // Re-marshall the response so that any ID attributes within the decrypted + // Assertions + // will have their ID-ness re-established at the DOM level. + if (!decrypteds.isEmpty()) { + try { + XMLObjectSupport.marshall(response); + } + catch (final MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + } + + private void decryptAssertion(Assertion assertion) { + for (AttributeStatement statement : assertion.getAttributeStatements()) { + decryptAttributes(statement); + } + decryptSubject(assertion.getSubject()); + if (assertion.getConditions() != null) { + for (Condition c : assertion.getConditions().getConditions()) { + if (!(c instanceof DelegationRestrictionType delegation)) { + continue; + } + for (Delegate d : delegation.getDelegates()) { + if (d.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(d.getEncryptedID()); + if (decrypted != null) { + d.setNameID(decrypted); + d.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + } + + private void decryptAttributes(AttributeStatement statement) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + for (EncryptedAttribute encrypted : statement.getEncryptedAttributes()) { + try { + Attribute decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + statement.getEncryptedAttributes().removeAll(encrypteds); + statement.getAttributes().addAll(decrypteds); + } + + private void decryptSubject(Subject subject) { + if (subject != null) { + if (subject.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(subject.getEncryptedID()); + if (decrypted != null) { + subject.setNameID(decrypted); + subject.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + for (final SubjectConfirmation sc : subject.getSubjectConfirmations()) { + if (sc.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(sc.getEncryptedID()); + if (decrypted != null) { + sc.setNameID(decrypted); + sc.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + + private void decryptLogoutRequest(LogoutRequest request) { + if (request.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(request.getEncryptedID()); + if (decrypted != null) { + request.setNameID(decrypted); + request.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml5AuthenticationProvider.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml5AuthenticationProvider.java new file mode 100644 index 0000000000..2ab8cbaaf1 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml5AuthenticationProvider.java @@ -0,0 +1,496 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication; + +import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import org.opensaml.saml.common.assertion.ValidationContext; +import org.opensaml.saml.common.assertion.ValidationResult; +import org.opensaml.saml.saml2.assertion.SAML20AssertionValidator; +import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.core.SubjectConfirmationData; +import org.opensaml.saml.saml2.encryption.Decrypter; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * Implementation of {@link AuthenticationProvider} for SAML authentications when + * receiving a {@code Response} object containing an {@code Assertion}. This + * implementation uses the {@code OpenSAML 4} library. + * + *

+ * The {@link OpenSaml5AuthenticationProvider} supports {@link Saml2AuthenticationToken} + * objects that contain a SAML response in its decoded XML format + * {@link Saml2AuthenticationToken#getSaml2Response()} along with the information about + * the asserting party, the identity provider (IDP), as well as the relying party, the + * service provider (SP, this application). + *

+ * The {@link Saml2AuthenticationToken} will be processed into a SAML Response object. The + * SAML response object can be signed. If the Response is signed, a signature will not be + * required on the assertion. + *

+ * While a response object can contain a list of assertion, this provider will only + * leverage the first valid assertion for the purpose of authentication. Assertions that + * do not pass validation will be ignored. If no valid assertions are found a + * {@link Saml2AuthenticationException} is thrown. + *

+ * This provider supports two types of encrypted SAML elements + *

+ * If the assertion is encrypted, then signature validation on the assertion is no longer + * required. + *

+ * This provider does not perform an X509 certificate validation on the configured + * asserting party, IDP, verification certificates. + * + * @author Josh Cummings + * @since 5.5 + * @see SAML 2 + * StatusResponse + * @see OpenSAML + */ +public final class OpenSaml5AuthenticationProvider implements AuthenticationProvider { + + private final BaseOpenSamlAuthenticationProvider delegate; + + /** + * Creates an {@link OpenSaml5AuthenticationProvider} + */ + public OpenSaml5AuthenticationProvider() { + this.delegate = new BaseOpenSamlAuthenticationProvider(new OpenSaml5Template()); + setAssertionValidator(createDefaultAssertionValidator()); + } + + /** + * Set the {@link Consumer} strategy to use for decrypting elements of a validated + * {@link Response}. The default strategy decrypts all {@link EncryptedAssertion}s + * using OpenSAML's {@link Decrypter}, adding the results to + * {@link Response#getAssertions()}. + * + * You can use this method to configure the {@link Decrypter} instance like so: + * + *

+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	provider.setResponseElementsDecrypter((responseToken) -> {
+	 *	    DecrypterParameters parameters = new DecrypterParameters();
+	 *	    // ... set parameters as needed
+	 *	    Decrypter decrypter = new Decrypter(parameters);
+	 *		Response response = responseToken.getResponse();
+	 *  	EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
+	 *  	try {
+	 *  		Assertion assertion = decrypter.decrypt(encrypted);
+	 *  		response.getAssertions().add(assertion);
+	 *  	} catch (Exception e) {
+	 *  	 	throw new Saml2AuthenticationException(...);
+	 *  	}
+	 *	});
+	 * 
+ * + * Or, in the event that you have your own custom decryption interface, the same + * pattern applies: + * + *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	Converter<EncryptedAssertion, Assertion> myService = ...
+	 *	provider.setResponseDecrypter((responseToken) -> {
+	 *	   Response response = responseToken.getResponse();
+	 *	   response.getEncryptedAssertions().stream()
+	 *	   		.map(service::decrypt).forEach(response.getAssertions()::add);
+	 *	});
+	 * 
+ * + * This is valuable when using an external service to perform the decryption. + * @param responseElementsDecrypter the {@link Consumer} for decrypting response + * elements + * @since 5.5 + */ + public void setResponseElementsDecrypter(Consumer responseElementsDecrypter) { + Assert.notNull(responseElementsDecrypter, "responseElementsDecrypter cannot be null"); + this.delegate + .setResponseElementsDecrypter((token) -> responseElementsDecrypter.accept(new ResponseToken(token))); + } + + /** + * Set the {@link Converter} to use for validating the SAML 2.0 Response. + * + * You can still invoke the default validator by delegating to + * {@link #createDefaultResponseValidator()}, like so: + * + *
+	 * OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider();
+	 * provider.setResponseValidator(responseToken -> {
+	 * 		Saml2ResponseValidatorResult result = createDefaultResponseValidator()
+	 * 			.convert(responseToken)
+	 * 		return result.concat(myCustomValidator.convert(responseToken));
+	 * });
+	 * 
+ * @param responseValidator the {@link Converter} to use + * @since 5.6 + */ + public void setResponseValidator(Converter responseValidator) { + Assert.notNull(responseValidator, "responseValidator cannot be null"); + this.delegate.setResponseValidator((token) -> responseValidator.convert(new ResponseToken(token))); + } + + /** + * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML + * 2.0 Response. + * + * You can still invoke the default validator by delgating to + * {@link #createAssertionValidator}, like so: + * + *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *  provider.setAssertionValidator(assertionToken -> {
+	 *		Saml2ResponseValidatorResult result = createDefaultAssertionValidator()
+	 *			.convert(assertionToken)
+	 *		return result.concat(myCustomValidator.convert(assertionToken));
+	 *  });
+	 * 
+ * + * You can also use this method to configure the provider to use a different + * {@link ValidationContext} from the default, like so: + * + *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	provider.setAssertionValidator(
+	 *		createDefaultAssertionValidator(assertionToken -> {
+	 *			Map<String, Object> params = new HashMap<>();
+	 *			params.put(CLOCK_SKEW, 2 * 60 * 1000);
+	 *			// other parameters
+	 *			return new ValidationContext(params);
+	 *		}));
+	 * 
+ * + * Consider taking a look at {@link #createValidationContext} to see how it constructs + * a {@link ValidationContext}. + * + * It is not necessary to delegate to the default validator. You can safely replace it + * entirely with your own. Note that signature verification is performed as a separate + * step from this validator. + * @param assertionValidator the validator to use + * @since 5.4 + */ + public void setAssertionValidator(Converter assertionValidator) { + Assert.notNull(assertionValidator, "assertionValidator cannot be null"); + this.delegate.setAssertionValidator((token) -> assertionValidator.convert(new AssertionToken(token))); + } + + /** + * Set the {@link Consumer} strategy to use for decrypting elements of a validated + * {@link Assertion}. + * + * You can use this method to configure the {@link Decrypter} used like so: + * + *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	provider.setResponseDecrypter((assertionToken) -> {
+	 *	    DecrypterParameters parameters = new DecrypterParameters();
+	 *	    // ... set parameters as needed
+	 *	    Decrypter decrypter = new Decrypter(parameters);
+	 *		Assertion assertion = assertionToken.getAssertion();
+	 *  	EncryptedID encrypted = assertion.getSubject().getEncryptedID();
+	 *  	try {
+	 *  		NameID name = decrypter.decrypt(encrypted);
+	 *  		assertion.getSubject().setNameID(name);
+	 *  	} catch (Exception e) {
+	 *  	 	throw new Saml2AuthenticationException(...);
+	 *  	}
+	 *	});
+	 * 
+ * + * Or, in the event that you have your own custom interface, the same pattern applies: + * + *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	MyDecryptionService myService = ...
+	 *	provider.setResponseDecrypter((responseToken) -> {
+	 *	   	Assertion assertion = assertionToken.getAssertion();
+	 *	   	EncryptedID encrypted = assertion.getSubject().getEncryptedID();
+	 *		NameID name = myService.decrypt(encrypted);
+	 *		assertion.getSubject().setNameID(name);
+	 *	});
+	 * 
+ * @param assertionDecrypter the {@link Consumer} for decrypting assertion elements + * @since 5.5 + */ + public void setAssertionElementsDecrypter(Consumer assertionDecrypter) { + Assert.notNull(assertionDecrypter, "assertionDecrypter cannot be null"); + this.delegate.setAssertionElementsDecrypter((token) -> assertionDecrypter.accept(new AssertionToken(token))); + } + + /** + * Set the {@link Converter} to use for converting a validated {@link Response} into + * an {@link AbstractAuthenticationToken}. + * + * You can delegate to the default behavior by calling + * {@link #createDefaultResponseAuthenticationConverter()} like so: + * + *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 * 	Converter<ResponseToken, Saml2Authentication> authenticationConverter =
+	 * 			createDefaultResponseAuthenticationConverter();
+	 *	provider.setResponseAuthenticationConverter(responseToken -> {
+	 *		Saml2Authentication authentication = authenticationConverter.convert(responseToken);
+	 *		User user = myUserRepository.findByUsername(authentication.getName());
+	 *		return new MyAuthentication(authentication, user);
+	 *	});
+	 * 
+ * @param responseAuthenticationConverter the {@link Converter} to use + * @since 5.4 + */ + public void setResponseAuthenticationConverter( + Converter responseAuthenticationConverter) { + Assert.notNull(responseAuthenticationConverter, "responseAuthenticationConverter cannot be null"); + this.delegate.setResponseAuthenticationConverter( + (token) -> responseAuthenticationConverter.convert(new ResponseToken(token))); + } + + /** + * Construct a default strategy for validating the SAML 2.0 Response + * @return the default response validator strategy + * @since 5.6 + */ + public static Converter createDefaultResponseValidator() { + Converter delegate = BaseOpenSamlAuthenticationProvider + .createDefaultResponseValidator(); + return (token) -> delegate + .convert(new BaseOpenSamlAuthenticationProvider.ResponseToken(token.getResponse(), token.getToken())); + } + + /** + * Construct a default strategy for validating each SAML 2.0 Assertion and associated + * {@link Authentication} token + * @return the default assertion validator strategy + */ + public static Converter createDefaultAssertionValidator() { + return createDefaultAssertionValidatorWithParameters( + (params) -> params.put(SAML2AssertionValidationParameters.CLOCK_SKEW, Duration.ofMinutes(5))); + } + + /** + * Construct a default strategy for validating each SAML 2.0 Assertion and associated + * {@link Authentication} token + * @param contextConverter the conversion strategy to use to generate a + * {@link ValidationContext} for each assertion being validated + * @return the default assertion validator strategy + * @deprecated Use {@link #createDefaultAssertionValidatorWithParameters} instead + */ + @Deprecated + public static Converter createDefaultAssertionValidator( + Converter contextConverter) { + return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, + (assertionToken) -> BaseOpenSamlAuthenticationProvider.SAML20AssertionValidators.attributeValidator, + contextConverter); + } + + /** + * Construct a default strategy for validating each SAML 2.0 Assertion and associated + * {@link Authentication} token + * @param validationContextParameters a consumer for editing the values passed to the + * {@link ValidationContext} for each assertion being validated + * @return the default assertion validator strategy + * @since 5.8 + */ + public static Converter createDefaultAssertionValidatorWithParameters( + Consumer> validationContextParameters) { + return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION, + (assertionToken) -> BaseOpenSamlAuthenticationProvider.SAML20AssertionValidators.attributeValidator, + (assertionToken) -> createValidationContext(assertionToken, validationContextParameters)); + } + + /** + * Construct a default strategy for converting a SAML 2.0 Response and + * {@link Authentication} token into a {@link Saml2Authentication} + * @return the default response authentication converter strategy + */ + public static Converter createDefaultResponseAuthenticationConverter() { + Converter delegate = BaseOpenSamlAuthenticationProvider + .createDefaultResponseAuthenticationConverter(); + return (token) -> delegate + .convert(new BaseOpenSamlAuthenticationProvider.ResponseToken(token.getResponse(), token.getToken())); + } + + /** + * @param authentication the authentication request object, must be of type + * {@link Saml2AuthenticationToken} + * @return {@link Saml2Authentication} if the assertion is valid + * @throws AuthenticationException if a validation exception occurs + */ + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + return this.delegate.authenticate(authentication); + } + + @Override + public boolean supports(Class authentication) { + return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication); + } + + private static Converter createAssertionValidator(String errorCode, + Converter validatorConverter, + Converter contextConverter) { + + return (assertionToken) -> { + Assertion assertion = assertionToken.getAssertion(); + SAML20AssertionValidator validator = validatorConverter.convert(assertionToken); + ValidationContext context = contextConverter.convert(assertionToken); + try { + ValidationResult result = validator.validate(assertion, context); + if (result == ValidationResult.VALID) { + return Saml2ResponseValidatorResult.success(); + } + } + catch (Exception ex) { + String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(), + ((Response) assertion.getParent()).getID(), ex.getMessage()); + return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message)); + } + String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(), + ((Response) assertion.getParent()).getID(), context.getValidationFailureMessages()); + return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message)); + }; + } + + private static ValidationContext createValidationContext(AssertionToken assertionToken, + Consumer> paramsConsumer) { + Saml2AuthenticationToken token = assertionToken.getToken(); + RelyingPartyRegistration relyingPartyRegistration = token.getRelyingPartyRegistration(); + String audience = relyingPartyRegistration.getEntityId(); + String recipient = relyingPartyRegistration.getAssertionConsumerServiceLocation(); + String assertingPartyEntityId = relyingPartyRegistration.getAssertingPartyMetadata().getEntityId(); + Map params = new HashMap<>(); + Assertion assertion = assertionToken.getAssertion(); + if (assertionContainsInResponseTo(assertion)) { + String requestId = getAuthnRequestId(token.getAuthenticationRequest()); + params.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId); + } + params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience)); + params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient)); + params.put(SAML2AssertionValidationParameters.VALID_ISSUERS, Collections.singleton(assertingPartyEntityId)); + paramsConsumer.accept(params); + return new ValidationContext(params); + } + + private static boolean assertionContainsInResponseTo(Assertion assertion) { + if (assertion.getSubject() == null) { + return false; + } + for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) { + SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData(); + if (confirmationData == null) { + continue; + } + if (StringUtils.hasText(confirmationData.getInResponseTo())) { + return true; + } + } + return false; + } + + private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) { + return (serialized != null) ? serialized.getId() : null; + } + + /** + * A tuple containing an OpenSAML {@link Response} and its associated authentication + * token. + * + * @since 5.4 + */ + public static class ResponseToken { + + private final Saml2AuthenticationToken token; + + private final Response response; + + ResponseToken(Response response, Saml2AuthenticationToken token) { + this.token = token; + this.response = response; + } + + ResponseToken(BaseOpenSamlAuthenticationProvider.ResponseToken token) { + this.token = token.getToken(); + this.response = token.getResponse(); + } + + public Response getResponse() { + return this.response; + } + + public Saml2AuthenticationToken getToken() { + return this.token; + } + + } + + /** + * A tuple containing an OpenSAML {@link Assertion} and its associated authentication + * token. + * + * @since 5.4 + */ + public static class AssertionToken { + + private final Saml2AuthenticationToken token; + + private final Assertion assertion; + + AssertionToken(Assertion assertion, Saml2AuthenticationToken token) { + this.token = token; + this.assertion = assertion; + } + + AssertionToken(BaseOpenSamlAuthenticationProvider.AssertionToken token) { + this.token = token.getToken(); + this.assertion = token.getAssertion(); + } + + public Assertion getAssertion() { + return this.assertion; + } + + public Saml2AuthenticationToken getToken() { + return this.token; + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml5Template.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml5Template.java new file mode 100644 index 0000000000..b2003af461 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml5Template.java @@ -0,0 +1,617 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.xml.namespace.QName; + +import net.shibboleth.shared.resolver.CriteriaSet; +import net.shibboleth.shared.xml.SerializeSupport; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.core.xml.io.Unmarshaller; +import org.opensaml.core.xml.io.UnmarshallerFactory; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.criterion.ProtocolCriterion; +import org.opensaml.saml.ext.saml2delrestrict.Delegate; +import org.opensaml.saml.ext.saml2delrestrict.DelegationRestrictionType; +import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Condition; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.LogoutRequest; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.encryption.Decrypter; +import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialResolver; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; +import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.security.criteria.UsageCriterion; +import org.opensaml.security.x509.BasicX509Credential; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.DecryptionException; +import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * For internal use only. Subject to breaking changes at any time. + */ +final class OpenSaml5Template implements OpenSamlOperations { + + private static final Log logger = LogFactory.getLog(OpenSaml5Template.class); + + @Override + public T build(QName elementName) { + XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); + if (builder == null) { + throw new Saml2Exception("Unable to resolve Builder for " + elementName); + } + return (T) builder.buildObject(elementName); + } + + @Override + public T deserialize(String serialized) { + return deserialize(new ByteArrayInputStream(serialized.getBytes(StandardCharsets.UTF_8))); + } + + @Override + public T deserialize(InputStream serialized) { + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized); + Element element = document.getDocumentElement(); + UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory(); + Unmarshaller unmarshaller = factory.getUnmarshaller(element); + if (unmarshaller == null) { + throw new Saml2Exception("Unsupported element of type " + element.getTagName()); + } + return (T) unmarshaller.unmarshall(element); + } + catch (Saml2Exception ex) { + throw ex; + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize payload", ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(XMLObject object) { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + try { + return serialize(marshaller.marshall(object)); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(Element element) { + return new OpenSaml5SerializationConfigurer(element); + } + + @Override + public OpenSaml5SignatureConfigurer withSigningKeys(Collection credentials) { + return new OpenSaml5SignatureConfigurer(credentials); + } + + @Override + public OpenSaml5VerificationConfigurer withVerificationKeys(Collection credentials) { + return new OpenSaml5VerificationConfigurer(credentials); + } + + @Override + public OpenSaml5DecryptionConfigurer withDecryptionKeys(Collection credentials) { + return new OpenSaml5DecryptionConfigurer(credentials); + } + + OpenSaml5Template() { + + } + + static final class OpenSaml5SerializationConfigurer + implements SerializationConfigurer { + + private final Element element; + + boolean pretty; + + OpenSaml5SerializationConfigurer(Element element) { + this.element = element; + } + + @Override + public OpenSaml5SerializationConfigurer prettyPrint(boolean pretty) { + this.pretty = pretty; + return this; + } + + @Override + public String serialize() { + if (this.pretty) { + return SerializeSupport.prettyPrintXML(this.element); + } + return SerializeSupport.nodeToString(this.element); + } + + } + + static final class OpenSaml5SignatureConfigurer implements SignatureConfigurer { + + private final Collection credentials; + + private final Map components = new LinkedHashMap<>(); + + private List algs = List.of(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + + OpenSaml5SignatureConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public OpenSaml5SignatureConfigurer algorithms(List algs) { + this.algs = algs; + return this; + } + + @Override + public O sign(O object) { + SignatureSigningParameters parameters = resolveSigningParameters(); + try { + SignatureSupport.signObject(object, parameters); + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + return object; + } + + @Override + public Map sign(Map params) { + SignatureSigningParameters parameters = resolveSigningParameters(); + this.components.putAll(params); + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : this.components.entrySet()) { + builder.queryParam(component.getKey(), + UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + return this.components; + } + + private SignatureSigningParameters resolveSigningParameters() { + List credentials = resolveSigningCredentials(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(this.algs); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); + CriteriaSet criteria = new CriteriaSet(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { + final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); + + namedManager.setUseDefaultManager(true); + final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); + + // Generator for X509Credentials + final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); + x509Factory.setEmitEntityCertificate(true); + x509Factory.setEmitEntityCertificateChain(true); + + defaultManager.registerFactory(x509Factory); + + return namedManager; + } + + private List resolveSigningCredentials() { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : this.credentials) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + + } + + static final class OpenSaml5VerificationConfigurer implements VerificationConfigurer { + + private final Collection credentials; + + private String entityId; + + OpenSaml5VerificationConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public VerificationConfigurer entityId(String entityId) { + this.entityId = entityId; + return this; + } + + private SignatureTrustEngine trustEngine(Collection keys) { + Set credentials = new HashSet<>(); + for (Saml2X509Credential key : keys) { + BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); + cred.setUsageType(UsageType.SIGNING); + cred.setEntityId(this.entityId); + credentials.add(cred); + } + CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); + return new ExplicitKeySignatureTrustEngine(credentialsResolver, + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + } + + private CriteriaSet verificationCriteria(Issuer issuer) { + return new CriteriaSet(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())), + new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)), + new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + } + + @Override + public Collection verify(SignableXMLObject signable) { + if (signable instanceof StatusResponseType response) { + return verifySignature(response.getID(), response.getIssuer(), response.getSignature()); + } + if (signable instanceof RequestAbstractType request) { + return verifySignature(request.getID(), request.getIssuer(), request.getSignature()); + } + if (signable instanceof Assertion assertion) { + return verifySignature(assertion.getID(), assertion.getIssuer(), assertion.getSignature()); + } + throw new Saml2Exception("Unsupported object of type: " + signable.getClass().getName()); + } + + private Collection verifySignature(String id, Issuer issuer, Signature signature) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(issuer); + Collection errors = new ArrayList<>(); + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(signature); + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + try { + if (!trustEngine.validate(signature, criteria)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + return errors; + } + + @Override + public Collection verify(RedirectParameters parameters) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(parameters.getIssuer()); + if (parameters.getAlgorithm() == null) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature algorithm for object [" + parameters.getId() + "]")); + } + if (!parameters.hasSignature()) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature for object [" + parameters.getId() + "]")); + } + Collection errors = new ArrayList<>(); + String algorithmUri = parameters.getAlgorithm(); + try { + if (!trustEngine.validate(parameters.getSignature(), parameters.getContent(), algorithmUri, criteria, + null)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]: ")); + } + return errors; + } + + } + + static final class OpenSaml5DecryptionConfigurer implements DecryptionConfigurer { + + private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( + Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), + new SimpleRetrievalMethodEncryptedKeyResolver())); + + private final Decrypter decrypter; + + OpenSaml5DecryptionConfigurer(Collection decryptionCredentials) { + this.decrypter = decrypter(decryptionCredentials); + } + + private static Decrypter decrypter(Collection decryptionCredentials) { + Collection credentials = new ArrayList<>(); + for (Saml2X509Credential key : decryptionCredentials) { + Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); + credentials.add(cred); + } + KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials); + Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); + decrypter.setRootInNewDocument(true); + return decrypter; + } + + @Override + public void decrypt(XMLObject object) { + if (object instanceof Response response) { + decryptResponse(response); + return; + } + if (object instanceof Assertion assertion) { + decryptAssertion(assertion); + } + if (object instanceof LogoutRequest request) { + decryptLogoutRequest(request); + } + } + + /* + * The methods that follow are adapted from OpenSAML's {@link DecryptAssertions}, + * {@link DecryptNameIDs}, and {@link DecryptAttributes}. + * + *

The reason that these OpenSAML classes are not used directly is because they + * reference {@link javax.servlet.http.HttpServletRequest} which is a lower + * Servlet API version than what Spring Security SAML uses. + * + * If OpenSAML 5 updates to {@link jakarta.servlet.http.HttpServletRequest}, then + * this arrangement can be revisited. + */ + + private void decryptResponse(Response response) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + + int count = 0; + int size = response.getEncryptedAssertions().size(); + for (EncryptedAssertion encrypted : response.getEncryptedAssertions()) { + logger.trace(String.format("Decrypting EncryptedAssertion (%d/%d) in Response [%s]", count, size, + response.getID())); + try { + Assertion decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + count++; + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + response.getEncryptedAssertions().removeAll(encrypteds); + response.getAssertions().addAll(decrypteds); + + // Re-marshall the response so that any ID attributes within the decrypted + // Assertions + // will have their ID-ness re-established at the DOM level. + if (!decrypteds.isEmpty()) { + try { + XMLObjectSupport.marshall(response); + } + catch (final MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + } + + private void decryptAssertion(Assertion assertion) { + for (AttributeStatement statement : assertion.getAttributeStatements()) { + decryptAttributes(statement); + } + decryptSubject(assertion.getSubject()); + if (assertion.getConditions() != null) { + for (Condition c : assertion.getConditions().getConditions()) { + if (!(c instanceof DelegationRestrictionType delegation)) { + continue; + } + for (Delegate d : delegation.getDelegates()) { + if (d.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(d.getEncryptedID()); + if (decrypted != null) { + d.setNameID(decrypted); + d.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + } + + private void decryptAttributes(AttributeStatement statement) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + for (EncryptedAttribute encrypted : statement.getEncryptedAttributes()) { + try { + Attribute decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + statement.getEncryptedAttributes().removeAll(encrypteds); + statement.getAttributes().addAll(decrypteds); + } + + private void decryptSubject(Subject subject) { + if (subject != null) { + if (subject.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(subject.getEncryptedID()); + if (decrypted != null) { + subject.setNameID(decrypted); + subject.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + for (final SubjectConfirmation sc : subject.getSubjectConfirmations()) { + if (sc.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(sc.getEncryptedID()); + if (decrypted != null) { + sc.setNameID(decrypted); + sc.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + + private void decryptLogoutRequest(LogoutRequest request) { + if (request.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(request.getEncryptedID()); + if (decrypted != null) { + request.setNameID(decrypted); + request.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutRequestValidator.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutRequestValidator.java new file mode 100644 index 0000000000..a9f13cd5b0 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutRequestValidator.java @@ -0,0 +1,36 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication.logout; + +/** + * An OpenSAML 4.x compatible implementation of {@link Saml2LogoutResponseValidator} + * + * @author Josh Cummings + * @since 5.6 + */ +public final class OpenSaml5LogoutRequestValidator implements Saml2LogoutRequestValidator { + + @SuppressWarnings("deprecation") + private final Saml2LogoutRequestValidator delegate = new BaseOpenSamlLogoutRequestValidator( + new OpenSaml5Template()); + + @Override + public Saml2LogoutValidatorResult validate(Saml2LogoutRequestValidatorParameters parameters) { + return this.delegate.validate(parameters); + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutResponseValidator.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutResponseValidator.java new file mode 100644 index 0000000000..8f99e6a5a8 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutResponseValidator.java @@ -0,0 +1,36 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication.logout; + +/** + * An OpenSAML 5.x compatible implementation of {@link Saml2LogoutResponseValidator} + * + * @author Josh Cummings + * @since 5.6 + */ +public final class OpenSaml5LogoutResponseValidator implements Saml2LogoutResponseValidator { + + @SuppressWarnings("deprecation") + private final Saml2LogoutResponseValidator delegate = new BaseOpenSamlLogoutResponseValidator( + new OpenSaml5Template()); + + @Override + public Saml2LogoutValidatorResult validate(Saml2LogoutResponseValidatorParameters parameters) { + return this.delegate.validate(parameters); + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5Template.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5Template.java new file mode 100644 index 0000000000..d0add4c54c --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5Template.java @@ -0,0 +1,617 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication.logout; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.xml.namespace.QName; + +import net.shibboleth.shared.resolver.CriteriaSet; +import net.shibboleth.shared.xml.SerializeSupport; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.core.xml.io.Unmarshaller; +import org.opensaml.core.xml.io.UnmarshallerFactory; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.criterion.ProtocolCriterion; +import org.opensaml.saml.ext.saml2delrestrict.Delegate; +import org.opensaml.saml.ext.saml2delrestrict.DelegationRestrictionType; +import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Condition; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.LogoutRequest; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.encryption.Decrypter; +import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialResolver; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; +import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.security.criteria.UsageCriterion; +import org.opensaml.security.x509.BasicX509Credential; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.DecryptionException; +import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * For internal use only. Subject to breaking changes at any time. + */ +final class OpenSaml5Template implements OpenSamlOperations { + + private static final Log logger = LogFactory.getLog(OpenSaml5Template.class); + + @Override + public T build(QName elementName) { + XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); + if (builder == null) { + throw new Saml2Exception("Unable to resolve Builder for " + elementName); + } + return (T) builder.buildObject(elementName); + } + + @Override + public T deserialize(String serialized) { + return deserialize(new ByteArrayInputStream(serialized.getBytes(StandardCharsets.UTF_8))); + } + + @Override + public T deserialize(InputStream serialized) { + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized); + Element element = document.getDocumentElement(); + UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory(); + Unmarshaller unmarshaller = factory.getUnmarshaller(element); + if (unmarshaller == null) { + throw new Saml2Exception("Unsupported element of type " + element.getTagName()); + } + return (T) unmarshaller.unmarshall(element); + } + catch (Saml2Exception ex) { + throw ex; + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize payload", ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(XMLObject object) { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + try { + return serialize(marshaller.marshall(object)); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(Element element) { + return new OpenSaml5SerializationConfigurer(element); + } + + @Override + public OpenSaml5SignatureConfigurer withSigningKeys(Collection credentials) { + return new OpenSaml5SignatureConfigurer(credentials); + } + + @Override + public OpenSaml5VerificationConfigurer withVerificationKeys(Collection credentials) { + return new OpenSaml5VerificationConfigurer(credentials); + } + + @Override + public OpenSaml5DecryptionConfigurer withDecryptionKeys(Collection credentials) { + return new OpenSaml5DecryptionConfigurer(credentials); + } + + OpenSaml5Template() { + + } + + static final class OpenSaml5SerializationConfigurer + implements SerializationConfigurer { + + private final Element element; + + boolean pretty; + + OpenSaml5SerializationConfigurer(Element element) { + this.element = element; + } + + @Override + public OpenSaml5SerializationConfigurer prettyPrint(boolean pretty) { + this.pretty = pretty; + return this; + } + + @Override + public String serialize() { + if (this.pretty) { + return SerializeSupport.prettyPrintXML(this.element); + } + return SerializeSupport.nodeToString(this.element); + } + + } + + static final class OpenSaml5SignatureConfigurer implements SignatureConfigurer { + + private final Collection credentials; + + private final Map components = new LinkedHashMap<>(); + + private List algs = List.of(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + + OpenSaml5SignatureConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public OpenSaml5SignatureConfigurer algorithms(List algs) { + this.algs = algs; + return this; + } + + @Override + public O sign(O object) { + SignatureSigningParameters parameters = resolveSigningParameters(); + try { + SignatureSupport.signObject(object, parameters); + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + return object; + } + + @Override + public Map sign(Map params) { + SignatureSigningParameters parameters = resolveSigningParameters(); + this.components.putAll(params); + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : this.components.entrySet()) { + builder.queryParam(component.getKey(), + UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + return this.components; + } + + private SignatureSigningParameters resolveSigningParameters() { + List credentials = resolveSigningCredentials(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(this.algs); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); + CriteriaSet criteria = new CriteriaSet(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { + final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); + + namedManager.setUseDefaultManager(true); + final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); + + // Generator for X509Credentials + final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); + x509Factory.setEmitEntityCertificate(true); + x509Factory.setEmitEntityCertificateChain(true); + + defaultManager.registerFactory(x509Factory); + + return namedManager; + } + + private List resolveSigningCredentials() { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : this.credentials) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + + } + + static final class OpenSaml5VerificationConfigurer implements VerificationConfigurer { + + private final Collection credentials; + + private String entityId; + + OpenSaml5VerificationConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public VerificationConfigurer entityId(String entityId) { + this.entityId = entityId; + return this; + } + + private SignatureTrustEngine trustEngine(Collection keys) { + Set credentials = new HashSet<>(); + for (Saml2X509Credential key : keys) { + BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); + cred.setUsageType(UsageType.SIGNING); + cred.setEntityId(this.entityId); + credentials.add(cred); + } + CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); + return new ExplicitKeySignatureTrustEngine(credentialsResolver, + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + } + + private CriteriaSet verificationCriteria(Issuer issuer) { + return new CriteriaSet(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())), + new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)), + new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + } + + @Override + public Collection verify(SignableXMLObject signable) { + if (signable instanceof StatusResponseType response) { + return verifySignature(response.getID(), response.getIssuer(), response.getSignature()); + } + if (signable instanceof RequestAbstractType request) { + return verifySignature(request.getID(), request.getIssuer(), request.getSignature()); + } + if (signable instanceof Assertion assertion) { + return verifySignature(assertion.getID(), assertion.getIssuer(), assertion.getSignature()); + } + throw new Saml2Exception("Unsupported object of type: " + signable.getClass().getName()); + } + + private Collection verifySignature(String id, Issuer issuer, Signature signature) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(issuer); + Collection errors = new ArrayList<>(); + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(signature); + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + try { + if (!trustEngine.validate(signature, criteria)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + return errors; + } + + @Override + public Collection verify(RedirectParameters parameters) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(parameters.getIssuer()); + if (parameters.getAlgorithm() == null) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature algorithm for object [" + parameters.getId() + "]")); + } + if (!parameters.hasSignature()) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature for object [" + parameters.getId() + "]")); + } + Collection errors = new ArrayList<>(); + String algorithmUri = parameters.getAlgorithm(); + try { + if (!trustEngine.validate(parameters.getSignature(), parameters.getContent(), algorithmUri, criteria, + null)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]: ")); + } + return errors; + } + + } + + static final class OpenSaml5DecryptionConfigurer implements DecryptionConfigurer { + + private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( + Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), + new SimpleRetrievalMethodEncryptedKeyResolver())); + + private final Decrypter decrypter; + + OpenSaml5DecryptionConfigurer(Collection decryptionCredentials) { + this.decrypter = decrypter(decryptionCredentials); + } + + private static Decrypter decrypter(Collection decryptionCredentials) { + Collection credentials = new ArrayList<>(); + for (Saml2X509Credential key : decryptionCredentials) { + Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); + credentials.add(cred); + } + KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials); + Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); + decrypter.setRootInNewDocument(true); + return decrypter; + } + + @Override + public void decrypt(XMLObject object) { + if (object instanceof Response response) { + decryptResponse(response); + return; + } + if (object instanceof Assertion assertion) { + decryptAssertion(assertion); + } + if (object instanceof LogoutRequest request) { + decryptLogoutRequest(request); + } + } + + /* + * The methods that follow are adapted from OpenSAML's {@link DecryptAssertions}, + * {@link DecryptNameIDs}, and {@link DecryptAttributes}. + * + *

The reason that these OpenSAML classes are not used directly is because they + * reference {@link javax.servlet.http.HttpServletRequest} which is a lower + * Servlet API version than what Spring Security SAML uses. + * + * If OpenSAML 5 updates to {@link jakarta.servlet.http.HttpServletRequest}, then + * this arrangement can be revisited. + */ + + private void decryptResponse(Response response) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + + int count = 0; + int size = response.getEncryptedAssertions().size(); + for (EncryptedAssertion encrypted : response.getEncryptedAssertions()) { + logger.trace(String.format("Decrypting EncryptedAssertion (%d/%d) in Response [%s]", count, size, + response.getID())); + try { + Assertion decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + count++; + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + response.getEncryptedAssertions().removeAll(encrypteds); + response.getAssertions().addAll(decrypteds); + + // Re-marshall the response so that any ID attributes within the decrypted + // Assertions + // will have their ID-ness re-established at the DOM level. + if (!decrypteds.isEmpty()) { + try { + XMLObjectSupport.marshall(response); + } + catch (final MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + } + + private void decryptAssertion(Assertion assertion) { + for (AttributeStatement statement : assertion.getAttributeStatements()) { + decryptAttributes(statement); + } + decryptSubject(assertion.getSubject()); + if (assertion.getConditions() != null) { + for (Condition c : assertion.getConditions().getConditions()) { + if (!(c instanceof DelegationRestrictionType delegation)) { + continue; + } + for (Delegate d : delegation.getDelegates()) { + if (d.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(d.getEncryptedID()); + if (decrypted != null) { + d.setNameID(decrypted); + d.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + } + + private void decryptAttributes(AttributeStatement statement) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + for (EncryptedAttribute encrypted : statement.getEncryptedAttributes()) { + try { + Attribute decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + statement.getEncryptedAttributes().removeAll(encrypteds); + statement.getAttributes().addAll(decrypteds); + } + + private void decryptSubject(Subject subject) { + if (subject != null) { + if (subject.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(subject.getEncryptedID()); + if (decrypted != null) { + subject.setNameID(decrypted); + subject.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + for (final SubjectConfirmation sc : subject.getSubjectConfirmations()) { + if (sc.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(sc.getEncryptedID()); + if (decrypted != null) { + sc.setNameID(decrypted); + sc.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + + private void decryptLogoutRequest(LogoutRequest request) { + if (request.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(request.getEncryptedID()); + if (decrypted != null) { + request.setNameID(decrypted); + request.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/Saml2Utils.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/Saml2Utils.java new file mode 100644 index 0000000000..01c0ff5ec0 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/authentication/logout/Saml2Utils.java @@ -0,0 +1,196 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication.logout; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Base64; +import java.util.zip.Deflater; +import java.util.zip.DeflaterOutputStream; +import java.util.zip.Inflater; +import java.util.zip.InflaterOutputStream; + +import org.springframework.security.saml2.Saml2Exception; + +/** + * Utility methods for working with serialized SAML messages. + * + * For internal use only. + * + * @author Josh Cummings + */ +final class Saml2Utils { + + private Saml2Utils() { + } + + static String samlEncode(byte[] b) { + return Base64.getEncoder().encodeToString(b); + } + + static byte[] samlDecode(String s) { + return Base64.getMimeDecoder().decode(s); + } + + static byte[] samlDeflate(String s) { + try { + ByteArrayOutputStream b = new ByteArrayOutputStream(); + DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(Deflater.DEFLATED, true)); + deflater.write(s.getBytes(StandardCharsets.UTF_8)); + deflater.finish(); + return b.toByteArray(); + } + catch (IOException ex) { + throw new Saml2Exception("Unable to deflate string", ex); + } + } + + static String samlInflate(byte[] b) { + try { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); + iout.write(b); + iout.finish(); + return new String(out.toByteArray(), StandardCharsets.UTF_8); + } + catch (IOException ex) { + throw new Saml2Exception("Unable to inflate string", ex); + } + } + + static EncodingConfigurer withDecoded(String decoded) { + return new EncodingConfigurer(decoded); + } + + static DecodingConfigurer withEncoded(String encoded) { + return new DecodingConfigurer(encoded); + } + + static final class EncodingConfigurer { + + private final String decoded; + + private boolean deflate; + + private EncodingConfigurer(String decoded) { + this.decoded = decoded; + } + + EncodingConfigurer deflate(boolean deflate) { + this.deflate = deflate; + return this; + } + + String encode() { + byte[] bytes = (this.deflate) ? Saml2Utils.samlDeflate(this.decoded) + : this.decoded.getBytes(StandardCharsets.UTF_8); + return Saml2Utils.samlEncode(bytes); + } + + } + + static final class DecodingConfigurer { + + private static final Base64Checker BASE_64_CHECKER = new Base64Checker(); + + private final String encoded; + + private boolean inflate; + + private boolean requireBase64; + + private DecodingConfigurer(String encoded) { + this.encoded = encoded; + } + + DecodingConfigurer inflate(boolean inflate) { + this.inflate = inflate; + return this; + } + + DecodingConfigurer requireBase64(boolean requireBase64) { + this.requireBase64 = requireBase64; + return this; + } + + String decode() { + if (this.requireBase64) { + BASE_64_CHECKER.checkAcceptable(this.encoded); + } + byte[] bytes = Saml2Utils.samlDecode(this.encoded); + return (this.inflate) ? Saml2Utils.samlInflate(bytes) : new String(bytes, StandardCharsets.UTF_8); + } + + static class Base64Checker { + + private static final int[] values = genValueMapping(); + + Base64Checker() { + + } + + private static int[] genValueMapping() { + byte[] alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + .getBytes(StandardCharsets.ISO_8859_1); + + int[] values = new int[256]; + Arrays.fill(values, -1); + for (int i = 0; i < alphabet.length; i++) { + values[alphabet[i] & 0xff] = i; + } + return values; + } + + boolean isAcceptable(String s) { + int goodChars = 0; + int lastGoodCharVal = -1; + + // count number of characters from Base64 alphabet + for (int i = 0; i < s.length(); i++) { + int val = values[0xff & s.charAt(i)]; + if (val != -1) { + lastGoodCharVal = val; + goodChars++; + } + } + + // in cases of an incomplete final chunk, ensure the unused bits are zero + switch (goodChars % 4) { + case 0: + return true; + case 2: + return (lastGoodCharVal & 0b1111) == 0; + case 3: + return (lastGoodCharVal & 0b11) == 0; + default: + return false; + } + } + + void checkAcceptable(String ins) { + if (!isAcceptable(ins)) { + throw new IllegalArgumentException("Failed to decode SAMLResponse"); + } + } + + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/metadata/OpenSaml5MetadataResolver.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/metadata/OpenSaml5MetadataResolver.java new file mode 100644 index 0000000000..d03e6cda94 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/metadata/OpenSaml5MetadataResolver.java @@ -0,0 +1,117 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.metadata; + +import java.util.function.Consumer; + +import org.opensaml.saml.saml2.metadata.EntityDescriptor; + +import org.springframework.security.saml2.core.OpenSamlInitializationService; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; + +/** + * Resolves the SAML 2.0 Relying Party Metadata for a given + * {@link RelyingPartyRegistration} using the OpenSAML API. + * + * @author Jakub Kubrynski + * @author Josh Cummings + * @since 5.4 + */ +public final class OpenSaml5MetadataResolver implements Saml2MetadataResolver { + + static { + OpenSamlInitializationService.initialize(); + } + + private final BaseOpenSamlMetadataResolver delegate; + + public OpenSaml5MetadataResolver() { + this.delegate = new BaseOpenSamlMetadataResolver(new OpenSaml5Template()); + } + + @Override + public String resolve(RelyingPartyRegistration relyingPartyRegistration) { + return this.delegate.resolve(relyingPartyRegistration); + } + + public String resolve(Iterable relyingPartyRegistrations) { + return this.delegate.resolve(relyingPartyRegistrations); + } + + /** + * Set a {@link Consumer} for modifying the OpenSAML {@link EntityDescriptor} + * @param entityDescriptorCustomizer a consumer that accepts an + * {@link EntityDescriptorParameters} + * @since 5.7 + */ + public void setEntityDescriptorCustomizer(Consumer entityDescriptorCustomizer) { + this.delegate.setEntityDescriptorCustomizer( + (parameters) -> entityDescriptorCustomizer.accept(new EntityDescriptorParameters(parameters))); + } + + /** + * Configure whether to pretty-print the metadata XML. This can be helpful when + * signing the metadata payload. + * + * @since 6.2 + **/ + public void setUsePrettyPrint(boolean usePrettyPrint) { + this.delegate.setUsePrettyPrint(usePrettyPrint); + } + + /** + * Configure whether to sign the metadata, defaults to {@code false}. + * + * @since 6.4 + */ + public void setSignMetadata(boolean signMetadata) { + this.delegate.setSignMetadata(signMetadata); + } + + /** + * A tuple containing an OpenSAML {@link EntityDescriptor} and its associated + * {@link RelyingPartyRegistration} + * + * @since 5.7 + */ + public static final class EntityDescriptorParameters { + + private final EntityDescriptor entityDescriptor; + + private final RelyingPartyRegistration registration; + + public EntityDescriptorParameters(EntityDescriptor entityDescriptor, RelyingPartyRegistration registration) { + this.entityDescriptor = entityDescriptor; + this.registration = registration; + } + + EntityDescriptorParameters(BaseOpenSamlMetadataResolver.EntityDescriptorParameters parameters) { + this.entityDescriptor = parameters.getEntityDescriptor(); + this.registration = parameters.getRelyingPartyRegistration(); + } + + public EntityDescriptor getEntityDescriptor() { + return this.entityDescriptor; + } + + public RelyingPartyRegistration getRelyingPartyRegistration() { + return this.registration; + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/metadata/OpenSaml5Template.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/metadata/OpenSaml5Template.java new file mode 100644 index 0000000000..e5a39122a3 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/metadata/OpenSaml5Template.java @@ -0,0 +1,617 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.metadata; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.xml.namespace.QName; + +import net.shibboleth.shared.resolver.CriteriaSet; +import net.shibboleth.shared.xml.SerializeSupport; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.core.xml.io.Unmarshaller; +import org.opensaml.core.xml.io.UnmarshallerFactory; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.criterion.ProtocolCriterion; +import org.opensaml.saml.ext.saml2delrestrict.Delegate; +import org.opensaml.saml.ext.saml2delrestrict.DelegationRestrictionType; +import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Condition; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.LogoutRequest; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.encryption.Decrypter; +import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialResolver; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; +import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.security.criteria.UsageCriterion; +import org.opensaml.security.x509.BasicX509Credential; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.DecryptionException; +import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * For internal use only. Subject to breaking changes at any time. + */ +final class OpenSaml5Template implements OpenSamlOperations { + + private static final Log logger = LogFactory.getLog(OpenSaml5Template.class); + + @Override + public T build(QName elementName) { + XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); + if (builder == null) { + throw new Saml2Exception("Unable to resolve Builder for " + elementName); + } + return (T) builder.buildObject(elementName); + } + + @Override + public T deserialize(String serialized) { + return deserialize(new ByteArrayInputStream(serialized.getBytes(StandardCharsets.UTF_8))); + } + + @Override + public T deserialize(InputStream serialized) { + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized); + Element element = document.getDocumentElement(); + UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory(); + Unmarshaller unmarshaller = factory.getUnmarshaller(element); + if (unmarshaller == null) { + throw new Saml2Exception("Unsupported element of type " + element.getTagName()); + } + return (T) unmarshaller.unmarshall(element); + } + catch (Saml2Exception ex) { + throw ex; + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize payload", ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(XMLObject object) { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + try { + return serialize(marshaller.marshall(object)); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(Element element) { + return new OpenSaml5SerializationConfigurer(element); + } + + @Override + public OpenSaml5SignatureConfigurer withSigningKeys(Collection credentials) { + return new OpenSaml5SignatureConfigurer(credentials); + } + + @Override + public OpenSaml5VerificationConfigurer withVerificationKeys(Collection credentials) { + return new OpenSaml5VerificationConfigurer(credentials); + } + + @Override + public OpenSaml5DecryptionConfigurer withDecryptionKeys(Collection credentials) { + return new OpenSaml5DecryptionConfigurer(credentials); + } + + OpenSaml5Template() { + + } + + static final class OpenSaml5SerializationConfigurer + implements SerializationConfigurer { + + private final Element element; + + boolean pretty; + + OpenSaml5SerializationConfigurer(Element element) { + this.element = element; + } + + @Override + public OpenSaml5SerializationConfigurer prettyPrint(boolean pretty) { + this.pretty = pretty; + return this; + } + + @Override + public String serialize() { + if (this.pretty) { + return SerializeSupport.prettyPrintXML(this.element); + } + return SerializeSupport.nodeToString(this.element); + } + + } + + static final class OpenSaml5SignatureConfigurer implements SignatureConfigurer { + + private final Collection credentials; + + private final Map components = new LinkedHashMap<>(); + + private List algs = List.of(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + + OpenSaml5SignatureConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public OpenSaml5SignatureConfigurer algorithms(List algs) { + this.algs = algs; + return this; + } + + @Override + public O sign(O object) { + SignatureSigningParameters parameters = resolveSigningParameters(); + try { + SignatureSupport.signObject(object, parameters); + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + return object; + } + + @Override + public Map sign(Map params) { + SignatureSigningParameters parameters = resolveSigningParameters(); + this.components.putAll(params); + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : this.components.entrySet()) { + builder.queryParam(component.getKey(), + UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + return this.components; + } + + private SignatureSigningParameters resolveSigningParameters() { + List credentials = resolveSigningCredentials(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(this.algs); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); + CriteriaSet criteria = new CriteriaSet(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { + final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); + + namedManager.setUseDefaultManager(true); + final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); + + // Generator for X509Credentials + final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); + x509Factory.setEmitEntityCertificate(true); + x509Factory.setEmitEntityCertificateChain(true); + + defaultManager.registerFactory(x509Factory); + + return namedManager; + } + + private List resolveSigningCredentials() { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : this.credentials) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + + } + + static final class OpenSaml5VerificationConfigurer implements VerificationConfigurer { + + private final Collection credentials; + + private String entityId; + + OpenSaml5VerificationConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public VerificationConfigurer entityId(String entityId) { + this.entityId = entityId; + return this; + } + + private SignatureTrustEngine trustEngine(Collection keys) { + Set credentials = new HashSet<>(); + for (Saml2X509Credential key : keys) { + BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); + cred.setUsageType(UsageType.SIGNING); + cred.setEntityId(this.entityId); + credentials.add(cred); + } + CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); + return new ExplicitKeySignatureTrustEngine(credentialsResolver, + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + } + + private CriteriaSet verificationCriteria(Issuer issuer) { + return new CriteriaSet(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())), + new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)), + new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + } + + @Override + public Collection verify(SignableXMLObject signable) { + if (signable instanceof StatusResponseType response) { + return verifySignature(response.getID(), response.getIssuer(), response.getSignature()); + } + if (signable instanceof RequestAbstractType request) { + return verifySignature(request.getID(), request.getIssuer(), request.getSignature()); + } + if (signable instanceof Assertion assertion) { + return verifySignature(assertion.getID(), assertion.getIssuer(), assertion.getSignature()); + } + throw new Saml2Exception("Unsupported object of type: " + signable.getClass().getName()); + } + + private Collection verifySignature(String id, Issuer issuer, Signature signature) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(issuer); + Collection errors = new ArrayList<>(); + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(signature); + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + try { + if (!trustEngine.validate(signature, criteria)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + return errors; + } + + @Override + public Collection verify(RedirectParameters parameters) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(parameters.getIssuer()); + if (parameters.getAlgorithm() == null) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature algorithm for object [" + parameters.getId() + "]")); + } + if (!parameters.hasSignature()) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature for object [" + parameters.getId() + "]")); + } + Collection errors = new ArrayList<>(); + String algorithmUri = parameters.getAlgorithm(); + try { + if (!trustEngine.validate(parameters.getSignature(), parameters.getContent(), algorithmUri, criteria, + null)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]: ")); + } + return errors; + } + + } + + static final class OpenSaml5DecryptionConfigurer implements DecryptionConfigurer { + + private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( + Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), + new SimpleRetrievalMethodEncryptedKeyResolver())); + + private final Decrypter decrypter; + + OpenSaml5DecryptionConfigurer(Collection decryptionCredentials) { + this.decrypter = decrypter(decryptionCredentials); + } + + private static Decrypter decrypter(Collection decryptionCredentials) { + Collection credentials = new ArrayList<>(); + for (Saml2X509Credential key : decryptionCredentials) { + Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); + credentials.add(cred); + } + KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials); + Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); + decrypter.setRootInNewDocument(true); + return decrypter; + } + + @Override + public void decrypt(XMLObject object) { + if (object instanceof Response response) { + decryptResponse(response); + return; + } + if (object instanceof Assertion assertion) { + decryptAssertion(assertion); + } + if (object instanceof LogoutRequest request) { + decryptLogoutRequest(request); + } + } + + /* + * The methods that follow are adapted from OpenSAML's {@link DecryptAssertions}, + * {@link DecryptNameIDs}, and {@link DecryptAttributes}. + * + *

The reason that these OpenSAML classes are not used directly is because they + * reference {@link javax.servlet.http.HttpServletRequest} which is a lower + * Servlet API version than what Spring Security SAML uses. + * + * If OpenSAML 5 updates to {@link jakarta.servlet.http.HttpServletRequest}, then + * this arrangement can be revisited. + */ + + private void decryptResponse(Response response) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + + int count = 0; + int size = response.getEncryptedAssertions().size(); + for (EncryptedAssertion encrypted : response.getEncryptedAssertions()) { + logger.trace(String.format("Decrypting EncryptedAssertion (%d/%d) in Response [%s]", count, size, + response.getID())); + try { + Assertion decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + count++; + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + response.getEncryptedAssertions().removeAll(encrypteds); + response.getAssertions().addAll(decrypteds); + + // Re-marshall the response so that any ID attributes within the decrypted + // Assertions + // will have their ID-ness re-established at the DOM level. + if (!decrypteds.isEmpty()) { + try { + XMLObjectSupport.marshall(response); + } + catch (final MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + } + + private void decryptAssertion(Assertion assertion) { + for (AttributeStatement statement : assertion.getAttributeStatements()) { + decryptAttributes(statement); + } + decryptSubject(assertion.getSubject()); + if (assertion.getConditions() != null) { + for (Condition c : assertion.getConditions().getConditions()) { + if (!(c instanceof DelegationRestrictionType delegation)) { + continue; + } + for (Delegate d : delegation.getDelegates()) { + if (d.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(d.getEncryptedID()); + if (decrypted != null) { + d.setNameID(decrypted); + d.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + } + + private void decryptAttributes(AttributeStatement statement) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + for (EncryptedAttribute encrypted : statement.getEncryptedAttributes()) { + try { + Attribute decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + statement.getEncryptedAttributes().removeAll(encrypteds); + statement.getAttributes().addAll(decrypteds); + } + + private void decryptSubject(Subject subject) { + if (subject != null) { + if (subject.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(subject.getEncryptedID()); + if (decrypted != null) { + subject.setNameID(decrypted); + subject.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + for (final SubjectConfirmation sc : subject.getSubjectConfirmations()) { + if (sc.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(sc.getEncryptedID()); + if (decrypted != null) { + sc.setNameID(decrypted); + sc.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + + private void decryptLogoutRequest(LogoutRequest request) { + if (request.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(request.getEncryptedID()); + if (decrypted != null) { + request.setNameID(decrypted); + request.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepository.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepository.java new file mode 100644 index 0000000000..381bb599b9 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepository.java @@ -0,0 +1,318 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.registration; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URL; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.function.Consumer; + +import javax.annotation.Nonnull; + +import net.shibboleth.shared.resolver.CriteriaSet; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.saml.criterion.EntityRoleCriterion; +import org.opensaml.saml.metadata.IterableMetadataSource; +import org.opensaml.saml.metadata.resolver.MetadataResolver; +import org.opensaml.saml.metadata.resolver.filter.impl.SignatureValidationFilter; +import org.opensaml.saml.metadata.resolver.impl.ResourceBackedMetadataResolver; +import org.opensaml.saml.metadata.resolver.index.impl.RoleMetadataIndex; +import org.opensaml.saml.saml2.metadata.EntityDescriptor; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; + +import org.springframework.core.io.DefaultResourceLoader; +import org.springframework.core.io.Resource; +import org.springframework.core.io.ResourceLoader; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.provider.service.registration.BaseOpenSamlAssertingPartyMetadataRepository.MetadataResolverAdapter; +import org.springframework.util.Assert; + +/** + * An implementation of {@link AssertingPartyMetadataRepository} that uses a + * {@link MetadataResolver} to retrieve {@link AssertingPartyMetadata} instances. + * + *

+ * The {@link MetadataResolver} constructed in {@link #withTrustedMetadataLocation} + * provides expiry-aware refreshing. + * + * @author Josh Cummings + * @since 6.4 + * @see AssertingPartyMetadataRepository + * @see RelyingPartyRegistrations + */ +public final class OpenSaml5AssertingPartyMetadataRepository implements AssertingPartyMetadataRepository { + + private final BaseOpenSamlAssertingPartyMetadataRepository delegate; + + /** + * Construct an {@link OpenSaml5AssertingPartyMetadataRepository} using the provided + * {@link MetadataResolver}. + * + *

+ * The {@link MetadataResolver} should either be of type + * {@link IterableMetadataSource} or it should have a {@link RoleMetadataIndex} + * configured. + * @param metadataResolver the {@link MetadataResolver} to use + */ + public OpenSaml5AssertingPartyMetadataRepository(MetadataResolver metadataResolver) { + Assert.notNull(metadataResolver, "metadataResolver cannot be null"); + this.delegate = new BaseOpenSamlAssertingPartyMetadataRepository( + new CriteriaSetResolverWrapper(metadataResolver)); + } + + /** + * {@inheritDoc} + */ + @Override + @NonNull + public Iterator iterator() { + return this.delegate.iterator(); + } + + /** + * {@inheritDoc} + */ + @Nullable + @Override + public AssertingPartyMetadata findByEntityId(String entityId) { + return this.delegate.findByEntityId(entityId); + } + + /** + * Use this trusted {@code metadataLocation} to retrieve refreshable, expiry-aware + * SAML 2.0 Asserting Party (IDP) metadata. + * + *

+ * Valid locations can be classpath- or file-based or they can be HTTPS endpoints. + * Some valid endpoints might include: + * + *

+	 *   metadataLocation = "classpath:asserting-party-metadata.xml";
+	 *   metadataLocation = "file:asserting-party-metadata.xml";
+	 *   metadataLocation = "https://ap.example.org/metadata";
+	 * 
+ * + *

+ * Resolution of location is attempted immediately. To defer, wrap in + * {@link CachingRelyingPartyRegistrationRepository}. + * @param metadataLocation the classpath- or file-based locations or HTTPS endpoints + * of the asserting party metadata file + * @return the {@link MetadataLocationRepositoryBuilder} for further configuration + */ + public static MetadataLocationRepositoryBuilder withTrustedMetadataLocation(String metadataLocation) { + return new MetadataLocationRepositoryBuilder(metadataLocation, true); + } + + /** + * Use this {@code metadataLocation} to retrieve refreshable, expiry-aware SAML 2.0 + * Asserting Party (IDP) metadata. Verification credentials are required. + * + *

+ * Valid locations can be classpath- or file-based or they can be remote endpoints. + * Some valid endpoints might include: + * + *

+	 *   metadataLocation = "classpath:asserting-party-metadata.xml";
+	 *   metadataLocation = "file:asserting-party-metadata.xml";
+	 *   metadataLocation = "https://ap.example.org/metadata";
+	 * 
+ * + *

+ * Resolution of location is attempted immediately. To defer, wrap in + * {@link CachingRelyingPartyRegistrationRepository}. + * @param metadataLocation the classpath- or file-based locations or remote endpoints + * of the asserting party metadata file + * @return the {@link MetadataLocationRepositoryBuilder} for further configuration + */ + public static MetadataLocationRepositoryBuilder withMetadataLocation(String metadataLocation) { + return new MetadataLocationRepositoryBuilder(metadataLocation, false); + } + + /** + * A builder class for configuring {@link OpenSaml5AssertingPartyMetadataRepository} + * for a specific metadata location. + * + * @author Josh Cummings + */ + public static final class MetadataLocationRepositoryBuilder { + + private final String metadataLocation; + + private final boolean requireVerificationCredentials; + + private final Collection verificationCredentials = new ArrayList<>(); + + private ResourceLoader resourceLoader = new DefaultResourceLoader(); + + MetadataLocationRepositoryBuilder(String metadataLocation, boolean trusted) { + this.metadataLocation = metadataLocation; + this.requireVerificationCredentials = !trusted; + } + + public MetadataLocationRepositoryBuilder verificationCredentials(Consumer> credentials) { + credentials.accept(this.verificationCredentials); + return this; + } + + public MetadataLocationRepositoryBuilder resourceLoader(ResourceLoader resourceLoader) { + this.resourceLoader = resourceLoader; + return this; + } + + public OpenSaml5AssertingPartyMetadataRepository build() { + return new OpenSaml5AssertingPartyMetadataRepository(metadataResolver()); + } + + private MetadataResolver metadataResolver() { + ResourceBackedMetadataResolver metadataResolver = resourceBackedMetadataResolver(); + boolean missingCredentials = this.requireVerificationCredentials && this.verificationCredentials.isEmpty(); + Assert.isTrue(!missingCredentials, "Verification credentials are required"); + return initialize(metadataResolver); + } + + private ResourceBackedMetadataResolver resourceBackedMetadataResolver() { + Resource resource = this.resourceLoader.getResource(this.metadataLocation); + try { + ResourceBackedMetadataResolver metadataResolver = new ResourceBackedMetadataResolver( + new SpringResource(resource)); + if (this.verificationCredentials.isEmpty()) { + return metadataResolver; + } + SignatureTrustEngine engine = new ExplicitKeySignatureTrustEngine( + new CollectionCredentialResolver(this.verificationCredentials), + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + SignatureValidationFilter filter = new SignatureValidationFilter(engine); + filter.setRequireSignedRoot(true); + metadataResolver.setMetadataFilter(filter); + filter.initialize(); + return metadataResolver; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private MetadataResolver initialize(ResourceBackedMetadataResolver metadataResolver) { + metadataResolver.setParserPool(XMLObjectProviderRegistrySupport.getParserPool()); + return BaseOpenSamlAssertingPartyMetadataRepository.initialize(metadataResolver); + } + + private static final class SpringResource implements net.shibboleth.shared.resource.Resource { + + private final Resource resource; + + SpringResource(Resource resource) { + this.resource = resource; + } + + @Override + public boolean exists() { + return this.resource.exists(); + } + + @Override + public boolean isReadable() { + return this.resource.isReadable(); + } + + @Override + public boolean isOpen() { + return this.resource.isOpen(); + } + + @Override + public URL getURL() throws IOException { + return this.resource.getURL(); + } + + @Override + public URI getURI() throws IOException { + return this.resource.getURI(); + } + + @Override + public File getFile() throws IOException { + return this.resource.getFile(); + } + + @Nonnull + @Override + public InputStream getInputStream() throws IOException { + return this.resource.getInputStream(); + } + + @Override + public long contentLength() throws IOException { + return this.resource.contentLength(); + } + + @Override + public long lastModified() throws IOException { + return this.resource.lastModified(); + } + + @Override + public net.shibboleth.shared.resource.Resource createRelativeResource(String relativePath) + throws IOException { + return new SpringResource(this.resource.createRelative(relativePath)); + } + + @Override + public String getFilename() { + return this.resource.getFilename(); + } + + @Override + public String getDescription() { + return this.resource.getDescription(); + } + + } + + } + + private static final class CriteriaSetResolverWrapper extends MetadataResolverAdapter { + + CriteriaSetResolverWrapper(MetadataResolver metadataResolver) { + super(metadataResolver); + } + + @Override + EntityDescriptor resolveSingle(EntityIdCriterion entityId) throws Exception { + return super.metadataResolver.resolveSingle(new CriteriaSet(entityId)); + } + + @Override + Iterable resolve(EntityRoleCriterion role) throws Exception { + return super.metadataResolver.resolve(new CriteriaSet(role)); + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5Template.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5Template.java new file mode 100644 index 0000000000..50f60e5a0d --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5Template.java @@ -0,0 +1,617 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.registration; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.xml.namespace.QName; + +import net.shibboleth.shared.resolver.CriteriaSet; +import net.shibboleth.shared.xml.SerializeSupport; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.core.xml.io.Unmarshaller; +import org.opensaml.core.xml.io.UnmarshallerFactory; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.criterion.ProtocolCriterion; +import org.opensaml.saml.ext.saml2delrestrict.Delegate; +import org.opensaml.saml.ext.saml2delrestrict.DelegationRestrictionType; +import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Condition; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.LogoutRequest; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.encryption.Decrypter; +import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialResolver; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; +import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.security.criteria.UsageCriterion; +import org.opensaml.security.x509.BasicX509Credential; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.DecryptionException; +import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * For internal use only. Subject to breaking changes at any time. + */ +final class OpenSaml5Template implements OpenSamlOperations { + + private static final Log logger = LogFactory.getLog(OpenSaml5Template.class); + + @Override + public T build(QName elementName) { + XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); + if (builder == null) { + throw new Saml2Exception("Unable to resolve Builder for " + elementName); + } + return (T) builder.buildObject(elementName); + } + + @Override + public T deserialize(String serialized) { + return deserialize(new ByteArrayInputStream(serialized.getBytes(StandardCharsets.UTF_8))); + } + + @Override + public T deserialize(InputStream serialized) { + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized); + Element element = document.getDocumentElement(); + UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory(); + Unmarshaller unmarshaller = factory.getUnmarshaller(element); + if (unmarshaller == null) { + throw new Saml2Exception("Unsupported element of type " + element.getTagName()); + } + return (T) unmarshaller.unmarshall(element); + } + catch (Saml2Exception ex) { + throw ex; + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize payload", ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(XMLObject object) { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + try { + return serialize(marshaller.marshall(object)); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(Element element) { + return new OpenSaml5SerializationConfigurer(element); + } + + @Override + public OpenSaml5SignatureConfigurer withSigningKeys(Collection credentials) { + return new OpenSaml5SignatureConfigurer(credentials); + } + + @Override + public OpenSaml5VerificationConfigurer withVerificationKeys(Collection credentials) { + return new OpenSaml5VerificationConfigurer(credentials); + } + + @Override + public OpenSaml5DecryptionConfigurer withDecryptionKeys(Collection credentials) { + return new OpenSaml5DecryptionConfigurer(credentials); + } + + OpenSaml5Template() { + + } + + static final class OpenSaml5SerializationConfigurer + implements SerializationConfigurer { + + private final Element element; + + boolean pretty; + + OpenSaml5SerializationConfigurer(Element element) { + this.element = element; + } + + @Override + public OpenSaml5SerializationConfigurer prettyPrint(boolean pretty) { + this.pretty = pretty; + return this; + } + + @Override + public String serialize() { + if (this.pretty) { + return SerializeSupport.prettyPrintXML(this.element); + } + return SerializeSupport.nodeToString(this.element); + } + + } + + static final class OpenSaml5SignatureConfigurer implements SignatureConfigurer { + + private final Collection credentials; + + private final Map components = new LinkedHashMap<>(); + + private List algs = List.of(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + + OpenSaml5SignatureConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public OpenSaml5SignatureConfigurer algorithms(List algs) { + this.algs = algs; + return this; + } + + @Override + public O sign(O object) { + SignatureSigningParameters parameters = resolveSigningParameters(); + try { + SignatureSupport.signObject(object, parameters); + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + return object; + } + + @Override + public Map sign(Map params) { + SignatureSigningParameters parameters = resolveSigningParameters(); + this.components.putAll(params); + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : this.components.entrySet()) { + builder.queryParam(component.getKey(), + UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + return this.components; + } + + private SignatureSigningParameters resolveSigningParameters() { + List credentials = resolveSigningCredentials(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(this.algs); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); + CriteriaSet criteria = new CriteriaSet(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { + final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); + + namedManager.setUseDefaultManager(true); + final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); + + // Generator for X509Credentials + final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); + x509Factory.setEmitEntityCertificate(true); + x509Factory.setEmitEntityCertificateChain(true); + + defaultManager.registerFactory(x509Factory); + + return namedManager; + } + + private List resolveSigningCredentials() { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : this.credentials) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + + } + + static final class OpenSaml5VerificationConfigurer implements VerificationConfigurer { + + private final Collection credentials; + + private String entityId; + + OpenSaml5VerificationConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public VerificationConfigurer entityId(String entityId) { + this.entityId = entityId; + return this; + } + + private SignatureTrustEngine trustEngine(Collection keys) { + Set credentials = new HashSet<>(); + for (Saml2X509Credential key : keys) { + BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); + cred.setUsageType(UsageType.SIGNING); + cred.setEntityId(this.entityId); + credentials.add(cred); + } + CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); + return new ExplicitKeySignatureTrustEngine(credentialsResolver, + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + } + + private CriteriaSet verificationCriteria(Issuer issuer) { + return new CriteriaSet(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())), + new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)), + new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + } + + @Override + public Collection verify(SignableXMLObject signable) { + if (signable instanceof StatusResponseType response) { + return verifySignature(response.getID(), response.getIssuer(), response.getSignature()); + } + if (signable instanceof RequestAbstractType request) { + return verifySignature(request.getID(), request.getIssuer(), request.getSignature()); + } + if (signable instanceof Assertion assertion) { + return verifySignature(assertion.getID(), assertion.getIssuer(), assertion.getSignature()); + } + throw new Saml2Exception("Unsupported object of type: " + signable.getClass().getName()); + } + + private Collection verifySignature(String id, Issuer issuer, Signature signature) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(issuer); + Collection errors = new ArrayList<>(); + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(signature); + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + try { + if (!trustEngine.validate(signature, criteria)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + return errors; + } + + @Override + public Collection verify(RedirectParameters parameters) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(parameters.getIssuer()); + if (parameters.getAlgorithm() == null) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature algorithm for object [" + parameters.getId() + "]")); + } + if (!parameters.hasSignature()) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature for object [" + parameters.getId() + "]")); + } + Collection errors = new ArrayList<>(); + String algorithmUri = parameters.getAlgorithm(); + try { + if (!trustEngine.validate(parameters.getSignature(), parameters.getContent(), algorithmUri, criteria, + null)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]: ")); + } + return errors; + } + + } + + static final class OpenSaml5DecryptionConfigurer implements DecryptionConfigurer { + + private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( + Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), + new SimpleRetrievalMethodEncryptedKeyResolver())); + + private final Decrypter decrypter; + + OpenSaml5DecryptionConfigurer(Collection decryptionCredentials) { + this.decrypter = decrypter(decryptionCredentials); + } + + private static Decrypter decrypter(Collection decryptionCredentials) { + Collection credentials = new ArrayList<>(); + for (Saml2X509Credential key : decryptionCredentials) { + Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); + credentials.add(cred); + } + KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials); + Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); + decrypter.setRootInNewDocument(true); + return decrypter; + } + + @Override + public void decrypt(XMLObject object) { + if (object instanceof Response response) { + decryptResponse(response); + return; + } + if (object instanceof Assertion assertion) { + decryptAssertion(assertion); + } + if (object instanceof LogoutRequest request) { + decryptLogoutRequest(request); + } + } + + /* + * The methods that follow are adapted from OpenSAML's {@link DecryptAssertions}, + * {@link DecryptNameIDs}, and {@link DecryptAttributes}. + * + *

The reason that these OpenSAML classes are not used directly is because they + * reference {@link javax.servlet.http.HttpServletRequest} which is a lower + * Servlet API version than what Spring Security SAML uses. + * + * If OpenSAML 5 updates to {@link jakarta.servlet.http.HttpServletRequest}, then + * this arrangement can be revisited. + */ + + private void decryptResponse(Response response) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + + int count = 0; + int size = response.getEncryptedAssertions().size(); + for (EncryptedAssertion encrypted : response.getEncryptedAssertions()) { + logger.trace(String.format("Decrypting EncryptedAssertion (%d/%d) in Response [%s]", count, size, + response.getID())); + try { + Assertion decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + count++; + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + response.getEncryptedAssertions().removeAll(encrypteds); + response.getAssertions().addAll(decrypteds); + + // Re-marshall the response so that any ID attributes within the decrypted + // Assertions + // will have their ID-ness re-established at the DOM level. + if (!decrypteds.isEmpty()) { + try { + XMLObjectSupport.marshall(response); + } + catch (final MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + } + + private void decryptAssertion(Assertion assertion) { + for (AttributeStatement statement : assertion.getAttributeStatements()) { + decryptAttributes(statement); + } + decryptSubject(assertion.getSubject()); + if (assertion.getConditions() != null) { + for (Condition c : assertion.getConditions().getConditions()) { + if (!(c instanceof DelegationRestrictionType delegation)) { + continue; + } + for (Delegate d : delegation.getDelegates()) { + if (d.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(d.getEncryptedID()); + if (decrypted != null) { + d.setNameID(decrypted); + d.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + } + + private void decryptAttributes(AttributeStatement statement) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + for (EncryptedAttribute encrypted : statement.getEncryptedAttributes()) { + try { + Attribute decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + statement.getEncryptedAttributes().removeAll(encrypteds); + statement.getAttributes().addAll(decrypteds); + } + + private void decryptSubject(Subject subject) { + if (subject != null) { + if (subject.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(subject.getEncryptedID()); + if (decrypted != null) { + subject.setNameID(decrypted); + subject.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + for (final SubjectConfirmation sc : subject.getSubjectConfirmations()) { + if (sc.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(sc.getEncryptedID()); + if (decrypted != null) { + sc.setNameID(decrypted); + sc.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + + private void decryptLogoutRequest(LogoutRequest request) { + if (request.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(request.getEncryptedID()); + if (decrypted != null) { + request.setNameID(decrypted); + request.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverter.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverter.java new file mode 100644 index 0000000000..48a2bd35a2 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverter.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; + +/** + * An {@link AuthenticationConverter} that generates a {@link Saml2AuthenticationToken} + * appropriate for authenticated a SAML 2.0 Assertion against an + * {@link org.springframework.security.authentication.AuthenticationManager}. + * + * @author Josh Cummings + * @since 6.1 + */ +public final class OpenSaml5AuthenticationTokenConverter implements AuthenticationConverter { + + private final BaseOpenSamlAuthenticationTokenConverter delegate; + + /** + * Constructs a {@link OpenSaml5AuthenticationTokenConverter} given a repository for + * {@link RelyingPartyRegistration}s + * @param registrations the repository for {@link RelyingPartyRegistration}s + * {@link RelyingPartyRegistration}s + */ + public OpenSaml5AuthenticationTokenConverter(RelyingPartyRegistrationRepository registrations) { + Assert.notNull(registrations, "relyingPartyRegistrationRepository cannot be null"); + this.delegate = new BaseOpenSamlAuthenticationTokenConverter(registrations, new OpenSaml5Template()); + } + + /** + * Resolve an authentication request from the given {@link HttpServletRequest}. + * + *

+ * First uses the configured {@link RequestMatcher} to deduce whether an + * authentication request is being made and optionally for which + * {@code registrationId}. + * + *

+ * If there is an associated {@code }, then the + * {@code registrationId} is looked up and used. + * + *

+ * If a {@code registrationId} is found in the request, then it is looked up and used. + * In that case, if none is found a {@link Saml2AuthenticationException} is thrown. + * + *

+ * Finally, if no {@code registrationId} is found in the request, then the code + * attempts to resolve the {@link RelyingPartyRegistration} from the SAML Response's + * Issuer. + * @param request the HTTP request + * @return the {@link Saml2AuthenticationToken} authentication request + * @throws Saml2AuthenticationException if the {@link RequestMatcher} specifies a + * non-existent {@code registrationId} + */ + @Override + public Saml2AuthenticationToken convert(HttpServletRequest request) { + return this.delegate.convert(request); + } + + /** + * Use the given {@link Saml2AuthenticationRequestRepository} to load authentication + * request. + * @param authenticationRequestRepository the + * {@link Saml2AuthenticationRequestRepository} to use + */ + public void setAuthenticationRequestRepository( + Saml2AuthenticationRequestRepository authenticationRequestRepository) { + Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null"); + this.delegate.setAuthenticationRequestRepository(authenticationRequestRepository); + } + + /** + * Use the given {@link RequestMatcher} to match the request. + * @param requestMatcher the {@link RequestMatcher} to use + */ + public void setRequestMatcher(RequestMatcher requestMatcher) { + Assert.notNull(requestMatcher, "requestMatcher cannot be null"); + this.delegate.setRequestMatcher(requestMatcher); + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/OpenSaml5Template.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/OpenSaml5Template.java new file mode 100644 index 0000000000..73b3d9f391 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/OpenSaml5Template.java @@ -0,0 +1,617 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.xml.namespace.QName; + +import net.shibboleth.shared.resolver.CriteriaSet; +import net.shibboleth.shared.xml.SerializeSupport; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.core.xml.io.Unmarshaller; +import org.opensaml.core.xml.io.UnmarshallerFactory; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.criterion.ProtocolCriterion; +import org.opensaml.saml.ext.saml2delrestrict.Delegate; +import org.opensaml.saml.ext.saml2delrestrict.DelegationRestrictionType; +import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Condition; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.LogoutRequest; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.encryption.Decrypter; +import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialResolver; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; +import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.security.criteria.UsageCriterion; +import org.opensaml.security.x509.BasicX509Credential; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.DecryptionException; +import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * For internal use only. Subject to breaking changes at any time. + */ +final class OpenSaml5Template implements OpenSamlOperations { + + private static final Log logger = LogFactory.getLog(OpenSaml5Template.class); + + @Override + public T build(QName elementName) { + XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); + if (builder == null) { + throw new Saml2Exception("Unable to resolve Builder for " + elementName); + } + return (T) builder.buildObject(elementName); + } + + @Override + public T deserialize(String serialized) { + return deserialize(new ByteArrayInputStream(serialized.getBytes(StandardCharsets.UTF_8))); + } + + @Override + public T deserialize(InputStream serialized) { + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized); + Element element = document.getDocumentElement(); + UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory(); + Unmarshaller unmarshaller = factory.getUnmarshaller(element); + if (unmarshaller == null) { + throw new Saml2Exception("Unsupported element of type " + element.getTagName()); + } + return (T) unmarshaller.unmarshall(element); + } + catch (Saml2Exception ex) { + throw ex; + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize payload", ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(XMLObject object) { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + try { + return serialize(marshaller.marshall(object)); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(Element element) { + return new OpenSaml5SerializationConfigurer(element); + } + + @Override + public OpenSaml5SignatureConfigurer withSigningKeys(Collection credentials) { + return new OpenSaml5SignatureConfigurer(credentials); + } + + @Override + public OpenSaml5VerificationConfigurer withVerificationKeys(Collection credentials) { + return new OpenSaml5VerificationConfigurer(credentials); + } + + @Override + public OpenSaml5DecryptionConfigurer withDecryptionKeys(Collection credentials) { + return new OpenSaml5DecryptionConfigurer(credentials); + } + + OpenSaml5Template() { + + } + + static final class OpenSaml5SerializationConfigurer + implements SerializationConfigurer { + + private final Element element; + + boolean pretty; + + OpenSaml5SerializationConfigurer(Element element) { + this.element = element; + } + + @Override + public OpenSaml5SerializationConfigurer prettyPrint(boolean pretty) { + this.pretty = pretty; + return this; + } + + @Override + public String serialize() { + if (this.pretty) { + return SerializeSupport.prettyPrintXML(this.element); + } + return SerializeSupport.nodeToString(this.element); + } + + } + + static final class OpenSaml5SignatureConfigurer implements SignatureConfigurer { + + private final Collection credentials; + + private final Map components = new LinkedHashMap<>(); + + private List algs = List.of(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + + OpenSaml5SignatureConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public OpenSaml5SignatureConfigurer algorithms(List algs) { + this.algs = algs; + return this; + } + + @Override + public O sign(O object) { + SignatureSigningParameters parameters = resolveSigningParameters(); + try { + SignatureSupport.signObject(object, parameters); + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + return object; + } + + @Override + public Map sign(Map params) { + SignatureSigningParameters parameters = resolveSigningParameters(); + this.components.putAll(params); + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : this.components.entrySet()) { + builder.queryParam(component.getKey(), + UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + return this.components; + } + + private SignatureSigningParameters resolveSigningParameters() { + List credentials = resolveSigningCredentials(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(this.algs); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); + CriteriaSet criteria = new CriteriaSet(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { + final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); + + namedManager.setUseDefaultManager(true); + final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); + + // Generator for X509Credentials + final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); + x509Factory.setEmitEntityCertificate(true); + x509Factory.setEmitEntityCertificateChain(true); + + defaultManager.registerFactory(x509Factory); + + return namedManager; + } + + private List resolveSigningCredentials() { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : this.credentials) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + + } + + static final class OpenSaml5VerificationConfigurer implements VerificationConfigurer { + + private final Collection credentials; + + private String entityId; + + OpenSaml5VerificationConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public VerificationConfigurer entityId(String entityId) { + this.entityId = entityId; + return this; + } + + private SignatureTrustEngine trustEngine(Collection keys) { + Set credentials = new HashSet<>(); + for (Saml2X509Credential key : keys) { + BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); + cred.setUsageType(UsageType.SIGNING); + cred.setEntityId(this.entityId); + credentials.add(cred); + } + CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); + return new ExplicitKeySignatureTrustEngine(credentialsResolver, + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + } + + private CriteriaSet verificationCriteria(Issuer issuer) { + return new CriteriaSet(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())), + new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)), + new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + } + + @Override + public Collection verify(SignableXMLObject signable) { + if (signable instanceof StatusResponseType response) { + return verifySignature(response.getID(), response.getIssuer(), response.getSignature()); + } + if (signable instanceof RequestAbstractType request) { + return verifySignature(request.getID(), request.getIssuer(), request.getSignature()); + } + if (signable instanceof Assertion assertion) { + return verifySignature(assertion.getID(), assertion.getIssuer(), assertion.getSignature()); + } + throw new Saml2Exception("Unsupported object of type: " + signable.getClass().getName()); + } + + private Collection verifySignature(String id, Issuer issuer, Signature signature) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(issuer); + Collection errors = new ArrayList<>(); + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(signature); + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + try { + if (!trustEngine.validate(signature, criteria)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + return errors; + } + + @Override + public Collection verify(RedirectParameters parameters) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(parameters.getIssuer()); + if (parameters.getAlgorithm() == null) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature algorithm for object [" + parameters.getId() + "]")); + } + if (!parameters.hasSignature()) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature for object [" + parameters.getId() + "]")); + } + Collection errors = new ArrayList<>(); + String algorithmUri = parameters.getAlgorithm(); + try { + if (!trustEngine.validate(parameters.getSignature(), parameters.getContent(), algorithmUri, criteria, + null)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]: ")); + } + return errors; + } + + } + + static final class OpenSaml5DecryptionConfigurer implements DecryptionConfigurer { + + private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( + Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), + new SimpleRetrievalMethodEncryptedKeyResolver())); + + private final Decrypter decrypter; + + OpenSaml5DecryptionConfigurer(Collection decryptionCredentials) { + this.decrypter = decrypter(decryptionCredentials); + } + + private static Decrypter decrypter(Collection decryptionCredentials) { + Collection credentials = new ArrayList<>(); + for (Saml2X509Credential key : decryptionCredentials) { + Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); + credentials.add(cred); + } + KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials); + Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); + decrypter.setRootInNewDocument(true); + return decrypter; + } + + @Override + public void decrypt(XMLObject object) { + if (object instanceof Response response) { + decryptResponse(response); + return; + } + if (object instanceof Assertion assertion) { + decryptAssertion(assertion); + } + if (object instanceof LogoutRequest request) { + decryptLogoutRequest(request); + } + } + + /* + * The methods that follow are adapted from OpenSAML's {@link DecryptAssertions}, + * {@link DecryptNameIDs}, and {@link DecryptAttributes}. + * + *

The reason that these OpenSAML classes are not used directly is because they + * reference {@link javax.servlet.http.HttpServletRequest} which is a lower + * Servlet API version than what Spring Security SAML uses. + * + * If OpenSAML 5 updates to {@link jakarta.servlet.http.HttpServletRequest}, then + * this arrangement can be revisited. + */ + + private void decryptResponse(Response response) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + + int count = 0; + int size = response.getEncryptedAssertions().size(); + for (EncryptedAssertion encrypted : response.getEncryptedAssertions()) { + logger.trace(String.format("Decrypting EncryptedAssertion (%d/%d) in Response [%s]", count, size, + response.getID())); + try { + Assertion decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + count++; + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + response.getEncryptedAssertions().removeAll(encrypteds); + response.getAssertions().addAll(decrypteds); + + // Re-marshall the response so that any ID attributes within the decrypted + // Assertions + // will have their ID-ness re-established at the DOM level. + if (!decrypteds.isEmpty()) { + try { + XMLObjectSupport.marshall(response); + } + catch (final MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + } + + private void decryptAssertion(Assertion assertion) { + for (AttributeStatement statement : assertion.getAttributeStatements()) { + decryptAttributes(statement); + } + decryptSubject(assertion.getSubject()); + if (assertion.getConditions() != null) { + for (Condition c : assertion.getConditions().getConditions()) { + if (!(c instanceof DelegationRestrictionType delegation)) { + continue; + } + for (Delegate d : delegation.getDelegates()) { + if (d.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(d.getEncryptedID()); + if (decrypted != null) { + d.setNameID(decrypted); + d.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + } + + private void decryptAttributes(AttributeStatement statement) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + for (EncryptedAttribute encrypted : statement.getEncryptedAttributes()) { + try { + Attribute decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + statement.getEncryptedAttributes().removeAll(encrypteds); + statement.getAttributes().addAll(decrypteds); + } + + private void decryptSubject(Subject subject) { + if (subject != null) { + if (subject.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(subject.getEncryptedID()); + if (decrypted != null) { + subject.setNameID(decrypted); + subject.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + for (final SubjectConfirmation sc : subject.getSubjectConfirmations()) { + if (sc.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(sc.getEncryptedID()); + if (decrypted != null) { + sc.setNameID(decrypted); + sc.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + + private void decryptLogoutRequest(LogoutRequest request) { + if (request.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(request.getEncryptedID()); + if (decrypted != null) { + request.setNameID(decrypted); + request.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolver.java new file mode 100644 index 0000000000..7b921d1b45 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolver.java @@ -0,0 +1,144 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import java.time.Clock; +import java.time.Instant; +import java.util.function.Consumer; + +import jakarta.servlet.http.HttpServletRequest; +import org.opensaml.saml.saml2.core.AuthnRequest; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; + +/** + * A strategy for resolving a SAML 2.0 Authentication Request from the + * {@link HttpServletRequest} using OpenSAML. + * + * @author Josh Cummings + * @since 5.7 + */ +public final class OpenSaml5AuthenticationRequestResolver implements Saml2AuthenticationRequestResolver { + + private final BaseOpenSamlAuthenticationRequestResolver delegate; + + /** + * Construct an {@link OpenSaml5AuthenticationRequestResolver} + * @param registrations a repository for relying and asserting party configuration + * @since 6.1 + */ + public OpenSaml5AuthenticationRequestResolver(RelyingPartyRegistrationRepository registrations) { + this.delegate = new BaseOpenSamlAuthenticationRequestResolver((request, id) -> { + if (id == null) { + return null; + } + return registrations.findByRegistrationId(id); + }, new OpenSaml5Template()); + } + + /** + * Construct a {@link OpenSaml5AuthenticationRequestResolver} + */ + public OpenSaml5AuthenticationRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + this.delegate = new BaseOpenSamlAuthenticationRequestResolver(relyingPartyRegistrationResolver, + new OpenSaml5Template()); + } + + @Override + public T resolve(HttpServletRequest request) { + return this.delegate.resolve(request); + } + + /** + * Set a {@link Consumer} for modifying the OpenSAML {@link AuthnRequest} + * @param contextConsumer a consumer that accepts an {@link AuthnRequestContext} + */ + public void setAuthnRequestCustomizer(Consumer contextConsumer) { + Assert.notNull(contextConsumer, "contextConsumer cannot be null"); + this.delegate.setParametersConsumer( + (parameters) -> contextConsumer.accept(new AuthnRequestContext(parameters.getRequest(), + parameters.getRelyingPartyRegistration(), parameters.getAuthnRequest()))); + } + + /** + * Set the {@link RequestMatcher} to use for setting the + * {@link BaseOpenSamlAuthenticationRequestResolver#setRequestMatcher(RequestMatcher)} + * (RequestMatcher)} + * @param requestMatcher the {@link RequestMatcher} to identify authentication + * requests. + * @since 5.8 + */ + public void setRequestMatcher(RequestMatcher requestMatcher) { + Assert.notNull(requestMatcher, "requestMatcher cannot be null"); + this.delegate.setRequestMatcher(requestMatcher); + } + + /** + * Use this {@link Clock} for generating the issued {@link Instant} + * @param clock the {@link Clock} to use + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "clock must not be null"); + this.delegate.setClock(clock); + } + + /** + * Use this {@link Converter} to compute the RelayState + * @param relayStateResolver the {@link Converter} to use + * @since 5.8 + */ + public void setRelayStateResolver(Converter relayStateResolver) { + Assert.notNull(relayStateResolver, "relayStateResolver cannot be null"); + this.delegate.setRelayStateResolver(relayStateResolver); + } + + public static final class AuthnRequestContext { + + private final HttpServletRequest request; + + private final RelyingPartyRegistration registration; + + private final AuthnRequest authnRequest; + + public AuthnRequestContext(HttpServletRequest request, RelyingPartyRegistration registration, + AuthnRequest authnRequest) { + this.request = request; + this.registration = registration; + this.authnRequest = authnRequest; + } + + public HttpServletRequest getRequest() { + return this.request; + } + + public RelyingPartyRegistration getRelyingPartyRegistration() { + return this.registration; + } + + public AuthnRequest getAuthnRequest() { + return this.authnRequest; + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5Template.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5Template.java new file mode 100644 index 0000000000..db40f084ee --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5Template.java @@ -0,0 +1,617 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.xml.namespace.QName; + +import net.shibboleth.shared.resolver.CriteriaSet; +import net.shibboleth.shared.xml.SerializeSupport; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.core.xml.io.Unmarshaller; +import org.opensaml.core.xml.io.UnmarshallerFactory; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.criterion.ProtocolCriterion; +import org.opensaml.saml.ext.saml2delrestrict.Delegate; +import org.opensaml.saml.ext.saml2delrestrict.DelegationRestrictionType; +import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Condition; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.LogoutRequest; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.encryption.Decrypter; +import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialResolver; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; +import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.security.criteria.UsageCriterion; +import org.opensaml.security.x509.BasicX509Credential; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.DecryptionException; +import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * For internal use only. Subject to breaking changes at any time. + */ +final class OpenSaml5Template implements OpenSamlOperations { + + private static final Log logger = LogFactory.getLog(OpenSaml5Template.class); + + @Override + public T build(QName elementName) { + XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); + if (builder == null) { + throw new Saml2Exception("Unable to resolve Builder for " + elementName); + } + return (T) builder.buildObject(elementName); + } + + @Override + public T deserialize(String serialized) { + return deserialize(new ByteArrayInputStream(serialized.getBytes(StandardCharsets.UTF_8))); + } + + @Override + public T deserialize(InputStream serialized) { + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized); + Element element = document.getDocumentElement(); + UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory(); + Unmarshaller unmarshaller = factory.getUnmarshaller(element); + if (unmarshaller == null) { + throw new Saml2Exception("Unsupported element of type " + element.getTagName()); + } + return (T) unmarshaller.unmarshall(element); + } + catch (Saml2Exception ex) { + throw ex; + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize payload", ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(XMLObject object) { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + try { + return serialize(marshaller.marshall(object)); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(Element element) { + return new OpenSaml5SerializationConfigurer(element); + } + + @Override + public OpenSaml5SignatureConfigurer withSigningKeys(Collection credentials) { + return new OpenSaml5SignatureConfigurer(credentials); + } + + @Override + public OpenSaml5VerificationConfigurer withVerificationKeys(Collection credentials) { + return new OpenSaml5VerificationConfigurer(credentials); + } + + @Override + public OpenSaml5DecryptionConfigurer withDecryptionKeys(Collection credentials) { + return new OpenSaml5DecryptionConfigurer(credentials); + } + + OpenSaml5Template() { + + } + + static final class OpenSaml5SerializationConfigurer + implements SerializationConfigurer { + + private final Element element; + + boolean pretty; + + OpenSaml5SerializationConfigurer(Element element) { + this.element = element; + } + + @Override + public OpenSaml5SerializationConfigurer prettyPrint(boolean pretty) { + this.pretty = pretty; + return this; + } + + @Override + public String serialize() { + if (this.pretty) { + return SerializeSupport.prettyPrintXML(this.element); + } + return SerializeSupport.nodeToString(this.element); + } + + } + + static final class OpenSaml5SignatureConfigurer implements SignatureConfigurer { + + private final Collection credentials; + + private final Map components = new LinkedHashMap<>(); + + private List algs = List.of(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + + OpenSaml5SignatureConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public OpenSaml5SignatureConfigurer algorithms(List algs) { + this.algs = algs; + return this; + } + + @Override + public O sign(O object) { + SignatureSigningParameters parameters = resolveSigningParameters(); + try { + SignatureSupport.signObject(object, parameters); + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + return object; + } + + @Override + public Map sign(Map params) { + SignatureSigningParameters parameters = resolveSigningParameters(); + this.components.putAll(params); + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : this.components.entrySet()) { + builder.queryParam(component.getKey(), + UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + return this.components; + } + + private SignatureSigningParameters resolveSigningParameters() { + List credentials = resolveSigningCredentials(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(this.algs); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); + CriteriaSet criteria = new CriteriaSet(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { + final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); + + namedManager.setUseDefaultManager(true); + final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); + + // Generator for X509Credentials + final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); + x509Factory.setEmitEntityCertificate(true); + x509Factory.setEmitEntityCertificateChain(true); + + defaultManager.registerFactory(x509Factory); + + return namedManager; + } + + private List resolveSigningCredentials() { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : this.credentials) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + + } + + static final class OpenSaml5VerificationConfigurer implements VerificationConfigurer { + + private final Collection credentials; + + private String entityId; + + OpenSaml5VerificationConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public VerificationConfigurer entityId(String entityId) { + this.entityId = entityId; + return this; + } + + private SignatureTrustEngine trustEngine(Collection keys) { + Set credentials = new HashSet<>(); + for (Saml2X509Credential key : keys) { + BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); + cred.setUsageType(UsageType.SIGNING); + cred.setEntityId(this.entityId); + credentials.add(cred); + } + CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); + return new ExplicitKeySignatureTrustEngine(credentialsResolver, + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + } + + private CriteriaSet verificationCriteria(Issuer issuer) { + return new CriteriaSet(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())), + new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)), + new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + } + + @Override + public Collection verify(SignableXMLObject signable) { + if (signable instanceof StatusResponseType response) { + return verifySignature(response.getID(), response.getIssuer(), response.getSignature()); + } + if (signable instanceof RequestAbstractType request) { + return verifySignature(request.getID(), request.getIssuer(), request.getSignature()); + } + if (signable instanceof Assertion assertion) { + return verifySignature(assertion.getID(), assertion.getIssuer(), assertion.getSignature()); + } + throw new Saml2Exception("Unsupported object of type: " + signable.getClass().getName()); + } + + private Collection verifySignature(String id, Issuer issuer, Signature signature) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(issuer); + Collection errors = new ArrayList<>(); + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(signature); + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + try { + if (!trustEngine.validate(signature, criteria)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + return errors; + } + + @Override + public Collection verify(RedirectParameters parameters) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(parameters.getIssuer()); + if (parameters.getAlgorithm() == null) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature algorithm for object [" + parameters.getId() + "]")); + } + if (!parameters.hasSignature()) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature for object [" + parameters.getId() + "]")); + } + Collection errors = new ArrayList<>(); + String algorithmUri = parameters.getAlgorithm(); + try { + if (!trustEngine.validate(parameters.getSignature(), parameters.getContent(), algorithmUri, criteria, + null)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]: ")); + } + return errors; + } + + } + + static final class OpenSaml5DecryptionConfigurer implements DecryptionConfigurer { + + private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( + Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), + new SimpleRetrievalMethodEncryptedKeyResolver())); + + private final Decrypter decrypter; + + OpenSaml5DecryptionConfigurer(Collection decryptionCredentials) { + this.decrypter = decrypter(decryptionCredentials); + } + + private static Decrypter decrypter(Collection decryptionCredentials) { + Collection credentials = new ArrayList<>(); + for (Saml2X509Credential key : decryptionCredentials) { + Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); + credentials.add(cred); + } + KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials); + Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); + decrypter.setRootInNewDocument(true); + return decrypter; + } + + @Override + public void decrypt(XMLObject object) { + if (object instanceof Response response) { + decryptResponse(response); + return; + } + if (object instanceof Assertion assertion) { + decryptAssertion(assertion); + } + if (object instanceof LogoutRequest request) { + decryptLogoutRequest(request); + } + } + + /* + * The methods that follow are adapted from OpenSAML's {@link DecryptAssertions}, + * {@link DecryptNameIDs}, and {@link DecryptAttributes}. + * + *

The reason that these OpenSAML classes are not used directly is because they + * reference {@link javax.servlet.http.HttpServletRequest} which is a lower + * Servlet API version than what Spring Security SAML uses. + * + * If OpenSAML 5 updates to {@link jakarta.servlet.http.HttpServletRequest}, then + * this arrangement can be revisited. + */ + + private void decryptResponse(Response response) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + + int count = 0; + int size = response.getEncryptedAssertions().size(); + for (EncryptedAssertion encrypted : response.getEncryptedAssertions()) { + logger.trace(String.format("Decrypting EncryptedAssertion (%d/%d) in Response [%s]", count, size, + response.getID())); + try { + Assertion decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + count++; + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + response.getEncryptedAssertions().removeAll(encrypteds); + response.getAssertions().addAll(decrypteds); + + // Re-marshall the response so that any ID attributes within the decrypted + // Assertions + // will have their ID-ness re-established at the DOM level. + if (!decrypteds.isEmpty()) { + try { + XMLObjectSupport.marshall(response); + } + catch (final MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + } + + private void decryptAssertion(Assertion assertion) { + for (AttributeStatement statement : assertion.getAttributeStatements()) { + decryptAttributes(statement); + } + decryptSubject(assertion.getSubject()); + if (assertion.getConditions() != null) { + for (Condition c : assertion.getConditions().getConditions()) { + if (!(c instanceof DelegationRestrictionType delegation)) { + continue; + } + for (Delegate d : delegation.getDelegates()) { + if (d.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(d.getEncryptedID()); + if (decrypted != null) { + d.setNameID(decrypted); + d.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + } + + private void decryptAttributes(AttributeStatement statement) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + for (EncryptedAttribute encrypted : statement.getEncryptedAttributes()) { + try { + Attribute decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + statement.getEncryptedAttributes().removeAll(encrypteds); + statement.getAttributes().addAll(decrypteds); + } + + private void decryptSubject(Subject subject) { + if (subject != null) { + if (subject.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(subject.getEncryptedID()); + if (decrypted != null) { + subject.setNameID(decrypted); + subject.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + for (final SubjectConfirmation sc : subject.getSubjectConfirmations()) { + if (sc.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(sc.getEncryptedID()); + if (decrypted != null) { + sc.setNameID(decrypted); + sc.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + + private void decryptLogoutRequest(LogoutRequest request) { + if (request.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(request.getEncryptedID()); + if (decrypted != null) { + request.setNameID(decrypted); + request.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestResolver.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestResolver.java new file mode 100644 index 0000000000..d5e53a1d7e --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestResolver.java @@ -0,0 +1,142 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import java.time.Clock; +import java.time.Instant; +import java.util.function.Consumer; + +import jakarta.servlet.http.HttpServletRequest; +import org.opensaml.saml.saml2.core.LogoutRequest; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.core.Authentication; +import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; +import org.springframework.util.Assert; + +/** + * A {@link Saml2LogoutRequestResolver} for resolving SAML 2.0 Logout Requests with + * OpenSAML 4 + * + * @author Josh Cummings + * @author Gerhard Haege + * @since 5.6 + */ +public final class OpenSaml5LogoutRequestResolver implements Saml2LogoutRequestResolver { + + private final BaseOpenSamlLogoutRequestResolver delegate; + + public OpenSaml5LogoutRequestResolver(RelyingPartyRegistrationRepository registrations) { + this((request, id) -> { + if (id == null) { + return null; + } + return registrations.findByRegistrationId(id); + }); + } + + /** + * Construct a {@link OpenSaml5LogoutRequestResolver} + */ + public OpenSaml5LogoutRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + this.delegate = new BaseOpenSamlLogoutRequestResolver(relyingPartyRegistrationResolver, + new OpenSaml5Template()); + } + + /** + * {@inheritDoc} + */ + @Override + public Saml2LogoutRequest resolve(HttpServletRequest request, Authentication authentication) { + return this.delegate.resolve(request, authentication); + } + + /** + * Set a {@link Consumer} for modifying the OpenSAML {@link LogoutRequest} + * @param parametersConsumer a consumer that accepts an + * {@link LogoutRequestParameters} + */ + public void setParametersConsumer(Consumer parametersConsumer) { + Assert.notNull(parametersConsumer, "parametersConsumer cannot be null"); + this.delegate + .setParametersConsumer((parameters) -> parametersConsumer.accept(new LogoutRequestParameters(parameters))); + } + + /** + * Use this {@link Clock} for determining the issued {@link Instant} + * @param clock the {@link Clock} to use + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "clock must not be null"); + this.delegate.setClock(clock); + } + + /** + * Use this {@link Converter} to compute the RelayState + * @param relayStateResolver the {@link Converter} to use + * @since 6.1 + */ + public void setRelayStateResolver(Converter relayStateResolver) { + Assert.notNull(relayStateResolver, "relayStateResolver cannot be null"); + this.delegate.setRelayStateResolver(relayStateResolver); + } + + public static final class LogoutRequestParameters { + + private final HttpServletRequest request; + + private final RelyingPartyRegistration registration; + + private final Authentication authentication; + + private final LogoutRequest logoutRequest; + + public LogoutRequestParameters(HttpServletRequest request, RelyingPartyRegistration registration, + Authentication authentication, LogoutRequest logoutRequest) { + this.request = request; + this.registration = registration; + this.authentication = authentication; + this.logoutRequest = logoutRequest; + } + + LogoutRequestParameters(BaseOpenSamlLogoutRequestResolver.LogoutRequestParameters parameters) { + this(parameters.getRequest(), parameters.getRelyingPartyRegistration(), parameters.getAuthentication(), + parameters.getLogoutRequest()); + } + + public HttpServletRequest getRequest() { + return this.request; + } + + public RelyingPartyRegistration getRelyingPartyRegistration() { + return this.registration; + } + + public Authentication getAuthentication() { + return this.authentication; + } + + public LogoutRequest getLogoutRequest() { + return this.logoutRequest; + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolver.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolver.java new file mode 100644 index 0000000000..0d3086dd92 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolver.java @@ -0,0 +1,100 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.security.core.Authentication; +import org.springframework.security.saml2.core.OpenSamlInitializationService; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; +import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidatorParameters; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; + +/** + * An OpenSAML-based implementation of + * {@link Saml2LogoutRequestValidatorParametersResolver} + */ +public final class OpenSaml5LogoutRequestValidatorParametersResolver + implements Saml2LogoutRequestValidatorParametersResolver { + + static { + OpenSamlInitializationService.initialize(); + } + + private final BaseOpenSamlLogoutRequestValidatorParametersResolver delegate; + + /** + * Constructs a {@link OpenSaml5LogoutRequestValidatorParametersResolver} + */ + public OpenSaml5LogoutRequestValidatorParametersResolver(RelyingPartyRegistrationRepository registrations) { + Assert.notNull(registrations, "relyingPartyRegistrationRepository cannot be null"); + this.delegate = new BaseOpenSamlLogoutRequestValidatorParametersResolver(new OpenSaml5Template(), + registrations); + } + + /** + * Construct the parameters necessary for validating an asserting party's + * {@code } based on the given {@link HttpServletRequest} + * + *

+ * Uses the configured {@link RequestMatcher} to identify the processing request, + * including looking for any indicated {@code registrationId}. + * + *

+ * If a {@code registrationId} is found in the request, it will attempt to use that, + * erroring if no {@link RelyingPartyRegistration} is found. + * + *

+ * If no {@code registrationId} is found in the request, it will look for a currently + * logged-in user and use the associated {@code registrationId}. + * + *

+ * In the event that neither the URL nor any logged in user could determine a + * {@code registrationId}, this code then will try and derive a + * {@link RelyingPartyRegistration} given the {@code }'s + * {@code Issuer} value. + * @param request the HTTP request + * @return a {@link Saml2LogoutRequestValidatorParameters} instance, or {@code null} + * if one could not be resolved + * @throws Saml2AuthenticationException if the {@link RequestMatcher} specifies a + * non-existent {@code registrationId} + */ + @Override + public Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request, Authentication authentication) { + return this.delegate.resolve(request, authentication); + } + + /** + * The request matcher to use to identify a request to process a + * {@code }. By default, checks for {@code /logout/saml2/slo} and + * {@code /logout/saml2/slo/{registrationId}}. + * + *

+ * Generally speaking, the URL does not need to have a {@code registrationId} in it + * since either it can be looked up from the active logged in user or it can be + * derived through the {@code Issuer} in the {@code }. + * @param requestMatcher the {@link RequestMatcher} to use + */ + public void setRequestMatcher(RequestMatcher requestMatcher) { + Assert.notNull(requestMatcher, "requestMatcher cannot be null"); + this.delegate.setRequestMatcher(requestMatcher); + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutResponseResolver.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutResponseResolver.java new file mode 100644 index 0000000000..2478da4917 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutResponseResolver.java @@ -0,0 +1,130 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import java.time.Clock; +import java.time.Instant; +import java.util.function.Consumer; + +import jakarta.servlet.http.HttpServletRequest; +import org.opensaml.saml.saml2.core.LogoutRequest; + +import org.springframework.security.core.Authentication; +import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutResponse; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; +import org.springframework.util.Assert; + +/** + * A {@link Saml2LogoutResponseResolver} for resolving SAML 2.0 Logout Responses with + * OpenSAML 4 + * + * @author Josh Cummings + * @since 5.6 + */ +public final class OpenSaml5LogoutResponseResolver implements Saml2LogoutResponseResolver { + + private final BaseOpenSamlLogoutResponseResolver delegate; + + public OpenSaml5LogoutResponseResolver(RelyingPartyRegistrationRepository registrations) { + this.delegate = new BaseOpenSamlLogoutResponseResolver(registrations, (request, id) -> { + if (id == null) { + return null; + } + return registrations.findByRegistrationId(id); + }, new OpenSaml5Template()); + } + + /** + * Construct a {@link OpenSaml5LogoutResponseResolver} + */ + public OpenSaml5LogoutResponseResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + this.delegate = new BaseOpenSamlLogoutResponseResolver(null, relyingPartyRegistrationResolver, + new OpenSaml5Template()); + } + + /** + * {@inheritDoc} + */ + @Override + public Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication) { + return this.delegate.resolve(request, authentication); + } + + /** + * Set a {@link Consumer} for modifying the OpenSAML {@link LogoutRequest} + * @param parametersConsumer a consumer that accepts an + * {@link OpenSaml5LogoutRequestResolver.LogoutRequestParameters} + */ + public void setParametersConsumer(Consumer parametersConsumer) { + Assert.notNull(parametersConsumer, "parametersConsumer cannot be null"); + this.delegate + .setParametersConsumer((parameters) -> parametersConsumer.accept(new LogoutResponseParameters(parameters))); + } + + /** + * Use this {@link Clock} for determining the issued {@link Instant} + * @param clock the {@link Clock} to use + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "clock must not be null"); + this.delegate.setClock(clock); + } + + public static final class LogoutResponseParameters { + + private final HttpServletRequest request; + + private final RelyingPartyRegistration registration; + + private final Authentication authentication; + + private final LogoutRequest logoutRequest; + + public LogoutResponseParameters(HttpServletRequest request, RelyingPartyRegistration registration, + Authentication authentication, LogoutRequest logoutRequest) { + this.request = request; + this.registration = registration; + this.authentication = authentication; + this.logoutRequest = logoutRequest; + } + + LogoutResponseParameters(BaseOpenSamlLogoutResponseResolver.LogoutResponseParameters parameters) { + this(parameters.getRequest(), parameters.getRelyingPartyRegistration(), parameters.getAuthentication(), + parameters.getLogoutRequest()); + } + + public HttpServletRequest getRequest() { + return this.request; + } + + public RelyingPartyRegistration getRelyingPartyRegistration() { + return this.registration; + } + + public Authentication getAuthentication() { + return this.authentication; + } + + public LogoutRequest getLogoutRequest() { + return this.logoutRequest; + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5Template.java b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5Template.java new file mode 100644 index 0000000000..2d095694f5 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5Template.java @@ -0,0 +1,617 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.xml.namespace.QName; + +import net.shibboleth.shared.resolver.CriteriaSet; +import net.shibboleth.shared.xml.SerializeSupport; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.core.xml.io.Unmarshaller; +import org.opensaml.core.xml.io.UnmarshallerFactory; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.criterion.ProtocolCriterion; +import org.opensaml.saml.ext.saml2delrestrict.Delegate; +import org.opensaml.saml.ext.saml2delrestrict.DelegationRestrictionType; +import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Condition; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.LogoutRequest; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.encryption.Decrypter; +import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialResolver; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; +import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.security.criteria.UsageCriterion; +import org.opensaml.security.x509.BasicX509Credential; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.DecryptionException; +import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * For internal use only. Subject to breaking changes at any time. + */ +final class OpenSaml5Template implements OpenSamlOperations { + + private static final Log logger = LogFactory.getLog(OpenSaml5Template.class); + + @Override + public T build(QName elementName) { + XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); + if (builder == null) { + throw new Saml2Exception("Unable to resolve Builder for " + elementName); + } + return (T) builder.buildObject(elementName); + } + + @Override + public T deserialize(String serialized) { + return deserialize(new ByteArrayInputStream(serialized.getBytes(StandardCharsets.UTF_8))); + } + + @Override + public T deserialize(InputStream serialized) { + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized); + Element element = document.getDocumentElement(); + UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory(); + Unmarshaller unmarshaller = factory.getUnmarshaller(element); + if (unmarshaller == null) { + throw new Saml2Exception("Unsupported element of type " + element.getTagName()); + } + return (T) unmarshaller.unmarshall(element); + } + catch (Saml2Exception ex) { + throw ex; + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize payload", ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(XMLObject object) { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + try { + return serialize(marshaller.marshall(object)); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + @Override + public OpenSaml5SerializationConfigurer serialize(Element element) { + return new OpenSaml5SerializationConfigurer(element); + } + + @Override + public OpenSaml5SignatureConfigurer withSigningKeys(Collection credentials) { + return new OpenSaml5SignatureConfigurer(credentials); + } + + @Override + public OpenSaml5VerificationConfigurer withVerificationKeys(Collection credentials) { + return new OpenSaml5VerificationConfigurer(credentials); + } + + @Override + public OpenSaml5DecryptionConfigurer withDecryptionKeys(Collection credentials) { + return new OpenSaml5DecryptionConfigurer(credentials); + } + + OpenSaml5Template() { + + } + + static final class OpenSaml5SerializationConfigurer + implements SerializationConfigurer { + + private final Element element; + + boolean pretty; + + OpenSaml5SerializationConfigurer(Element element) { + this.element = element; + } + + @Override + public OpenSaml5SerializationConfigurer prettyPrint(boolean pretty) { + this.pretty = pretty; + return this; + } + + @Override + public String serialize() { + if (this.pretty) { + return SerializeSupport.prettyPrintXML(this.element); + } + return SerializeSupport.nodeToString(this.element); + } + + } + + static final class OpenSaml5SignatureConfigurer implements SignatureConfigurer { + + private final Collection credentials; + + private final Map components = new LinkedHashMap<>(); + + private List algs = List.of(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + + OpenSaml5SignatureConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public OpenSaml5SignatureConfigurer algorithms(List algs) { + this.algs = algs; + return this; + } + + @Override + public O sign(O object) { + SignatureSigningParameters parameters = resolveSigningParameters(); + try { + SignatureSupport.signObject(object, parameters); + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + return object; + } + + @Override + public Map sign(Map params) { + SignatureSigningParameters parameters = resolveSigningParameters(); + this.components.putAll(params); + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : this.components.entrySet()) { + builder.queryParam(component.getKey(), + UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + return this.components; + } + + private SignatureSigningParameters resolveSigningParameters() { + List credentials = resolveSigningCredentials(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(this.algs); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); + CriteriaSet criteria = new CriteriaSet(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { + final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); + + namedManager.setUseDefaultManager(true); + final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); + + // Generator for X509Credentials + final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); + x509Factory.setEmitEntityCertificate(true); + x509Factory.setEmitEntityCertificateChain(true); + + defaultManager.registerFactory(x509Factory); + + return namedManager; + } + + private List resolveSigningCredentials() { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : this.credentials) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + + } + + static final class OpenSaml5VerificationConfigurer implements VerificationConfigurer { + + private final Collection credentials; + + private String entityId; + + OpenSaml5VerificationConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public VerificationConfigurer entityId(String entityId) { + this.entityId = entityId; + return this; + } + + private SignatureTrustEngine trustEngine(Collection keys) { + Set credentials = new HashSet<>(); + for (Saml2X509Credential key : keys) { + BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); + cred.setUsageType(UsageType.SIGNING); + cred.setEntityId(this.entityId); + credentials.add(cred); + } + CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); + return new ExplicitKeySignatureTrustEngine(credentialsResolver, + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + } + + private CriteriaSet verificationCriteria(Issuer issuer) { + return new CriteriaSet(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())), + new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)), + new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + } + + @Override + public Collection verify(SignableXMLObject signable) { + if (signable instanceof StatusResponseType response) { + return verifySignature(response.getID(), response.getIssuer(), response.getSignature()); + } + if (signable instanceof RequestAbstractType request) { + return verifySignature(request.getID(), request.getIssuer(), request.getSignature()); + } + if (signable instanceof Assertion assertion) { + return verifySignature(assertion.getID(), assertion.getIssuer(), assertion.getSignature()); + } + throw new Saml2Exception("Unsupported object of type: " + signable.getClass().getName()); + } + + private Collection verifySignature(String id, Issuer issuer, Signature signature) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(issuer); + Collection errors = new ArrayList<>(); + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(signature); + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + try { + if (!trustEngine.validate(signature, criteria)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + return errors; + } + + @Override + public Collection verify(RedirectParameters parameters) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(parameters.getIssuer()); + if (parameters.getAlgorithm() == null) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature algorithm for object [" + parameters.getId() + "]")); + } + if (!parameters.hasSignature()) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature for object [" + parameters.getId() + "]")); + } + Collection errors = new ArrayList<>(); + String algorithmUri = parameters.getAlgorithm(); + try { + if (!trustEngine.validate(parameters.getSignature(), parameters.getContent(), algorithmUri, criteria, + null)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]: ")); + } + return errors; + } + + } + + static final class OpenSaml5DecryptionConfigurer implements DecryptionConfigurer { + + private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( + Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), + new SimpleRetrievalMethodEncryptedKeyResolver())); + + private final Decrypter decrypter; + + OpenSaml5DecryptionConfigurer(Collection decryptionCredentials) { + this.decrypter = decrypter(decryptionCredentials); + } + + private static Decrypter decrypter(Collection decryptionCredentials) { + Collection credentials = new ArrayList<>(); + for (Saml2X509Credential key : decryptionCredentials) { + Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); + credentials.add(cred); + } + KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials); + Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); + decrypter.setRootInNewDocument(true); + return decrypter; + } + + @Override + public void decrypt(XMLObject object) { + if (object instanceof Response response) { + decryptResponse(response); + return; + } + if (object instanceof Assertion assertion) { + decryptAssertion(assertion); + } + if (object instanceof LogoutRequest request) { + decryptLogoutRequest(request); + } + } + + /* + * The methods that follow are adapted from OpenSAML's {@link DecryptAssertions}, + * {@link DecryptNameIDs}, and {@link DecryptAttributes}. + * + *

The reason that these OpenSAML classes are not used directly is because they + * reference {@link javax.servlet.http.HttpServletRequest} which is a lower + * Servlet API version than what Spring Security SAML uses. + * + * If OpenSAML 5 updates to {@link jakarta.servlet.http.HttpServletRequest}, then + * this arrangement can be revisited. + */ + + private void decryptResponse(Response response) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + + int count = 0; + int size = response.getEncryptedAssertions().size(); + for (EncryptedAssertion encrypted : response.getEncryptedAssertions()) { + logger.trace(String.format("Decrypting EncryptedAssertion (%d/%d) in Response [%s]", count, size, + response.getID())); + try { + Assertion decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + count++; + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + response.getEncryptedAssertions().removeAll(encrypteds); + response.getAssertions().addAll(decrypteds); + + // Re-marshall the response so that any ID attributes within the decrypted + // Assertions + // will have their ID-ness re-established at the DOM level. + if (!decrypteds.isEmpty()) { + try { + XMLObjectSupport.marshall(response); + } + catch (final MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + } + + private void decryptAssertion(Assertion assertion) { + for (AttributeStatement statement : assertion.getAttributeStatements()) { + decryptAttributes(statement); + } + decryptSubject(assertion.getSubject()); + if (assertion.getConditions() != null) { + for (Condition c : assertion.getConditions().getConditions()) { + if (!(c instanceof DelegationRestrictionType delegation)) { + continue; + } + for (Delegate d : delegation.getDelegates()) { + if (d.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(d.getEncryptedID()); + if (decrypted != null) { + d.setNameID(decrypted); + d.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + } + + private void decryptAttributes(AttributeStatement statement) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + for (EncryptedAttribute encrypted : statement.getEncryptedAttributes()) { + try { + Attribute decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + statement.getEncryptedAttributes().removeAll(encrypteds); + statement.getAttributes().addAll(decrypteds); + } + + private void decryptSubject(Subject subject) { + if (subject != null) { + if (subject.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(subject.getEncryptedID()); + if (decrypted != null) { + subject.setNameID(decrypted); + subject.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + for (final SubjectConfirmation sc : subject.getSubjectConfirmations()) { + if (sc.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(sc.getEncryptedID()); + if (decrypted != null) { + sc.setNameID(decrypted); + sc.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + + private void decryptLogoutRequest(LogoutRequest request) { + if (request.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(request.getEncryptedID()); + if (decrypted != null) { + request.setNameID(decrypted); + request.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml5AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml5AuthenticationProviderTests.java new file mode 100644 index 0000000000..22ed0e89b6 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml5AuthenticationProviderTests.java @@ -0,0 +1,945 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import javax.xml.namespace.QName; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.schema.XSDateTime; +import org.opensaml.core.xml.schema.impl.XSDateTimeBuilder; +import org.opensaml.saml.common.SignableSAMLObject; +import org.opensaml.saml.common.assertion.ValidationContext; +import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.AttributeValue; +import org.opensaml.saml.saml2.core.Conditions; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; +import org.opensaml.saml.saml2.core.EncryptedID; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.OneTimeUse; +import org.opensaml.saml.saml2.core.ProxyRestriction; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.Status; +import org.opensaml.saml.saml2.core.StatusCode; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.core.SubjectConfirmationData; +import org.opensaml.saml.saml2.core.impl.AttributeBuilder; +import org.opensaml.saml.saml2.core.impl.EncryptedAssertionBuilder; +import org.opensaml.saml.saml2.core.impl.EncryptedIDBuilder; +import org.opensaml.saml.saml2.core.impl.NameIDBuilder; +import org.opensaml.saml.saml2.core.impl.ProxyRestrictionBuilder; +import org.opensaml.saml.saml2.core.impl.StatusBuilder; +import org.opensaml.saml.saml2.core.impl.StatusCodeBuilder; +import org.opensaml.xmlsec.encryption.impl.EncryptedDataBuilder; +import org.opensaml.xmlsec.signature.support.SignatureConstants; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.core.Authentication; +import org.springframework.security.jackson2.SecurityJackson2Modules; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; +import org.springframework.security.saml2.core.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.authentication.OpenSaml5AuthenticationProvider.ResponseToken; +import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSaml5Objects.CustomOpenSamlObject; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.util.StringUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link OpenSaml5AuthenticationProvider} + * + * @author Filip Hanik + * @author Josh Cummings + */ +public class OpenSaml5AuthenticationProviderTests { + + private static String DESTINATION = "https://localhost/login/saml2/sso/idp-alias"; + + private static String RELYING_PARTY_ENTITY_ID = "https://localhost/saml2/service-provider-metadata/idp-alias"; + + private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp"; + + private final OpenSamlOperations saml = new OpenSaml5Template(); + + private OpenSaml5AuthenticationProvider provider = new OpenSaml5AuthenticationProvider(); + + private Saml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal("name", + Collections.emptyMap()); + + private Saml2Authentication authentication = new Saml2Authentication(this.principal, "response", + Collections.emptyList()); + + @Test + public void supportsWhenSaml2AuthenticationTokenThenReturnTrue() { + assertThat(this.provider.supports(Saml2AuthenticationToken.class)) + .withFailMessage(OpenSaml5AuthenticationProvider.class + "should support " + Saml2AuthenticationToken.class) + .isTrue(); + } + + @Test + public void supportsWhenNotSaml2AuthenticationTokenThenReturnFalse() { + assertThat(!this.provider.supports(Authentication.class)) + .withFailMessage(OpenSaml5AuthenticationProvider.class + "should not support " + Authentication.class) + .isTrue(); + } + + @Test + public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() { + Assertion assertion = (Assertion) XMLObjectProviderRegistrySupport.getBuilderFactory() + .getBuilder(Assertion.DEFAULT_ELEMENT_NAME) + .buildObject(Assertion.DEFAULT_ELEMENT_NAME); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider + .authenticate(new Saml2AuthenticationToken(verifying(registration()).build(), serialize(assertion)))) + .satisfies(errorOf(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA)); + } + + @Test + public void authenticateWhenXmlErrorThenThrowAuthenticationException() { + Saml2AuthenticationToken token = new Saml2AuthenticationToken(verifying(registration()).build(), "invalid xml"); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA)); + } + + @Test + public void authenticateWhenInvalidDestinationThenThrowAuthenticationException() { + Response response = response(DESTINATION + "invalid", ASSERTING_PARTY_ENTITY_ID); + response.getAssertions().add(assertion()); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.INVALID_DESTINATION)); + } + + @Test + public void authenticateWhenNoAssertionsPresentThenThrowAuthenticationException() { + Saml2AuthenticationToken token = token(); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response.")); + } + + @Test + public void authenticateWhenInvalidSignatureOnAssertionThenThrowAuthenticationException() { + Response response = response(); + response.getAssertions().add(assertion()); + Saml2AuthenticationToken token = token(response, verifying(registration())); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE)); + } + + @Test + public void authenticateWhenOpenSAMLValidationErrorThenThrowAuthenticationException() { + Response response = response(); + Assertion assertion = assertion(); + assertion.getSubject() + .getSubjectConfirmations() + .get(0) + .getSubjectConfirmationData() + .setNotOnOrAfter(Instant.now().minus(Duration.ofDays(3))); + response.getAssertions().add(signed(assertion)); + Saml2AuthenticationToken token = token(response, verifying(registration())); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.INVALID_ASSERTION)); + } + + @Test + public void authenticateWhenMissingSubjectThenThrowAuthenticationException() { + Response response = response(); + Assertion assertion = assertion(); + assertion.setSubject(null); + response.getAssertions().add(signed(assertion)); + Saml2AuthenticationToken token = token(response, verifying(registration())); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.SUBJECT_NOT_FOUND)); + } + + @Test + public void authenticateWhenUsernameMissingThenThrowAuthenticationException() { + Response response = response(); + Assertion assertion = assertion(); + assertion.getSubject().getNameID().setValue(null); + response.getAssertions().add(signed(assertion)); + Saml2AuthenticationToken token = token(response, verifying(registration())); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.SUBJECT_NOT_FOUND)); + } + + @Test + public void authenticateWhenAssertionContainsValidationAddressThenItSucceeds() { + Response response = response(); + Assertion assertion = assertion(); + assertion.getSubject() + .getSubjectConfirmations() + .forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10")); + response.getAssertions().add(signed(assertion)); + Saml2AuthenticationToken token = token(response, verifying(registration())); + this.provider.authenticate(token); + } + + @Test + public void evaluateInResponseToSucceedsWhenInResponseToInResponseAndAssertionsMatchRequestID() { + Response response = response(); + response.setInResponseTo("SAML2"); + response.getAssertions().add(signed(assertion("SAML2"))); + response.getAssertions().add(signed(assertion("SAML2"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2"); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + this.provider.authenticate(token); + } + + @Test + public void evaluateInResponseToSucceedsWhenInResponseToInAssertionOnlyMatchRequestID() { + Response response = response(); + response.getAssertions().add(signed(assertion())); + response.getAssertions().add(signed(assertion("SAML2"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2"); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + this.provider.authenticate(token); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseToInAssertionMismatchWithRequestID() { + Response response = response(); + response.setInResponseTo("SAML2"); + response.getAssertions().add(signed(assertion("SAML2"))); + response.getAssertions().add(signed(assertion("BAD"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2"); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .withStackTraceContaining("invalid_assertion"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndMismatchWithRequestID() { + Response response = response(); + response.getAssertions().add(signed(assertion())); + response.getAssertions().add(signed(assertion("BAD"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2"); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .withStackTraceContaining("invalid_assertion"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseInToResponseMismatchWithRequestID() { + Response response = response(); + response.setInResponseTo("BAD"); + response.getAssertions().add(signed(assertion("SAML2"))); + response.getAssertions().add(signed(assertion("SAML2"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2"); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .withStackTraceContaining("invalid_in_response_to"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseToInResponseButNoSavedRequest() { + Response response = response(); + response.setInResponseTo("BAD"); + Saml2AuthenticationToken token = token(response, verifying(registration())); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .withStackTraceContaining("invalid_in_response_to"); + } + + @Test + public void evaluateInResponseToSucceedsWhenNoInResponseToInResponseOrAssertions() { + Response response = response(); + response.getAssertions().add(signed(assertion())); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2"); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + this.provider.authenticate(token); + } + + @Test + public void authenticateWhenAssertionContainsAttributesThenItSucceeds() { + Response response = response(); + Assertion assertion = assertion(); + List attributes = attributeStatements(); + assertion.getAttributeStatements().addAll(attributes); + response.getAssertions().add(signed(assertion)); + Saml2AuthenticationToken token = token(response, verifying(registration())); + Authentication authentication = this.provider.authenticate(token); + Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); + Map expected = new LinkedHashMap<>(); + expected.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com")); + expected.put("name", Collections.singletonList("John Doe")); + expected.put("age", Collections.singletonList(21)); + expected.put("website", Collections.singletonList("https://johndoe.com/")); + expected.put("registered", Collections.singletonList(true)); + Instant registeredDate = Instant.parse("1970-01-01T00:00:00Z"); + expected.put("registeredDate", Collections.singletonList(registeredDate)); + expected.put("role", Arrays.asList("RoleOne", "RoleTwo")); // gh-11042 + assertThat((String) principal.getFirstAttribute("name")).isEqualTo("John Doe"); + assertThat(principal.getAttributes()).isEqualTo(expected); + assertThat(principal.getSessionIndexes()).contains("session-index"); + } + + // gh-11785 + @Test + public void deserializeWhenAssertionContainsAttributesThenWorks() throws Exception { + ObjectMapper mapper = new ObjectMapper(); + ClassLoader loader = getClass().getClassLoader(); + mapper.registerModules(SecurityJackson2Modules.getModules(loader)); + Response response = response(); + Assertion assertion = assertion(); + List attributes = TestOpenSamlObjects.attributeStatements(); + assertion.getAttributeStatements().addAll(attributes); + response.getAssertions().add(signed(assertion)); + Saml2AuthenticationToken token = token(response, verifying(registration())); + Authentication authentication = this.provider.authenticate(token); + String result = mapper.writeValueAsString(authentication); + mapper.readValue(result, Authentication.class); + } + + @Test + public void authenticateWhenAssertionContainsCustomAttributesThenItSucceeds() { + Response response = response(); + Assertion assertion = assertion(); + AttributeStatement attribute = TestOpenSamlObjects.customAttributeStatement("Address", + TestCustomOpenSaml5Objects.instance()); + assertion.getAttributeStatements().add(attribute); + response.getAssertions().add(signed(assertion)); + Saml2AuthenticationToken token = token(response, verifying(registration())); + Authentication authentication = this.provider.authenticate(token); + Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); + CustomOpenSamlObject address = (CustomOpenSamlObject) principal.getAttribute("Address").get(0); + assertThat(address.getStreet()).isEqualTo("Test Street"); + assertThat(address.getStreetNumber()).isEqualTo("1"); + assertThat(address.getZIP()).isEqualTo("11111"); + assertThat(address.getCity()).isEqualTo("Test City"); + } + + @Test + public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() { + Response response = response(); + EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion(), + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); + response.getEncryptedAssertions().add(encryptedAssertion); + Saml2AuthenticationToken token = token(response, decrypting(verifying(registration()))); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE, "Did not decrypt response")); + } + + @Test + public void authenticateWhenEncryptedAssertionWithSignatureThenItSucceeds() { + Response response = response(); + Assertion assertion = TestOpenSamlObjects.signed(assertion(), + TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion, + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); + response.getEncryptedAssertions().add(encryptedAssertion); + Saml2AuthenticationToken token = token(signed(response), decrypting(verifying(registration()))); + this.provider.authenticate(token); + } + + @Test + public void authenticateWhenEncryptedAssertionWithResponseSignatureThenItSucceeds() { + Response response = response(); + EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion(), + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); + response.getEncryptedAssertions().add(encryptedAssertion); + Saml2AuthenticationToken token = token(signed(response), decrypting(verifying(registration()))); + this.provider.authenticate(token); + } + + @Test + public void authenticateWhenEncryptedNameIdWithSignatureThenItSucceeds() { + Response response = response(); + Assertion assertion = assertion(); + NameID nameId = assertion.getSubject().getNameID(); + EncryptedID encryptedID = TestOpenSamlObjects.encrypted(nameId, + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); + assertion.getSubject().setNameID(null); + assertion.getSubject().setEncryptedID(encryptedID); + response.getAssertions().add(signed(assertion)); + Saml2AuthenticationToken token = token(response, decrypting(verifying(registration()))); + this.provider.authenticate(token); + } + + @Test + public void authenticateWhenEncryptedAttributeThenDecrypts() { + Response response = response(); + Assertion assertion = assertion(); + EncryptedAttribute attribute = TestOpenSamlObjects.encrypted("name", "value", + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); + AttributeStatement statement = build(AttributeStatement.DEFAULT_ELEMENT_NAME); + statement.getEncryptedAttributes().add(attribute); + assertion.getAttributeStatements().add(statement); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(signed(response), decrypting(verifying(registration()))); + Saml2Authentication authentication = (Saml2Authentication) this.provider.authenticate(token); + Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); + assertThat(principal.getAttribute("name")).containsExactly("value"); + } + + @Test + public void authenticateWhenDecryptionKeysAreMissingThenThrowAuthenticationException() { + Response response = response(); + EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion(), + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); + response.getEncryptedAssertions().add(encryptedAssertion); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData")); + } + + @Test + public void authenticateWhenDecryptionKeysAreWrongThenThrowAuthenticationException() { + Response response = response(); + EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion(), + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); + response.getEncryptedAssertions().add(encryptedAssertion); + Saml2AuthenticationToken token = token(signed(response), registration() + .decryptionX509Credentials((c) -> c.add(TestSaml2X509Credentials.assertingPartyPrivateCredential()))); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData")); + } + + @Test + public void authenticateWhenAuthenticationHasDetailsThenSucceeds() { + Response response = response(); + Assertion assertion = assertion(); + assertion.getSubject() + .getSubjectConfirmations() + .forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10")); + response.getAssertions().add(signed(assertion)); + Saml2AuthenticationToken token = token(response, verifying(registration())); + token.setDetails("some-details"); + Authentication authentication = this.provider.authenticate(token); + assertThat(authentication.getDetails()).isEqualTo("some-details"); + } + + @Test + public void writeObjectWhenTypeIsSaml2AuthenticationThenNoException() throws IOException { + Response response = response(); + Assertion assertion = TestOpenSamlObjects.signed(assertion(), + TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion, + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); + response.getEncryptedAssertions().add(encryptedAssertion); + Saml2AuthenticationToken token = token(signed(response), decrypting(verifying(registration()))); + Saml2Authentication authentication = (Saml2Authentication) this.provider.authenticate(token); + // the following code will throw an exception if authentication isn't serializable + ByteArrayOutputStream byteStream = new ByteArrayOutputStream(1024); + ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteStream); + objectOutputStream.writeObject(authentication); + objectOutputStream.flush(); + } + + @Test + public void createDefaultAssertionValidatorWhenAssertionThenValidates() { + Response response = TestOpenSamlObjects.signedResponseWithOneAssertion(); + Assertion assertion = response.getAssertions().get(0); + OpenSaml5AuthenticationProvider.AssertionToken assertionToken = new OpenSaml5AuthenticationProvider.AssertionToken( + assertion, token()); + assertThat( + OpenSaml5AuthenticationProvider.createDefaultAssertionValidator().convert(assertionToken).hasErrors()) + .isFalse(); + } + + @Test + public void authenticateWhenDelegatingToDefaultAssertionValidatorThenUses() { + OpenSaml5AuthenticationProvider provider = new OpenSaml5AuthenticationProvider(); + // @formatter:off + provider.setAssertionValidator((assertionToken) -> OpenSaml5AuthenticationProvider + .createDefaultAssertionValidator((token) -> new ValidationContext()) + .convert(assertionToken) + .concat(new Saml2Error("wrong error", "wrong error")) + ); + // @formatter:on + Response response = response(); + Assertion assertion = assertion(); + OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME); + assertion.getConditions().getConditions().add(oneTimeUse); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); + // @formatter:off + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> provider.authenticate(token)).isInstanceOf(Saml2AuthenticationException.class) + .satisfies((error) -> assertThat(error.getSaml2Error().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_ASSERTION)); + // @formatter:on + } + + // gh-11675 + @Test + public void authenticateWhenUsingCustomAssertionValidatorThenUses() { + OpenSaml5AuthenticationProvider provider = new OpenSaml5AuthenticationProvider(); + Consumer> validationParameters = mock(Consumer.class); + // @formatter:off + provider.setAssertionValidator(OpenSaml5AuthenticationProvider + .createDefaultAssertionValidatorWithParameters(validationParameters)); + // @formatter:on + Response response = response(); + Assertion assertion = assertion(); + OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME); + assertion.getConditions().getConditions().add(oneTimeUse); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); + provider.authenticate(token); + verify(validationParameters).accept(any()); + } + + @Test + public void authenticateWhenCustomAssertionValidatorThenUses() { + Converter validator = mock( + Converter.class); + OpenSaml5AuthenticationProvider provider = new OpenSaml5AuthenticationProvider(); + // @formatter:off + provider.setAssertionValidator((assertionToken) -> OpenSaml5AuthenticationProvider.createDefaultAssertionValidator() + .convert(assertionToken) + .concat(validator.convert(assertionToken)) + ); + // @formatter:on + Response response = response(); + Assertion assertion = assertion(); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); + given(validator.convert(any(OpenSaml5AuthenticationProvider.AssertionToken.class))) + .willReturn(Saml2ResponseValidatorResult.success()); + provider.authenticate(token); + verify(validator).convert(any(OpenSaml5AuthenticationProvider.AssertionToken.class)); + } + + @Test + public void authenticateWhenDefaultConditionValidatorNotUsedThenSignatureStillChecked() { + OpenSaml5AuthenticationProvider provider = new OpenSaml5AuthenticationProvider(); + provider.setAssertionValidator((assertionToken) -> Saml2ResponseValidatorResult.success()); + Response response = response(); + Assertion assertion = assertion(); + TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.relyingPartyDecryptingCredential(), + RELYING_PARTY_ENTITY_ID); // broken + // signature + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); + // @formatter:off + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> provider.authenticate(token)) + .satisfies((error) -> assertThat(error.getSaml2Error().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_SIGNATURE)); + // @formatter:on + } + + @Test + public void authenticateWhenValidationContextCustomizedThenUsers() { + Map parameters = new HashMap<>(); + parameters.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton("blah")); + ValidationContext context = mock(ValidationContext.class); + given(context.getStaticParameters()).willReturn(parameters); + OpenSaml5AuthenticationProvider provider = new OpenSaml5AuthenticationProvider(); + provider.setAssertionValidator( + OpenSaml5AuthenticationProvider.createDefaultAssertionValidator((assertionToken) -> context)); + Response response = response(); + Assertion assertion = assertion(); + response.getAssertions().add(signed(assertion)); + Saml2AuthenticationToken token = token(response, verifying(registration())); + // @formatter:off + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> provider.authenticate(token)).isInstanceOf(Saml2AuthenticationException.class) + .satisfies((error) -> assertThat(error).hasMessageContaining("Invalid assertion")); + // @formatter:on + verify(context, atLeastOnce()).getStaticParameters(); + } + + @Test + public void authenticateWithSHA1SignatureThenItSucceeds() throws Exception { + Response response = response(); + Assertion assertion = TestOpenSamlObjects.signed(assertion(), + TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID, + SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(response, verifying(registration())); + this.provider.authenticate(token); + } + + @Test + public void setAssertionValidatorWhenNullThenIllegalArgument() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.provider.setAssertionValidator(null)); + // @formatter:on + } + + @Test + public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() { + Response response = TestOpenSamlObjects.signedResponseWithOneAssertion(); + Saml2AuthenticationToken token = token(response, verifying(registration())); + ResponseToken responseToken = new ResponseToken(response, token); + Saml2Authentication authentication = OpenSaml5AuthenticationProvider + .createDefaultResponseAuthenticationConverter() + .convert(responseToken); + assertThat(authentication.getName()).isEqualTo("test@saml.user"); + } + + @Test + public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() { + Converter authenticationConverter = mock(Converter.class); + OpenSaml5AuthenticationProvider provider = new OpenSaml5AuthenticationProvider(); + provider.setResponseAuthenticationConverter(authenticationConverter); + Response response = TestOpenSamlObjects.signedResponseWithOneAssertion(); + Saml2AuthenticationToken token = token(response, verifying(registration())); + provider.authenticate(token); + verify(authenticationConverter).convert(any()); + } + + @Test + public void setResponseAuthenticationConverterWhenNullThenIllegalArgument() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.provider.setResponseAuthenticationConverter(null)); + // @formatter:on + } + + @Test + public void setResponseElementsDecrypterWhenNullThenIllegalArgument() { + assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setResponseElementsDecrypter(null)); + } + + @Test + public void setAssertionElementsDecrypterWhenNullThenIllegalArgument() { + assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionElementsDecrypter(null)); + } + + @Test + public void authenticateWhenCustomResponseElementsDecrypterThenDecryptsResponse() { + Response response = response(); + Assertion assertion = assertion(); + response.getEncryptedAssertions().add(new EncryptedAssertionBuilder().buildObject()); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, verifying(registration())); + this.provider + .setResponseElementsDecrypter((tuple) -> tuple.getResponse().getAssertions().add(signed(assertion))); + Authentication authentication = this.provider.authenticate(token); + assertThat(authentication.getName()).isEqualTo("test@saml.user"); + } + + @Test + public void authenticateWhenCustomAssertionElementsDecrypterThenDecryptsAssertion() { + Response response = response(); + Assertion assertion = assertion(); + EncryptedID id = new EncryptedIDBuilder().buildObject(); + id.setEncryptedData(new EncryptedDataBuilder().buildObject()); + assertion.getSubject().setEncryptedID(id); + response.getAssertions().add(signed(assertion)); + Saml2AuthenticationToken token = token(response, verifying(registration())); + this.provider.setAssertionElementsDecrypter((tuple) -> { + NameID name = new NameIDBuilder().buildObject(); + name.setValue("decrypted name"); + tuple.getAssertion().getSubject().setNameID(name); + }); + Authentication authentication = this.provider.authenticate(token); + assertThat(authentication.getName()).isEqualTo("decrypted name"); + } + + @Test + public void authenticateWhenResponseStatusIsNotSuccessThenFails() { + Response response = TestOpenSamlObjects + .signedResponseWithOneAssertion((r) -> r.setStatus(TestOpenSamlObjects.status(StatusCode.AUTHN_FAILED))); + Saml2AuthenticationToken token = token(response, verifying(registration())); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.INVALID_RESPONSE, "Invalid status")); + } + + @Test + public void authenticateWhenResponseStatusIsSuccessThenSucceeds() { + Response response = TestOpenSamlObjects + .signedResponseWithOneAssertion((r) -> r.setStatus(TestOpenSamlObjects.successStatus())); + Saml2AuthenticationToken token = token(response, verifying(registration())); + Authentication authentication = this.provider.authenticate(token); + assertThat(authentication.getName()).isEqualTo("test@saml.user"); + } + + @Test + public void setResponseValidatorWhenNullThenIllegalArgument() { + assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setResponseValidator(null)); + } + + @Test + public void authenticateWhenCustomResponseValidatorThenUses() { + Converter validator = mock( + Converter.class); + OpenSaml5AuthenticationProvider provider = new OpenSaml5AuthenticationProvider(); + // @formatter:off + provider.setResponseValidator((responseToken) -> OpenSaml5AuthenticationProvider.createDefaultResponseValidator() + .convert(responseToken) + .concat(validator.convert(responseToken)) + ); + // @formatter:on + Response response = response(); + Assertion assertion = assertion(); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); + given(validator.convert(any(OpenSaml5AuthenticationProvider.ResponseToken.class))) + .willReturn(Saml2ResponseValidatorResult.success()); + provider.authenticate(token); + verify(validator).convert(any(OpenSaml5AuthenticationProvider.ResponseToken.class)); + } + + @Test + public void authenticateWhenResponseStatusIsNotSuccessThenOnlyReturnParentStatusCodes() { + Saml2AuthenticationToken token = TestSaml2AuthenticationTokens.token(); + + Status parentStatus = new StatusBuilder().buildObject(); + StatusCode parentStatusCode = new StatusCodeBuilder().buildObject(); + parentStatusCode.setValue(StatusCode.AUTHN_FAILED); + StatusCode childStatusCode = new StatusCodeBuilder().buildObject(); + childStatusCode.setValue(StatusCode.NO_PASSIVE); + parentStatusCode.setStatusCode(childStatusCode); + parentStatus.setStatusCode(parentStatusCode); + + Response response = TestOpenSamlObjects.response(); + response.setStatus(parentStatus); + response.setIssuer(TestOpenSamlObjects.issuer("mockedIssuer")); + + Converter validator = OpenSaml5AuthenticationProvider + .createDefaultResponseValidator(); + Saml2ResponseValidatorResult result = validator.convert(new ResponseToken(response, token)); + + String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", + parentStatusCode.getValue()); + assertThat( + result.getErrors().stream().anyMatch((error) -> error.getDescription().contains(expectedErrorMessage))) + .isTrue(); + assertThat(result.getErrors() + .stream() + .noneMatch((error) -> error.getDescription().contains(childStatusCode.getValue()))).isTrue(); + } + + @Test + public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildStatusCode() { + Saml2AuthenticationToken token = TestSaml2AuthenticationTokens.token(); + Status parentStatus = new StatusBuilder().buildObject(); + StatusCode parentStatusCode = new StatusCodeBuilder().buildObject(); + parentStatusCode.setValue(StatusCode.REQUESTER); + StatusCode childStatusCode = new StatusCodeBuilder().buildObject(); + childStatusCode.setValue(StatusCode.NO_PASSIVE); + parentStatusCode.setStatusCode(childStatusCode); + parentStatus.setStatusCode(parentStatusCode); + + Response response = TestOpenSamlObjects.response(); + response.setStatus(parentStatus); + response.setIssuer(TestOpenSamlObjects.issuer("mockedIssuer")); + + Converter validator = OpenSaml5AuthenticationProvider + .createDefaultResponseValidator(); + Saml2ResponseValidatorResult result = validator.convert(new ResponseToken(response, token)); + + String expectedParentErrorMessage = String.format("Invalid status [%s] for SAML response", + parentStatusCode.getValue()); + String expectedChildErrorMessage = String.format("Invalid status [%s] for SAML response", + childStatusCode.getValue()); + assertThat(result.getErrors() + .stream() + .anyMatch((error) -> error.getDescription().contains(expectedParentErrorMessage))).isTrue(); + assertThat(result.getErrors() + .stream() + .anyMatch((error) -> error.getDescription().contains(expectedChildErrorMessage))).isTrue(); + } + + @Test + public void authenticateWhenAssertionIssuerNotValidThenFailsWithInvalidIssuer() { + OpenSaml5AuthenticationProvider provider = new OpenSaml5AuthenticationProvider(); + Response response = response(); + Assertion assertion = assertion(); + assertion.setIssuer(TestOpenSamlObjects.issuer("https://invalid.idp.test/saml2/idp")); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); + assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> provider.authenticate(token)) + .withMessageContaining("did not match any valid issuers"); + } + + // gh-14931 + @Test + public void authenticateWhenAssertionHasProxyRestrictionThenParses() { + OpenSaml5AuthenticationProvider provider = new OpenSaml5AuthenticationProvider(); + Response response = response(); + Assertion assertion = assertion(); + ProxyRestriction condition = new ProxyRestrictionBuilder().buildObject(); + assertion.getConditions().getConditions().add(condition); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); + provider.authenticate(token); + } + + // gh-15022 + @Test + public void authenticateWhenClockSkewThenVerifiesSignature() { + OpenSaml5AuthenticationProvider provider = new OpenSaml5AuthenticationProvider(); + provider.setAssertionValidator(OpenSaml5AuthenticationProvider.createDefaultAssertionValidatorWithParameters( + (params) -> params.put(SAML2AssertionValidationParameters.CLOCK_SKEW, Duration.ofMinutes(10)))); + Response response = response(); + Assertion assertion = assertion(); + assertion.setIssueInstant(Instant.now().plus(Duration.ofMinutes(9))); + response.getAssertions().add(assertion); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); + provider.authenticate(token); + } + + private T build(QName qName) { + return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName); + } + + private String serialize(XMLObject object) { + return this.saml.serialize(object).serialize(); + } + + private Consumer errorOf(String errorCode) { + return errorOf(errorCode, null); + } + + private Consumer errorOf(String errorCode, String description) { + return (ex) -> { + assertThat(ex.getSaml2Error().getErrorCode()).isEqualTo(errorCode); + if (StringUtils.hasText(description)) { + assertThat(ex.getSaml2Error().getDescription()).contains(description); + } + }; + } + + private Response response() { + Response response = TestOpenSamlObjects.response(); + response.setIssueInstant(Instant.now()); + return response; + } + + private Response response(String destination, String issuerEntityId) { + Response response = TestOpenSamlObjects.response(destination, issuerEntityId); + response.setIssueInstant(Instant.now()); + return response; + } + + private Assertion assertion(String inResponseTo) { + Assertion assertion = TestOpenSamlObjects.assertion(); + assertion.setIssueInstant(Instant.now()); + for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) { + SubjectConfirmationData data = confirmation.getSubjectConfirmationData(); + data.setNotBefore(Instant.now().minus(Duration.ofMillis(5 * 60 * 1000))); + data.setNotOnOrAfter(Instant.now().plus(Duration.ofMillis(5 * 60 * 1000))); + if (StringUtils.hasText(inResponseTo)) { + data.setInResponseTo(inResponseTo); + } + } + Conditions conditions = assertion.getConditions(); + conditions.setNotBefore(Instant.now().minus(Duration.ofMillis(5 * 60 * 1000))); + conditions.setNotOnOrAfter(Instant.now().plus(Duration.ofMillis(5 * 60 * 1000))); + return assertion; + } + + private Assertion assertion() { + return assertion(null); + } + + private T signed(T toSign) { + TestOpenSamlObjects.signed(toSign, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); + return toSign; + } + + private List attributeStatements() { + List attributeStatements = TestOpenSamlObjects.attributeStatements(); + AttributeBuilder attributeBuilder = new AttributeBuilder(); + Attribute registeredDateAttr = attributeBuilder.buildObject(); + registeredDateAttr.setName("registeredDate"); + XSDateTime registeredDate = new XSDateTimeBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, + XSDateTime.TYPE_NAME); + registeredDate.setValue(Instant.parse("1970-01-01T00:00:00Z")); + registeredDateAttr.getAttributeValues().add(registeredDate); + attributeStatements.iterator().next().getAttributes().add(registeredDateAttr); + return attributeStatements; + } + + private Saml2AuthenticationToken token() { + Response response = response(); + RelyingPartyRegistration registration = verifying(registration()).build(); + return new Saml2AuthenticationToken(registration, serialize(response)); + } + + private Saml2AuthenticationToken token(Response response, RelyingPartyRegistration.Builder registration) { + return new Saml2AuthenticationToken(registration.build(), serialize(response)); + } + + private Saml2AuthenticationToken token(Response response, RelyingPartyRegistration.Builder registration, + AbstractSaml2AuthenticationRequest authenticationRequest) { + return new Saml2AuthenticationToken(registration.build(), serialize(response), authenticationRequest); + } + + private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest(String requestId) { + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); + given(mockAuthenticationRequest.getId()).willReturn(requestId); + return mockAuthenticationRequest; + } + + private RelyingPartyRegistration.Builder registration() { + return TestRelyingPartyRegistrations.noCredentials() + .entityId(RELYING_PARTY_ENTITY_ID) + .assertionConsumerServiceLocation(DESTINATION) + .assertingPartyDetails((party) -> party.entityId(ASSERTING_PARTY_ENTITY_ID)); + } + + private RelyingPartyRegistration.Builder verifying(RelyingPartyRegistration.Builder builder) { + return builder.assertingPartyDetails((party) -> party + .verificationX509Credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential()))); + } + + private RelyingPartyRegistration.Builder decrypting(RelyingPartyRegistration.Builder builder) { + return builder + .decryptionX509Credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyDecryptingCredential())); + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSaml5Objects.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSaml5Objects.java new file mode 100644 index 0000000000..cdb98e47c4 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSaml5Objects.java @@ -0,0 +1,213 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication; + +import java.util.Collections; +import java.util.List; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import javax.xml.namespace.QName; + +import net.shibboleth.shared.xml.ElementSupport; +import org.opensaml.core.xml.AbstractXMLObject; +import org.opensaml.core.xml.AbstractXMLObjectBuilder; +import org.opensaml.core.xml.ElementExtensibleXMLObject; +import org.opensaml.core.xml.Namespace; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.AbstractXMLObjectMarshaller; +import org.opensaml.core.xml.io.AbstractXMLObjectUnmarshaller; +import org.opensaml.core.xml.io.UnmarshallingException; +import org.opensaml.core.xml.schema.XSAny; +import org.opensaml.core.xml.schema.impl.XSAnyBuilder; +import org.opensaml.core.xml.util.IndexedXMLObjectChildrenList; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.saml2.core.AttributeValue; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.core.OpenSamlInitializationService; + +public final class TestCustomOpenSaml5Objects { + + static { + OpenSamlInitializationService.initialize(); + XMLObjectProviderRegistrySupport.getMarshallerFactory() + .registerMarshaller(CustomOpenSamlObject.TYPE_NAME, + new TestCustomOpenSaml5Objects.CustomSamlObjectMarshaller()); + XMLObjectProviderRegistrySupport.getUnmarshallerFactory() + .registerUnmarshaller(CustomOpenSamlObject.TYPE_NAME, + new TestCustomOpenSaml5Objects.CustomSamlObjectUnmarshaller()); + } + + public static CustomOpenSamlObject instance() { + CustomOpenSamlObject samlObject = new TestCustomOpenSaml5Objects.CustomSamlObjectBuilder() + .buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, CustomOpenSamlObject.TYPE_NAME); + XSAny street = new XSAnyBuilder().buildObject(CustomOpenSamlObject.CUSTOM_NS, "Street", + CustomOpenSamlObject.TYPE_CUSTOM_PREFIX); + street.setTextContent("Test Street"); + samlObject.getUnknownXMLObjects().add(street); + XSAny streetNumber = new XSAnyBuilder().buildObject(CustomOpenSamlObject.CUSTOM_NS, "Number", + CustomOpenSamlObject.TYPE_CUSTOM_PREFIX); + streetNumber.setTextContent("1"); + samlObject.getUnknownXMLObjects().add(streetNumber); + XSAny zip = new XSAnyBuilder().buildObject(CustomOpenSamlObject.CUSTOM_NS, "ZIP", + CustomOpenSamlObject.TYPE_CUSTOM_PREFIX); + zip.setTextContent("11111"); + samlObject.getUnknownXMLObjects().add(zip); + XSAny city = new XSAnyBuilder().buildObject(CustomOpenSamlObject.CUSTOM_NS, "City", + CustomOpenSamlObject.TYPE_CUSTOM_PREFIX); + city.setTextContent("Test City"); + samlObject.getUnknownXMLObjects().add(city); + return samlObject; + } + + private TestCustomOpenSaml5Objects() { + + } + + public interface CustomOpenSamlObject extends ElementExtensibleXMLObject { + + String TYPE_LOCAL_NAME = "CustomType"; + + String TYPE_CUSTOM_PREFIX = "custom"; + + String CUSTOM_NS = "https://custom.com/schema/custom"; + + /** QName of the CustomType type. */ + QName TYPE_NAME = new QName(CUSTOM_NS, TYPE_LOCAL_NAME, TYPE_CUSTOM_PREFIX); + + String getStreet(); + + String getStreetNumber(); + + String getZIP(); + + String getCity(); + + } + + public static class CustomOpenSamlObjectImpl extends AbstractXMLObject implements CustomOpenSamlObject { + + @Nonnull + private IndexedXMLObjectChildrenList unknownXMLObjects; + + /** + * Constructor. + * @param namespaceURI the namespace the element is in + * @param elementLocalName the local name of the XML element this Object + * represents + * @param namespacePrefix the prefix for the given namespace + */ + protected CustomOpenSamlObjectImpl(@Nullable String namespaceURI, @Nonnull String elementLocalName, + @Nullable String namespacePrefix) { + super(namespaceURI, elementLocalName, namespacePrefix); + super.getNamespaceManager().registerNamespaceDeclaration(new Namespace(CUSTOM_NS, TYPE_CUSTOM_PREFIX)); + this.unknownXMLObjects = new IndexedXMLObjectChildrenList<>(this); + } + + @Nonnull + @Override + public List getUnknownXMLObjects() { + return this.unknownXMLObjects; + } + + @Nonnull + @Override + public List getUnknownXMLObjects(@Nonnull QName typeOrName) { + return (List) this.unknownXMLObjects.subList(typeOrName); + } + + @Nullable + @Override + public List getOrderedChildren() { + return Collections.unmodifiableList(this.unknownXMLObjects); + } + + @Override + public String getStreet() { + return ((XSAny) getOrderedChildren().get(0)).getTextContent(); + } + + @Override + public String getStreetNumber() { + return ((XSAny) getOrderedChildren().get(1)).getTextContent(); + } + + @Override + public String getZIP() { + return ((XSAny) getOrderedChildren().get(2)).getTextContent(); + } + + @Override + public String getCity() { + return ((XSAny) getOrderedChildren().get(3)).getTextContent(); + } + + } + + public static class CustomSamlObjectBuilder extends AbstractXMLObjectBuilder { + + @Nonnull + @Override + public CustomOpenSamlObject buildObject(@Nullable String namespaceURI, @Nonnull String localName, + @Nullable String namespacePrefix) { + return new CustomOpenSamlObjectImpl(namespaceURI, localName, namespacePrefix); + } + + } + + public static class CustomSamlObjectMarshaller extends AbstractXMLObjectMarshaller { + + public CustomSamlObjectMarshaller() { + super(); + } + + @Override + protected void marshallElementContent(@Nonnull XMLObject xmlObject, @Nonnull Element domElement) { + final CustomOpenSamlObject customSamlObject = (CustomOpenSamlObject) xmlObject; + + for (XMLObject object : customSamlObject.getOrderedChildren()) { + ElementSupport.appendChildElement(domElement, object.getDOM()); + } + } + + } + + public static class CustomSamlObjectUnmarshaller extends AbstractXMLObjectUnmarshaller { + + public CustomSamlObjectUnmarshaller() { + super(); + } + + @Override + protected void processChildElement(@Nonnull XMLObject parentXMLObject, @Nonnull XMLObject childXMLObject) + throws UnmarshallingException { + final CustomOpenSamlObject customSamlObject = (CustomOpenSamlObject) parentXMLObject; + customSamlObject.getUnknownXMLObjects().add(childXMLObject); + } + + @Nonnull + @Override + protected XMLObject buildXMLObject(@Nonnull Element domElement) { + return new CustomOpenSamlObjectImpl(SAMLConstants.SAML20_NS, AttributeValue.DEFAULT_ELEMENT_LOCAL_NAME, + CustomOpenSamlObject.TYPE_CUSTOM_PREFIX); + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutRequestValidatorTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutRequestValidatorTests.java new file mode 100644 index 0000000000..43ceb70fad --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutRequestValidatorTests.java @@ -0,0 +1,223 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication.logout; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.saml.saml2.core.LogoutRequest; + +import org.springframework.security.core.Authentication; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.authentication.DefaultSaml2AuthenticatedPrincipal; +import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; +import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; +import org.springframework.security.saml2.provider.service.authentication.logout.OpenSamlOperations.SignatureConfigurer; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link OpenSaml5LogoutRequestValidator} + * + * @author Josh Cummings + */ +public class OpenSaml5LogoutRequestValidatorTests { + + private final OpenSamlOperations saml = new OpenSaml5Template(); + + private final OpenSaml5LogoutRequestValidator validator = new OpenSaml5LogoutRequestValidator(); + + @Test + public void handleWhenPostBindingThenValidates() { + RelyingPartyRegistration registration = registration().build(); + LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequest(registration); + sign(logoutRequest, registration); + Saml2LogoutRequest request = post(logoutRequest, registration); + Saml2LogoutRequestValidatorParameters parameters = new Saml2LogoutRequestValidatorParameters(request, + registration, authentication(registration)); + Saml2LogoutValidatorResult result = this.validator.validate(parameters); + assertThat(result.hasErrors()).isFalse(); + } + + @Test + public void handleWhenNameIdIsEncryptedIdPostThenValidates() { + + RelyingPartyRegistration registration = decrypting(encrypting(registration())).build(); + LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequestNameIdInEncryptedId(registration); + sign(logoutRequest, registration); + Saml2LogoutRequest request = post(logoutRequest, registration); + Saml2LogoutRequestValidatorParameters parameters = new Saml2LogoutRequestValidatorParameters(request, + registration, authentication(registration)); + Saml2LogoutValidatorResult result = this.validator.validate(parameters); + assertThat(result.hasErrors()).withFailMessage(() -> result.getErrors().toString()).isFalse(); + + } + + @Test + public void handleWhenRedirectBindingThenValidatesSignatureParameter() { + RelyingPartyRegistration registration = registration() + .assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.REDIRECT)) + .build(); + LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequest(registration); + Saml2LogoutRequest request = redirect(logoutRequest, registration, + this.saml.withSigningKeys(registration.getSigningX509Credentials())); + Saml2LogoutRequestValidatorParameters parameters = new Saml2LogoutRequestValidatorParameters(request, + registration, authentication(registration)); + Saml2LogoutValidatorResult result = this.validator.validate(parameters); + assertThat(result.hasErrors()).isFalse(); + } + + @Test + public void handleWhenInvalidIssuerThenInvalidSignatureError() { + RelyingPartyRegistration registration = registration().build(); + LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequest(registration); + logoutRequest.getIssuer().setValue("wrong"); + sign(logoutRequest, registration); + Saml2LogoutRequest request = post(logoutRequest, registration); + Saml2LogoutRequestValidatorParameters parameters = new Saml2LogoutRequestValidatorParameters(request, + registration, authentication(registration)); + Saml2LogoutValidatorResult result = this.validator.validate(parameters); + assertThat(result.hasErrors()).isTrue(); + assertThat(result.getErrors().iterator().next().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_SIGNATURE); + } + + @Test + public void handleWhenMismatchedUserThenInvalidRequestError() { + RelyingPartyRegistration registration = registration().build(); + LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequest(registration); + logoutRequest.getNameID().setValue("wrong"); + sign(logoutRequest, registration); + Saml2LogoutRequest request = post(logoutRequest, registration); + Saml2LogoutRequestValidatorParameters parameters = new Saml2LogoutRequestValidatorParameters(request, + registration, authentication(registration)); + Saml2LogoutValidatorResult result = this.validator.validate(parameters); + assertThat(result.hasErrors()).isTrue(); + assertThat(result.getErrors().iterator().next().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_REQUEST); + } + + @Test + public void handleWhenMissingUserThenSubjectNotFoundError() { + RelyingPartyRegistration registration = registration().build(); + LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequest(registration); + logoutRequest.setNameID(null); + sign(logoutRequest, registration); + Saml2LogoutRequest request = post(logoutRequest, registration); + Saml2LogoutRequestValidatorParameters parameters = new Saml2LogoutRequestValidatorParameters(request, + registration, authentication(registration)); + Saml2LogoutValidatorResult result = this.validator.validate(parameters); + assertThat(result.hasErrors()).isTrue(); + assertThat(result.getErrors().iterator().next().getErrorCode()).isEqualTo(Saml2ErrorCodes.SUBJECT_NOT_FOUND); + } + + @Test + public void handleWhenMismatchedDestinationThenInvalidDestinationError() { + RelyingPartyRegistration registration = registration().build(); + LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequest(registration); + logoutRequest.setDestination("wrong"); + sign(logoutRequest, registration); + Saml2LogoutRequest request = post(logoutRequest, registration); + Saml2LogoutRequestValidatorParameters parameters = new Saml2LogoutRequestValidatorParameters(request, + registration, authentication(registration)); + Saml2LogoutValidatorResult result = this.validator.validate(parameters); + assertThat(result.hasErrors()).isTrue(); + assertThat(result.getErrors().iterator().next().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_DESTINATION); + } + + // gh-10923 + @Test + public void handleWhenLogoutResponseHasLineBreaksThenHandles() { + RelyingPartyRegistration registration = registration().build(); + LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequest(registration); + sign(logoutRequest, registration); + String encoded = new StringBuffer( + Saml2Utils.samlEncode(serialize(logoutRequest).getBytes(StandardCharsets.UTF_8))) + .insert(10, "\r\n") + .toString(); + Saml2LogoutRequest request = Saml2LogoutRequest.withRelyingPartyRegistration(registration) + .samlRequest(encoded) + .build(); + Saml2LogoutRequestValidatorParameters parameters = new Saml2LogoutRequestValidatorParameters(request, + registration, authentication(registration)); + Saml2LogoutValidatorResult result = this.validator.validate(parameters); + assertThat(result.hasErrors()).isFalse(); + } + + private RelyingPartyRegistration.Builder registration() { + return signing(verifying(TestRelyingPartyRegistrations.noCredentials())) + .assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)); + } + + private RelyingPartyRegistration.Builder decrypting(RelyingPartyRegistration.Builder builder) { + return builder + .decryptionX509Credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyDecryptingCredential())); + } + + private RelyingPartyRegistration.Builder encrypting(RelyingPartyRegistration.Builder builder) { + return builder.assertingPartyDetails((party) -> party + .encryptionX509Credentials((c) -> c.add(TestSaml2X509Credentials.assertingPartyEncryptingCredential()))); + } + + private RelyingPartyRegistration.Builder verifying(RelyingPartyRegistration.Builder builder) { + return builder.assertingPartyDetails((party) -> party + .verificationX509Credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential()))); + } + + private RelyingPartyRegistration.Builder signing(RelyingPartyRegistration.Builder builder) { + return builder.signingX509Credentials((c) -> c.add(TestSaml2X509Credentials.assertingPartySigningCredential())); + } + + private Authentication authentication(RelyingPartyRegistration registration) { + DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal("user", new HashMap<>()); + principal.setRelyingPartyRegistrationId(registration.getRegistrationId()); + return new Saml2Authentication(principal, "response", new ArrayList<>()); + } + + private Saml2LogoutRequest post(LogoutRequest logoutRequest, RelyingPartyRegistration registration) { + return Saml2LogoutRequest.withRelyingPartyRegistration(registration) + .samlRequest(Saml2Utils.samlEncode(serialize(logoutRequest).getBytes(StandardCharsets.UTF_8))) + .build(); + } + + private Saml2LogoutRequest redirect(LogoutRequest logoutRequest, RelyingPartyRegistration registration, + SignatureConfigurer configurer) { + String serialized = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(serialize(logoutRequest))); + Map parameters = configurer.sign(Map.of(Saml2ParameterNames.SAML_REQUEST, serialized)); + return Saml2LogoutRequest.withRelyingPartyRegistration(registration) + .samlRequest(serialized) + .parameters((params) -> params.putAll(parameters)) + .build(); + } + + private void sign(LogoutRequest logoutRequest, RelyingPartyRegistration registration) { + TestOpenSamlObjects.signed(logoutRequest, registration.getSigningX509Credentials().iterator().next(), + registration.getAssertingPartyDetails().getEntityId()); + } + + private String serialize(XMLObject object) { + return this.saml.serialize(object).serialize(); + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutResponseValidatorTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutResponseValidatorTests.java new file mode 100644 index 0000000000..08f2eeafb4 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSaml5LogoutResponseValidatorTests.java @@ -0,0 +1,190 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication.logout; + +import java.nio.charset.StandardCharsets; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.saml.saml2.core.LogoutResponse; +import org.opensaml.saml.saml2.core.StatusCode; + +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; +import org.springframework.security.saml2.provider.service.authentication.logout.OpenSamlOperations.SignatureConfigurer; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link OpenSaml5LogoutResponseValidator} + * + * @author Josh Cummings + */ +public class OpenSaml5LogoutResponseValidatorTests { + + private final OpenSamlOperations saml = new OpenSaml5Template(); + + private final OpenSaml5LogoutResponseValidator manager = new OpenSaml5LogoutResponseValidator(); + + @Test + public void handleWhenAuthenticatedThenHandles() { + RelyingPartyRegistration registration = signing(verifying(registration())).build(); + Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) + .id("id") + .build(); + LogoutResponse logoutResponse = TestOpenSamlObjects.assertingPartyLogoutResponse(registration); + sign(logoutResponse, registration); + Saml2LogoutResponse response = post(logoutResponse, registration); + Saml2LogoutResponseValidatorParameters parameters = new Saml2LogoutResponseValidatorParameters(response, + logoutRequest, registration); + this.manager.validate(parameters); + } + + @Test + public void handleWhenRedirectBindingThenValidatesSignatureParameter() { + RelyingPartyRegistration registration = signing(verifying(registration())) + .assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.REDIRECT)) + .build(); + Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) + .id("id") + .build(); + LogoutResponse logoutResponse = TestOpenSamlObjects.assertingPartyLogoutResponse(registration); + Saml2LogoutResponse response = redirect(logoutResponse, registration, + this.saml.withSigningKeys(registration.getSigningX509Credentials())); + Saml2LogoutResponseValidatorParameters parameters = new Saml2LogoutResponseValidatorParameters(response, + logoutRequest, registration); + this.manager.validate(parameters); + } + + @Test + public void handleWhenInvalidIssuerThenInvalidSignatureError() { + RelyingPartyRegistration registration = registration().build(); + Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) + .id("id") + .build(); + LogoutResponse logoutResponse = TestOpenSamlObjects.assertingPartyLogoutResponse(registration); + logoutResponse.getIssuer().setValue("wrong"); + sign(logoutResponse, registration); + Saml2LogoutResponse response = post(logoutResponse, registration); + Saml2LogoutResponseValidatorParameters parameters = new Saml2LogoutResponseValidatorParameters(response, + logoutRequest, registration); + Saml2LogoutValidatorResult result = this.manager.validate(parameters); + assertThat(result.hasErrors()).isTrue(); + assertThat(result.getErrors().iterator().next().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_SIGNATURE); + } + + @Test + public void handleWhenMismatchedDestinationThenInvalidDestinationError() { + RelyingPartyRegistration registration = registration().build(); + Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) + .id("id") + .build(); + LogoutResponse logoutResponse = TestOpenSamlObjects.assertingPartyLogoutResponse(registration); + logoutResponse.setDestination("wrong"); + sign(logoutResponse, registration); + Saml2LogoutResponse response = post(logoutResponse, registration); + Saml2LogoutResponseValidatorParameters parameters = new Saml2LogoutResponseValidatorParameters(response, + logoutRequest, registration); + Saml2LogoutValidatorResult result = this.manager.validate(parameters); + assertThat(result.hasErrors()).isTrue(); + assertThat(result.getErrors().iterator().next().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_DESTINATION); + } + + @Test + public void handleWhenStatusNotSuccessThenInvalidResponseError() { + RelyingPartyRegistration registration = registration().build(); + Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) + .id("id") + .build(); + LogoutResponse logoutResponse = TestOpenSamlObjects.assertingPartyLogoutResponse(registration); + logoutResponse.getStatus().getStatusCode().setValue(StatusCode.UNKNOWN_PRINCIPAL); + sign(logoutResponse, registration); + Saml2LogoutResponse response = post(logoutResponse, registration); + Saml2LogoutResponseValidatorParameters parameters = new Saml2LogoutResponseValidatorParameters(response, + logoutRequest, registration); + Saml2LogoutValidatorResult result = this.manager.validate(parameters); + assertThat(result.hasErrors()).isTrue(); + assertThat(result.getErrors().iterator().next().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE); + } + + // gh-10923 + @Test + public void handleWhenLogoutResponseHasLineBreaksThenHandles() { + RelyingPartyRegistration registration = signing(verifying(registration())).build(); + Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) + .id("id") + .build(); + LogoutResponse logoutResponse = TestOpenSamlObjects.assertingPartyLogoutResponse(registration); + sign(logoutResponse, registration); + String encoded = new StringBuilder( + Saml2Utils.samlEncode(serialize(logoutResponse).getBytes(StandardCharsets.UTF_8))) + .insert(10, "\r\n") + .toString(); + Saml2LogoutResponse response = Saml2LogoutResponse.withRelyingPartyRegistration(registration) + .samlResponse(encoded) + .build(); + Saml2LogoutResponseValidatorParameters parameters = new Saml2LogoutResponseValidatorParameters(response, + logoutRequest, registration); + this.manager.validate(parameters); + } + + private RelyingPartyRegistration.Builder registration() { + return signing(verifying(TestRelyingPartyRegistrations.noCredentials())) + .assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)); + } + + private RelyingPartyRegistration.Builder verifying(RelyingPartyRegistration.Builder builder) { + return builder.assertingPartyDetails((party) -> party + .verificationX509Credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential()))); + } + + private RelyingPartyRegistration.Builder signing(RelyingPartyRegistration.Builder builder) { + return builder.signingX509Credentials((c) -> c.add(TestSaml2X509Credentials.assertingPartySigningCredential())); + } + + private Saml2LogoutResponse post(LogoutResponse logoutResponse, RelyingPartyRegistration registration) { + return Saml2LogoutResponse.withRelyingPartyRegistration(registration) + .samlResponse(Saml2Utils.samlEncode(serialize(logoutResponse).getBytes(StandardCharsets.UTF_8))) + .build(); + } + + private Saml2LogoutResponse redirect(LogoutResponse logoutResponse, RelyingPartyRegistration registration, + SignatureConfigurer configurer) { + String serialized = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(serialize(logoutResponse))); + Map parameters = configurer.sign(Map.of(Saml2ParameterNames.SAML_RESPONSE, serialized)); + return Saml2LogoutResponse.withRelyingPartyRegistration(registration) + .samlResponse(serialized) + .parameters((params) -> params.putAll(parameters)) + .build(); + } + + private void sign(LogoutResponse logoutResponse, RelyingPartyRegistration registration) { + TestOpenSamlObjects.signed(logoutResponse, registration.getSigningX509Credentials().iterator().next(), + registration.getAssertingPartyDetails().getEntityId()); + } + + private String serialize(XMLObject object) { + return this.saml.serialize(object).serialize(); + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/metadata/OpenSaml5MetadataResolverTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/metadata/OpenSaml5MetadataResolverTests.java new file mode 100644 index 0000000000..c582b79876 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/metadata/OpenSaml5MetadataResolverTests.java @@ -0,0 +1,185 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.metadata; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.security.saml2.core.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link OpenSaml5MetadataResolver} + */ +public class OpenSaml5MetadataResolverTests { + + @Test + public void resolveWhenRelyingPartyThenMetadataMatches() { + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full() + .assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT) + .build(); + OpenSaml5MetadataResolver OpenSaml4MetadataResolver = new OpenSaml5MetadataResolver(); + String metadata = OpenSaml4MetadataResolver.resolve(relyingPartyRegistration); + assertThat(metadata).contains("") + .contains("") + .contains("MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBh") + .contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\"") + .contains("Location=\"https://rp.example.org/acs\" index=\"1\"") + .contains("ResponseLocation=\"https://rp.example.org/logout/saml2/response\""); + } + + @Test + public void resolveWhenRelyingPartyAndSignMetadataSetThenMetadataMatches() { + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full() + .assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT) + .build(); + OpenSaml5MetadataResolver OpenSaml4MetadataResolver = new OpenSaml5MetadataResolver(); + OpenSaml4MetadataResolver.setSignMetadata(true); + String metadata = OpenSaml4MetadataResolver.resolve(relyingPartyRegistration); + assertThat(metadata).contains("") + .contains("") + .contains("MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBh") + .contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\"") + .contains("Location=\"https://rp.example.org/acs\" index=\"1\"") + .contains("ResponseLocation=\"https://rp.example.org/logout/saml2/response\"") + .contains("Signature xmlns:ds=\"http://www.w3.org/2000/09/xmldsig#\"") + .contains("CanonicalizationMethod Algorithm=\"http://www.w3.org/2001/10/xml-exc-c14n#") + .contains("SignatureMethod Algorithm=\"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256") + .contains("Reference URI=\"\"") + .contains("Transform Algorithm=\"http://www.w3.org/2000/09/xmldsig#enveloped-signature") + .contains("Transform Algorithm=\"http://www.w3.org/2001/10/xml-exc-c14n#\"") + .contains("DigestMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#sha256\"") + .contains("DigestValue") + .contains("SignatureValue"); + } + + @Test + public void resolveWhenRelyingPartyNoCredentialsThenMetadataMatches() { + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.noCredentials() + .assertingPartyDetails((party) -> party + .verificationX509Credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential()))) + .build(); + OpenSaml5MetadataResolver OpenSaml4MetadataResolver = new OpenSaml5MetadataResolver(); + String metadata = OpenSaml4MetadataResolver.resolve(relyingPartyRegistration); + assertThat(metadata).contains("") + .doesNotContain("") + .contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"") + .contains("Location=\"https://rp.example.org/acs\" index=\"1\"") + .contains("ResponseLocation=\"https://rp.example.org/logout/saml2/response\""); + } + + @Test + public void resolveWhenRelyingPartyNameIDFormatThenMetadataMatches() { + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full() + .nameIdFormat("format") + .build(); + OpenSaml5MetadataResolver OpenSaml4MetadataResolver = new OpenSaml5MetadataResolver(); + String metadata = OpenSaml4MetadataResolver.resolve(relyingPartyRegistration); + assertThat(metadata).contains("format"); + } + + @Test + public void resolveWhenRelyingPartyNoLogoutThenMetadataMatches() { + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full() + .singleLogoutServiceLocation(null) + .nameIdFormat("format") + .build(); + OpenSaml5MetadataResolver OpenSaml4MetadataResolver = new OpenSaml5MetadataResolver(); + String metadata = OpenSaml4MetadataResolver.resolve(relyingPartyRegistration); + assertThat(metadata).doesNotContain("ResponseLocation"); + } + + @Test + public void resolveWhenEntityDescriptorCustomizerThenUses() { + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full() + .entityId("originalEntityId") + .build(); + OpenSaml5MetadataResolver OpenSaml4MetadataResolver = new OpenSaml5MetadataResolver(); + OpenSaml4MetadataResolver.setEntityDescriptorCustomizer( + (parameters) -> parameters.getEntityDescriptor().setEntityID("overriddenEntityId")); + String metadata = OpenSaml4MetadataResolver.resolve(relyingPartyRegistration); + assertThat(metadata).contains("") + .contains("") + .contains("MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBh") + .contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\"") + .contains("Location=\"https://rp.example.org/acs\" index=\"1\"") + .contains("ResponseLocation=\"https://rp.example.org/logout/saml2/response\""); + } + + @Test + public void resolveIterableWhenRelyingPartiesAndSignMetadataSetThenMetadataMatches() { + RelyingPartyRegistration one = TestRelyingPartyRegistrations.full() + .assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT) + .build(); + RelyingPartyRegistration two = TestRelyingPartyRegistrations.full() + .entityId("two") + .assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT) + .build(); + OpenSaml5MetadataResolver OpenSaml5MetadataResolver = new OpenSaml5MetadataResolver(); + OpenSaml5MetadataResolver.setSignMetadata(true); + String metadata = OpenSaml5MetadataResolver.resolve(List.of(one, two)); + assertThat(metadata).contains("") + .contains("") + .contains("MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBh") + .contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\"") + .contains("Location=\"https://rp.example.org/acs\" index=\"1\"") + .contains("ResponseLocation=\"https://rp.example.org/logout/saml2/response\"") + .contains("Signature xmlns:ds=\"http://www.w3.org/2000/09/xmldsig#\"") + .contains("CanonicalizationMethod Algorithm=\"http://www.w3.org/2001/10/xml-exc-c14n#") + .contains("SignatureMethod Algorithm=\"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256") + .contains("Reference URI=\"\"") + .contains("Transform Algorithm=\"http://www.w3.org/2000/09/xmldsig#enveloped-signature") + .contains("Transform Algorithm=\"http://www.w3.org/2001/10/xml-exc-c14n#\"") + .contains("DigestMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#sha256\"") + .contains("DigestValue") + .contains("SignatureValue"); + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepositoryTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepositoryTests.java new file mode 100644 index 0000000000..6184ed1db3 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepositoryTests.java @@ -0,0 +1,379 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.registration; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import net.shibboleth.shared.xml.SerializeSupport; +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.saml.metadata.IterableMetadataSource; +import org.opensaml.saml.metadata.resolver.MetadataResolver; +import org.opensaml.saml.metadata.resolver.impl.FilesystemMetadataResolver; +import org.opensaml.saml.metadata.resolver.index.impl.RoleMetadataIndex; +import org.opensaml.saml.saml2.metadata.EntityDescriptor; +import org.opensaml.security.credential.Credential; +import org.w3c.dom.Element; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.ResourceLoader; +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.OpenSamlInitializationService; +import org.springframework.security.saml2.core.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.withSettings; + +/** + * Tests for {@link BaseOpenSamlAssertingPartyMetadataRepository} + */ +public class OpenSaml5AssertingPartyMetadataRepositoryTests { + + static { + OpenSamlInitializationService.initialize(); + } + + private String metadata; + + private String entitiesDescriptor; + + @BeforeEach + public void setup() throws Exception { + ClassPathResource resource = new ClassPathResource("test-metadata.xml"); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { + this.metadata = reader.lines().collect(Collectors.joining()); + } + resource = new ClassPathResource("test-entitiesdescriptor.xml"); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { + this.entitiesDescriptor = reader.lines().collect(Collectors.joining()); + } + } + + @Test + public void withMetadataUrlLocationWhenResolvableThenFindByEntityIdReturns() throws Exception { + try (MockWebServer server = new MockWebServer()) { + server.setDispatcher(new AlwaysDispatch(this.metadata)); + AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation(server.url("/").toString()) + .build(); + AssertingPartyMetadata party = parties.findByEntityId("https://idp.example.com/idp/shibboleth"); + assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"); + assertThat(party.getSingleSignOnServiceLocation()) + .isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO"); + assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(party.getVerificationX509Credentials()).hasSize(1); + assertThat(party.getEncryptionX509Credentials()).hasSize(1); + } + } + + @Test + public void withMetadataUrlLocationnWhenResolvableThenIteratorReturns() throws Exception { + try (MockWebServer server = new MockWebServer()) { + server.setDispatcher(new AlwaysDispatch(this.entitiesDescriptor)); + List parties = new ArrayList<>(); + OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation(server.url("/").toString()) + .build() + .iterator() + .forEachRemaining(parties::add); + assertThat(parties).hasSize(2); + assertThat(parties).extracting(AssertingPartyMetadata::getEntityId) + .contains("https://ap.example.org/idp/shibboleth", "https://idp.example.com/idp/shibboleth"); + } + } + + @Test + public void withMetadataUrlLocationWhenUnresolvableThenThrowsSaml2Exception() throws Exception { + try (MockWebServer server = new MockWebServer()) { + server.enqueue(new MockResponse().setBody(this.metadata).setResponseCode(200)); + String url = server.url("/").toString(); + server.shutdown(); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build()); + } + } + + @Test + public void withMetadataUrlLocationWhenMalformedResponseThenSaml2Exception() throws Exception { + try (MockWebServer server = new MockWebServer()) { + server.setDispatcher(new AlwaysDispatch("malformed")); + String url = server.url("/").toString(); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build()); + } + } + + @Test + public void fromMetadataFileLocationWhenResolvableThenFindByEntityIdReturns() { + File file = new File("src/test/resources/test-metadata.xml"); + AssertingPartyMetadata party = OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation("file:" + file.getAbsolutePath()) + .build() + .findByEntityId("https://idp.example.com/idp/shibboleth"); + assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"); + assertThat(party.getSingleSignOnServiceLocation()) + .isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO"); + assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(party.getVerificationX509Credentials()).hasSize(1); + assertThat(party.getEncryptionX509Credentials()).hasSize(1); + } + + @Test + public void fromMetadataFileLocationWhenResolvableThenIteratorReturns() { + File file = new File("src/test/resources/test-entitiesdescriptor.xml"); + Collection parties = new ArrayList<>(); + OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation("file:" + file.getAbsolutePath()) + .build() + .iterator() + .forEachRemaining(parties::add); + assertThat(parties).hasSize(2); + assertThat(parties).extracting(AssertingPartyMetadata::getEntityId) + .contains("https://idp.example.com/idp/shibboleth", "https://ap.example.org/idp/shibboleth"); + } + + @Test + public void withMetadataFileLocationWhenNotFoundThenSaml2Exception() { + assertThatExceptionOfType(Saml2Exception.class).isThrownBy( + () -> OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation("file:path").build()); + } + + @Test + public void fromMetadataClasspathLocationWhenResolvableThenFindByEntityIdReturns() { + AssertingPartyMetadata party = OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation("classpath:test-entitiesdescriptor.xml") + .build() + .findByEntityId("https://ap.example.org/idp/shibboleth"); + assertThat(party.getEntityId()).isEqualTo("https://ap.example.org/idp/shibboleth"); + assertThat(party.getSingleSignOnServiceLocation()) + .isEqualTo("https://ap.example.org/idp/profile/SAML2/POST/SSO"); + assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(party.getVerificationX509Credentials()).hasSize(1); + assertThat(party.getEncryptionX509Credentials()).hasSize(1); + } + + @Test + public void fromMetadataClasspathLocationWhenResolvableThenIteratorReturns() { + Collection parties = new ArrayList<>(); + OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation("classpath:test-entitiesdescriptor.xml") + .build() + .iterator() + .forEachRemaining(parties::add); + assertThat(parties).hasSize(2); + assertThat(parties).extracting(AssertingPartyMetadata::getEntityId) + .contains("https://idp.example.com/idp/shibboleth", "https://ap.example.org/idp/shibboleth"); + } + + @Test + public void withMetadataClasspathLocationWhenNotFoundThenSaml2Exception() { + assertThatExceptionOfType(Saml2Exception.class).isThrownBy( + () -> OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation("classpath:path").build()); + } + + @Test + public void withTrustedMetadataLocationWhenMatchingCredentialsThenVerifiesSignature() throws IOException { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); + EntityDescriptor descriptor = TestOpenSamlObjects.entityDescriptor(registration); + TestOpenSamlObjects.signed(descriptor, TestSaml2X509Credentials.assertingPartySigningCredential(), + descriptor.getEntityID()); + String serialized = serialize(descriptor); + Credential credential = TestOpenSamlObjects + .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); + try (MockWebServer server = new MockWebServer()) { + server.start(); + server.setDispatcher(new AlwaysDispatch(serialized)); + AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation(server.url("/").toString()) + .verificationCredentials((c) -> c.add(credential)) + .build(); + assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); + } + } + + @Test + public void withTrustedMetadataLocationWhenMismatchingCredentialsThenSaml2Exception() throws IOException { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); + EntityDescriptor descriptor = TestOpenSamlObjects.entityDescriptor(registration); + TestOpenSamlObjects.signed(descriptor, TestSaml2X509Credentials.relyingPartySigningCredential(), + descriptor.getEntityID()); + String serialized = serialize(descriptor); + Credential credential = TestOpenSamlObjects + .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); + try (MockWebServer server = new MockWebServer()) { + server.start(); + server.setDispatcher(new AlwaysDispatch(serialized)); + assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation(server.url("/").toString()) + .verificationCredentials((c) -> c.add(credential)) + .build()); + } + } + + @Test + public void withTrustedMetadataLocationWhenNoCredentialsThenSkipsVerifySignature() throws IOException { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); + EntityDescriptor descriptor = TestOpenSamlObjects.entityDescriptor(registration); + TestOpenSamlObjects.signed(descriptor, TestSaml2X509Credentials.assertingPartySigningCredential(), + descriptor.getEntityID()); + String serialized = serialize(descriptor); + try (MockWebServer server = new MockWebServer()) { + server.start(); + server.setDispatcher(new AlwaysDispatch(serialized)); + AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation(server.url("/").toString()) + .build(); + assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); + } + } + + @Test + public void withTrustedMetadataLocationWhenCustomResourceLoaderThenUses() { + ResourceLoader resourceLoader = mock(ResourceLoader.class); + given(resourceLoader.getResource(any())).willReturn(new ClassPathResource("test-metadata.xml")); + AssertingPartyMetadata party = OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation("classpath:wrong") + .resourceLoader(resourceLoader) + .build() + .iterator() + .next(); + assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"); + assertThat(party.getSingleSignOnServiceLocation()) + .isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO"); + assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(party.getVerificationX509Credentials()).hasSize(1); + assertThat(party.getEncryptionX509Credentials()).hasSize(1); + verify(resourceLoader).getResource(any()); + } + + @Test + public void constructorWhenNoIndexAndNoIteratorThenException() { + MetadataResolver resolver = mock(MetadataResolver.class); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> new OpenSaml5AssertingPartyMetadataRepository(resolver)); + } + + @Test + public void constructorWhenIterableResolverThenUses() { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); + EntityDescriptor descriptor = TestOpenSamlObjects.entityDescriptor(registration); + MetadataResolver resolver = mock(MetadataResolver.class, + withSettings().extraInterfaces(IterableMetadataSource.class)); + given(((IterableMetadataSource) resolver).iterator()).willReturn(List.of(descriptor).iterator()); + AssertingPartyMetadataRepository parties = new OpenSaml5AssertingPartyMetadataRepository(resolver); + parties.iterator() + .forEachRemaining((p) -> assertThat(p.getEntityId()) + .isEqualTo(registration.getAssertingPartyDetails().getEntityId())); + verify(((IterableMetadataSource) resolver)).iterator(); + } + + @Test + public void constructorWhenIndexedResolverThenUses() throws Exception { + FilesystemMetadataResolver resolver = new FilesystemMetadataResolver( + new ClassPathResource("test-metadata.xml").getFile()); + resolver.setIndexes(Set.of(new RoleMetadataIndex())); + resolver.setId("id"); + resolver.setParserPool(XMLObjectProviderRegistrySupport.getParserPool()); + resolver.initialize(); + MetadataResolver spied = spy(resolver); + AssertingPartyMetadataRepository parties = new OpenSaml5AssertingPartyMetadataRepository(spied); + parties.iterator() + .forEachRemaining((p) -> assertThat(p.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth")); + verify(spied).resolve(any()); + } + + @Test + public void withMetadataLocationWhenNoCredentialsThenException() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy( + () -> OpenSaml5AssertingPartyMetadataRepository.withMetadataLocation("classpath:test-metadata.xml") + .build()); + } + + @Test + public void withMetadataLocationWhenMatchingCredentialsThenVerifiesSignature() throws IOException { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); + EntityDescriptor descriptor = TestOpenSamlObjects.entityDescriptor(registration); + TestOpenSamlObjects.signed(descriptor, TestSaml2X509Credentials.assertingPartySigningCredential(), + descriptor.getEntityID()); + String serialized = serialize(descriptor); + Credential credential = TestOpenSamlObjects + .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); + try (MockWebServer server = new MockWebServer()) { + server.start(); + server.setDispatcher(new AlwaysDispatch(serialized)); + AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository + .withMetadataLocation(server.url("/").toString()) + .verificationCredentials((c) -> c.add(credential)) + .build(); + assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); + } + } + + private static String serialize(XMLObject object) { + try { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + Element element = marshaller.marshall(object); + return SerializeSupport.nodeToString(element); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + private static final class AlwaysDispatch extends Dispatcher { + + private final MockResponse response; + + private AlwaysDispatch(String body) { + this.response = new MockResponse().setBody(body) + .setResponseCode(200) + .setBodyDelay(1, TimeUnit.MILLISECONDS); + } + + private AlwaysDispatch(MockResponse response) { + this.response = response; + } + + @Override + public MockResponse dispatch(RecordedRequest recordedRequest) throws InterruptedException { + return this.response; + } + + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverterTests.java new file mode 100644 index 0000000000..1c35ec58e3 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverterTests.java @@ -0,0 +1,246 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Instant; + +import jakarta.servlet.http.HttpServletRequest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.saml.common.SignableSAMLObject; +import org.opensaml.saml.saml2.core.Response; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2Utils; +import org.springframework.security.saml2.core.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; +import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.util.StreamUtils; +import org.springframework.web.util.UriUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link OpenSaml5AuthenticationTokenConverter} + */ +@ExtendWith(MockitoExtension.class) +public final class OpenSaml5AuthenticationTokenConverterTests { + + @Mock + RelyingPartyRegistrationRepository registrations; + + private final OpenSamlOperations saml = new OpenSaml5Template(); + + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build(); + + @Test + public void convertWhenSamlResponseThenToken() { + OpenSaml5AuthenticationTokenConverter converter = new OpenSaml5AuthenticationTokenConverter(this.registrations); + given(this.registrations.findByRegistrationId(any())).willReturn(this.registration); + MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId()); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, + Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8))); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo("response"); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(this.registration.getRegistrationId()); + } + + @Test + public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() { + OpenSaml5AuthenticationTokenConverter converter = new OpenSaml5AuthenticationTokenConverter(this.registrations); + given(this.registrations.findByRegistrationId(any())).willReturn(this.registration); + MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId()); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "invalid"); + assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request)) + .withCauseInstanceOf(IllegalArgumentException.class) + .satisfies( + (ex) -> assertThat(ex.getSaml2Error().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE)) + .satisfies( + (ex) -> assertThat(ex.getSaml2Error().getDescription()).isEqualTo("Failed to decode SAMLResponse")); + } + + @Test + public void convertWhenNoSamlResponseThenNull() { + OpenSaml5AuthenticationTokenConverter converter = new OpenSaml5AuthenticationTokenConverter(this.registrations); + MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId()); + assertThat(converter.convert(request)).isNull(); + } + + @Test + public void convertWhenNoMatchingRequestThenNull() { + OpenSaml5AuthenticationTokenConverter converter = new OpenSaml5AuthenticationTokenConverter(this.registrations); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "ignored"); + assertThat(converter.convert(request)).isNull(); + } + + @Test + public void convertWhenNoRelyingPartyRegistrationThenNull() { + OpenSaml5AuthenticationTokenConverter converter = new OpenSaml5AuthenticationTokenConverter(this.registrations); + MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId()); + String response = Saml2Utils.samlEncode(serialize(signed(response())).getBytes(StandardCharsets.UTF_8)); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, response); + assertThat(converter.convert(request)).isNull(); + } + + @Test + public void convertWhenGetRequestThenInflates() { + OpenSaml5AuthenticationTokenConverter converter = new OpenSaml5AuthenticationTokenConverter(this.registrations); + given(this.registrations.findByRegistrationId(any())).willReturn(this.registration); + MockHttpServletRequest request = get("/login/saml2/sso/" + this.registration.getRegistrationId()); + byte[] deflated = Saml2Utils.samlDeflate("response"); + String encoded = Saml2Utils.samlEncode(deflated); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, encoded); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo("response"); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(this.registration.getRegistrationId()); + } + + @Test + public void convertWhenGetRequestInvalidDeflatedThenSaml2AuthenticationException() { + OpenSaml5AuthenticationTokenConverter converter = new OpenSaml5AuthenticationTokenConverter(this.registrations); + given(this.registrations.findByRegistrationId(any())).willReturn(this.registration); + MockHttpServletRequest request = get("/login/saml2/sso/" + this.registration.getRegistrationId()); + byte[] invalidDeflated = "invalid".getBytes(); + String encoded = Saml2Utils.samlEncode(invalidDeflated); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, encoded); + assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request)) + .withRootCauseInstanceOf(IOException.class) + .satisfies( + (ex) -> assertThat(ex.getSaml2Error().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE)) + .satisfies((ex) -> assertThat(ex.getSaml2Error().getDescription()).isEqualTo("Unable to inflate string")); + } + + @Test + public void convertWhenUsingSamlUtilsBase64ThenXmlIsValid() throws Exception { + OpenSaml5AuthenticationTokenConverter converter = new OpenSaml5AuthenticationTokenConverter(this.registrations); + given(this.registrations.findByRegistrationId(any())).willReturn(this.registration); + MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId()); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, getSsoCircleEncodedXml()); + Saml2AuthenticationToken token = converter.convert(request); + validateSsoCircleXml(token.getSaml2Response()); + } + + @Test + public void convertWhenSavedAuthenticationRequestThenToken() { + Saml2AuthenticationRequestRepository authenticationRequestRepository = mock( + Saml2AuthenticationRequestRepository.class); + AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); + given(authenticationRequest.getRelyingPartyRegistrationId()).willReturn(this.registration.getRegistrationId()); + OpenSaml5AuthenticationTokenConverter converter = new OpenSaml5AuthenticationTokenConverter(this.registrations); + converter.setAuthenticationRequestRepository(authenticationRequestRepository); + given(this.registrations.findByRegistrationId(any())).willReturn(this.registration); + given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class))) + .willReturn(authenticationRequest); + MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId()); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, + Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8))); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo("response"); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(this.registration.getRegistrationId()); + assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest); + } + + @Test + public void convertWhenMatchingNoRegistrationIdThenLooksUpByAssertingEntityId() { + OpenSaml5AuthenticationTokenConverter converter = new OpenSaml5AuthenticationTokenConverter(this.registrations); + String response = serialize(signed(response())); + String encoded = Saml2Utils.samlEncode(response.getBytes(StandardCharsets.UTF_8)); + given(this.registrations.findUniqueByAssertingPartyEntityId(TestOpenSamlObjects.ASSERTING_PARTY_ENTITY_ID)) + .willReturn(this.registration); + MockHttpServletRequest request = post("/login/saml2/sso"); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, encoded); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo(response); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(this.registration.getRegistrationId()); + } + + @Test + public void constructorWhenResolverIsNullThenIllegalArgument() { + assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null)); + } + + @Test + public void setAuthenticationRequestRepositoryWhenNullThenIllegalArgument() { + OpenSaml5AuthenticationTokenConverter converter = new OpenSaml5AuthenticationTokenConverter(this.registrations); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> converter.setAuthenticationRequestRepository(null)); + } + + private void validateSsoCircleXml(String xml) { + assertThat(xml).contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"") + .contains(" ID=\"s246d157446618e90e43fb79bdd4d9e9e19cf2c7c4\"") + .contains("https://idp.ssocircle.com"); + } + + private String getSsoCircleEncodedXml() throws IOException { + ClassPathResource resource = new ClassPathResource("saml2-response-sso-circle.encoded"); + String response = StreamUtils.copyToString(resource.getInputStream(), StandardCharsets.UTF_8); + return UriUtils.decode(response, StandardCharsets.UTF_8); + } + + private MockHttpServletRequest post(String uri) { + MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); + request.setServletPath(uri); + return request; + } + + private MockHttpServletRequest get(String uri) { + MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); + request.setServletPath(uri); + return request; + } + + private T signed(T toSign) { + TestOpenSamlObjects.signed(toSign, TestSaml2X509Credentials.assertingPartySigningCredential(), + TestOpenSamlObjects.RELYING_PARTY_ENTITY_ID); + return toSign; + } + + private Response response() { + Response response = TestOpenSamlObjects.response(); + response.setIssueInstant(Instant.now()); + return response; + } + + private String serialize(XMLObject object) { + return this.saml.serialize(object).serialize(); + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java new file mode 100644 index 0000000000..2284dfa0fd --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import jakarta.servlet.http.HttpServletRequest; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +public class OpenSaml5AuthenticationRequestResolverTests { + + MockHttpServletRequest request; + + RelyingPartyRegistration registration; + + @BeforeEach + void setup() { + this.request = givenRequest("/saml2/authenticate/registration-id"); + this.registration = TestRelyingPartyRegistrations.full().build(); + } + + @Test + void resolveWhenRedirectThenSaml2RedirectAuthenticationRequest() { + RelyingPartyRegistrationResolver relyingParties = mock(RelyingPartyRegistrationResolver.class); + given(relyingParties.resolve(any(), any())).willReturn(this.registration); + OpenSaml5AuthenticationRequestResolver resolver = new OpenSaml5AuthenticationRequestResolver(relyingParties); + Saml2RedirectAuthenticationRequest authnRequest = resolver.resolve(this.request); + assertThat(authnRequest.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + assertThat(authnRequest.getAuthenticationRequestUri()) + .isEqualTo(this.registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); + } + + @Test + void resolveWhenPostThenSaml2PostAuthenticationRequest() { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() + .assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)) + .build(); + RelyingPartyRegistrationResolver relyingParties = mock(RelyingPartyRegistrationResolver.class); + given(relyingParties.resolve(any(), any())).willReturn(registration); + OpenSaml5AuthenticationRequestResolver resolver = new OpenSaml5AuthenticationRequestResolver(relyingParties); + Saml2PostAuthenticationRequest authnRequest = resolver.resolve(this.request); + assertThat(authnRequest.getBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(authnRequest.getAuthenticationRequestUri()) + .isEqualTo(this.registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); + } + + @Test + void resolveWhenCustomRelayStateThenUses() { + RelyingPartyRegistrationResolver relyingParties = mock(RelyingPartyRegistrationResolver.class); + given(relyingParties.resolve(any(), any())).willReturn(this.registration); + Converter relayState = mock(Converter.class); + given(relayState.convert(any())).willReturn("state"); + OpenSaml5AuthenticationRequestResolver resolver = new OpenSaml5AuthenticationRequestResolver(relyingParties); + resolver.setRelayStateResolver(relayState); + Saml2RedirectAuthenticationRequest authnRequest = resolver.resolve(this.request); + assertThat(authnRequest.getRelayState()).isEqualTo("state"); + verify(relayState).convert(any()); + } + + @Test + void resolveWhenCustomAuthenticationUrlTHenUses() { + RelyingPartyRegistrationResolver relyingParties = mock(RelyingPartyRegistrationResolver.class); + given(relyingParties.resolve(any(), any())).willReturn(this.registration); + OpenSaml5AuthenticationRequestResolver resolver = new OpenSaml5AuthenticationRequestResolver(relyingParties); + resolver.setRequestMatcher(new AntPathRequestMatcher("/custom/authentication/{registrationId}")); + Saml2RedirectAuthenticationRequest authnRequest = resolver + .resolve(givenRequest("/custom/authentication/registration-id")); + + assertThat(authnRequest.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + assertThat(authnRequest.getAuthenticationRequestUri()) + .isEqualTo(this.registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); + + } + + private MockHttpServletRequest givenRequest(String path) { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setServletPath(path); + return request; + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5SigningUtilsTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5SigningUtilsTests.java new file mode 100644 index 0000000000..2870fdc7c2 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5SigningUtilsTests.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication; + +import java.util.UUID; + +import javax.xml.namespace.QName; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.saml.common.SAMLVersion; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.xmlsec.signature.Signature; + +import org.springframework.security.saml2.core.OpenSamlInitializationService; +import org.springframework.security.saml2.core.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test open SAML signatures + */ +public class OpenSaml5SigningUtilsTests { + + static { + OpenSamlInitializationService.initialize(); + } + + private final OpenSamlOperations saml = new OpenSaml5Template(); + + private RelyingPartyRegistration registration; + + @BeforeEach + public void setup() { + this.registration = RelyingPartyRegistration.withRegistrationId("saml-idp") + .entityId("https://some.idp.example.com/entity-id") + .signingX509Credentials((c) -> { + c.add(TestSaml2X509Credentials.relyingPartySigningCredential()); + c.add(TestSaml2X509Credentials.assertingPartySigningCredential()); + }) + .assertingPartyDetails((c) -> c.entityId("https://some.idp.example.com/entity-id") + .singleSignOnServiceLocation("https://some.idp.example.com/service-location")) + .build(); + } + + @Test + public void whenSigningAnObjectThenKeyInfoIsPartOfTheSignature() { + Response response = response("destination", "issuer"); + this.saml.withSigningKeys(this.registration.getSigningX509Credentials()).sign(response); + Signature signature = response.getSignature(); + assertThat(signature).isNotNull(); + assertThat(signature.getKeyInfo()).isNotNull(); + } + + Response response(String destination, String issuerEntityId) { + Response response = build(Response.DEFAULT_ELEMENT_NAME); + response.setID("R" + UUID.randomUUID()); + response.setVersion(SAMLVersion.VERSION_20); + response.setID("_" + UUID.randomUUID()); + response.setDestination(destination); + response.setIssuer(issuer(issuerEntityId)); + return response; + } + + Issuer issuer(String entityId) { + Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME); + issuer.setValue(entityId); + return issuer; + } + + T build(QName qName) { + return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName); + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestResolverTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestResolverTests.java new file mode 100644 index 0000000000..3dcbfc8584 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestResolverTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import jakarta.servlet.http.HttpServletRequest; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link OpenSaml5LogoutRequestResolver} + */ +public class OpenSaml5LogoutRequestResolverTests { + + RelyingPartyRegistration registration; + + RelyingPartyRegistrationResolver registrationResolver; + + OpenSaml5LogoutRequestResolver logoutRequestResolver; + + @BeforeEach + public void setup() { + this.registration = TestRelyingPartyRegistrations.full().build(); + this.registrationResolver = mock(RelyingPartyRegistrationResolver.class); + this.logoutRequestResolver = new OpenSaml5LogoutRequestResolver(this.registrationResolver); + } + + @Test + public void resolveWhenCustomParametersConsumerThenUses() { + this.logoutRequestResolver.setParametersConsumer((parameters) -> parameters.getLogoutRequest().setID("myid")); + given(this.registrationResolver.resolve(any(), any())).willReturn(this.registration); + + Saml2LogoutRequest logoutRequest = this.logoutRequestResolver.resolve(givenRequest(), givenAuthentication()); + + assertThat(logoutRequest.getId()).isEqualTo("myid"); + } + + @Test + public void setParametersConsumerWhenNullThenIllegalArgument() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.logoutRequestResolver.setParametersConsumer(null)); + } + + @Test + public void resolveWhenCustomRelayStateThenUses() { + given(this.registrationResolver.resolve(any(), any())).willReturn(this.registration); + Converter relayState = mock(Converter.class); + given(relayState.convert(any())).willReturn("any-state"); + this.logoutRequestResolver.setRelayStateResolver(relayState); + + Saml2LogoutRequest logoutRequest = this.logoutRequestResolver.resolve(givenRequest(), givenAuthentication()); + + assertThat(logoutRequest.getRelayState()).isEqualTo("any-state"); + verify(relayState).convert(any()); + } + + private static Authentication givenAuthentication() { + return new TestingAuthenticationToken("user", "password"); + } + + private MockHttpServletRequest givenRequest() { + return new MockHttpServletRequest(); + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolverTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolverTests.java new file mode 100644 index 0000000000..8ec0cef306 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolverTests.java @@ -0,0 +1,153 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import java.nio.charset.StandardCharsets; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensaml.core.xml.XMLObject; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; +import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; +import org.springframework.security.saml2.provider.service.authentication.TestSaml2Authentications; +import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidatorParameters; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.BDDMockito.given; + +@ExtendWith(MockitoExtension.class) +public final class OpenSaml5LogoutRequestValidatorParametersResolverTests { + + @Mock + RelyingPartyRegistrationRepository registrations; + + private final OpenSamlOperations saml = new OpenSaml5Template(); + + private RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build(); + + private OpenSaml5LogoutRequestValidatorParametersResolver resolver; + + @BeforeEach + void setup() { + this.resolver = new OpenSaml5LogoutRequestValidatorParametersResolver(this.registrations); + } + + @Test + void saml2LogoutRegistrationIdResolveWhenMatchesThenParameters() { + String registrationId = this.registration.getRegistrationId(); + MockHttpServletRequest request = post("/logout/saml2/slo/" + registrationId); + Authentication authentication = new TestingAuthenticationToken("user", "pass"); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + given(this.registrations.findByRegistrationId(registrationId)).willReturn(this.registration); + Saml2LogoutRequestValidatorParameters parameters = this.resolver.resolve(request, authentication); + assertThat(parameters.getAuthentication()).isEqualTo(authentication); + assertThat(parameters.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(registrationId); + assertThat(parameters.getLogoutRequest().getSamlRequest()).isEqualTo("request"); + } + + @Test + void saml2LogoutRegistrationIdWhenUnauthenticatedThenParameters() { + String registrationId = this.registration.getRegistrationId(); + MockHttpServletRequest request = post("/logout/saml2/slo/" + registrationId); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + given(this.registrations.findByRegistrationId(registrationId)).willReturn(this.registration); + Saml2LogoutRequestValidatorParameters parameters = this.resolver.resolve(request, null); + assertThat(parameters.getAuthentication()).isNull(); + assertThat(parameters.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(registrationId); + assertThat(parameters.getLogoutRequest().getSamlRequest()).isEqualTo("request"); + } + + @Test + void saml2LogoutResolveWhenAuthenticatedThenParameters() { + String registrationId = this.registration.getRegistrationId(); + MockHttpServletRequest request = post("/logout/saml2/slo"); + Authentication authentication = TestSaml2Authentications.authentication(); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + given(this.registrations.findByRegistrationId(registrationId)).willReturn(this.registration); + Saml2LogoutRequestValidatorParameters parameters = this.resolver.resolve(request, authentication); + assertThat(parameters.getAuthentication()).isEqualTo(authentication); + assertThat(parameters.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(registrationId); + assertThat(parameters.getLogoutRequest().getSamlRequest()).isEqualTo("request"); + } + + @Test + void saml2LogoutResolveWhenUnauthenticatedThenParameters() { + String registrationId = this.registration.getRegistrationId(); + MockHttpServletRequest request = post("/logout/saml2/slo"); + String logoutRequest = serialize(TestOpenSamlObjects.logoutRequest()); + String encoded = Saml2Utils.samlEncode(logoutRequest.getBytes(StandardCharsets.UTF_8)); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, encoded); + given(this.registrations.findUniqueByAssertingPartyEntityId(TestOpenSamlObjects.ASSERTING_PARTY_ENTITY_ID)) + .willReturn(this.registration); + Saml2LogoutRequestValidatorParameters parameters = this.resolver.resolve(request, null); + assertThat(parameters.getAuthentication()).isNull(); + assertThat(parameters.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(registrationId); + assertThat(parameters.getLogoutRequest().getSamlRequest()).isEqualTo(encoded); + } + + @Test + void saml2LogoutResolveWhenUnauthenticatedGetRequestThenInflates() { + String registrationId = this.registration.getRegistrationId(); + MockHttpServletRequest request = get("/logout/saml2/slo"); + String logoutRequest = serialize(TestOpenSamlObjects.logoutRequest()); + String encoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(logoutRequest)); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, encoded); + given(this.registrations.findUniqueByAssertingPartyEntityId(TestOpenSamlObjects.ASSERTING_PARTY_ENTITY_ID)) + .willReturn(this.registration); + Saml2LogoutRequestValidatorParameters parameters = this.resolver.resolve(request, null); + assertThat(parameters.getAuthentication()).isNull(); + assertThat(parameters.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(registrationId); + assertThat(parameters.getLogoutRequest().getSamlRequest()).isEqualTo(encoded); + } + + @Test + void saml2LogoutRegistrationIdResolveWhenNoMatchingRegistrationIdThenSaml2Exception() { + MockHttpServletRequest request = post("/logout/saml2/slo/id"); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.resolver.resolve(request, null)); + } + + private MockHttpServletRequest post(String uri) { + MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); + request.setServletPath(uri); + return request; + } + + private MockHttpServletRequest get(String uri) { + MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); + request.setServletPath(uri); + return request; + } + + private String serialize(XMLObject object) { + return this.saml.serialize(object).serialize(); + } + +} diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutResponseResolverTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutResponseResolverTests.java new file mode 100644 index 0000000000..392b5ef817 --- /dev/null +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutResponseResolverTests.java @@ -0,0 +1,80 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.opensaml.saml.saml2.core.LogoutRequest; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; +import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutResponse; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; +import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml5LogoutResponseResolver.LogoutResponseParameters; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link OpenSaml5LogoutResponseResolver} + */ +public class OpenSaml5LogoutResponseResolverTests { + + private final OpenSamlOperations saml = new OpenSaml5Template(); + + RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = mock(RelyingPartyRegistrationResolver.class); + + @Test + public void resolveWhenCustomParametersConsumerThenUses() { + OpenSaml5LogoutResponseResolver logoutResponseResolver = new OpenSaml5LogoutResponseResolver( + this.relyingPartyRegistrationResolver); + Consumer parametersConsumer = mock(Consumer.class); + logoutResponseResolver.setParametersConsumer(parametersConsumer); + MockHttpServletRequest request = new MockHttpServletRequest(); + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration() + .assertingPartyDetails( + (party) -> party.singleLogoutServiceResponseLocation("https://ap.example.com/logout")) + .build(); + Authentication authentication = new TestingAuthenticationToken("user", "password"); + LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequest(registration); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, + Saml2Utils.samlEncode(this.saml.serialize(logoutRequest).serialize().getBytes())); + given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); + Saml2LogoutResponse logoutResponse = logoutResponseResolver.resolve(request, authentication); + assertThat(logoutResponse).isNotNull(); + verify(parametersConsumer).accept(any()); + } + + @Test + public void setParametersConsumerWhenNullThenIllegalArgument() { + OpenSaml5LogoutRequestResolver logoutRequestResolver = new OpenSaml5LogoutRequestResolver( + this.relyingPartyRegistrationResolver); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> logoutRequestResolver.setParametersConsumer(null)); + } + +}