From 51fc05630d809143a2181f04e9e25eca1826d84c Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Mon, 5 Aug 2024 08:37:54 -0600 Subject: [PATCH] Use OpenSAML API for web.authentication.logout Issue gh-11658 --- .../saml2/Saml2LogoutConfigurer.java | 4 +- .../ROOT/pages/servlet/saml2/logout.adoc | 2 +- ...ing-security-saml2-service-provider.gradle | 6 + ...=> BaseOpenSamlLogoutRequestResolver.java} | 109 +++- ...outRequestValidatorParametersResolver.java | 199 ++++++ ...> BaseOpenSamlLogoutResponseResolver.java} | 154 +++-- .../OpenSaml4LogoutRequestResolver.java | 28 +- ...outRequestValidatorParametersResolver.java | 100 +++ .../OpenSaml4LogoutResponseResolver.java | 52 +- .../logout/OpenSaml4Template.java | 617 ++++++++++++++++++ ...outRequestValidatorParametersResolver.java | 48 +- .../logout/OpenSamlOperations.java | 184 ++++++ .../logout/OpenSamlSigningUtils.java | 194 ------ .../web/authentication/logout/Saml2Utils.java | 122 +++- ...questValidatorParametersResolverTests.java | 153 +++++ .../OpenSaml4LogoutResponseResolverTests.java | 4 +- .../OpenSamlLogoutRequestResolverTests.java | 121 ---- ...questValidatorParametersResolverTests.java | 17 +- .../OpenSamlLogoutResponseResolverTests.java | 153 ----- .../logout/OpenSamlSigningUtilsTests.java | 91 --- .../logout/Saml2LogoutSigningUtilsTests.java | 59 -- 21 files changed, 1613 insertions(+), 804 deletions(-) rename saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/{OpenSamlLogoutRequestResolver.java => BaseOpenSamlLogoutRequestResolver.java} (72%) create mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/BaseOpenSamlLogoutRequestValidatorParametersResolver.java rename saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/{OpenSamlLogoutResponseResolver.java => BaseOpenSamlLogoutResponseResolver.java} (65%) create mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolver.java create mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4Template.java create mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlOperations.java delete mode 100644 saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlSigningUtils.java create mode 100644 saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolverTests.java delete mode 100644 saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestResolverTests.java delete mode 100644 saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutResponseResolverTests.java delete mode 100644 saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlSigningUtilsTests.java delete mode 100644 saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutSigningUtilsTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java index 4d1291fd18..d746732e5b 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java @@ -40,7 +40,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP import org.springframework.security.saml2.provider.service.web.authentication.logout.HttpSessionLogoutRequestRepository; import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutRequestResolver; import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutResponseResolver; -import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSamlLogoutRequestValidatorParametersResolver; +import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutRequestValidatorParametersResolver; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestFilter; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestRepository; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestResolver; @@ -251,7 +251,7 @@ public final class Saml2LogoutConfigurer> LogoutHandler[] logoutHandlers = this.logoutHandlers.toArray(new LogoutHandler[0]); Saml2LogoutResponseResolver logoutResponseResolver = createSaml2LogoutResponseResolver(registrations); RequestMatcher requestMatcher = createLogoutRequestMatcher(); - OpenSamlLogoutRequestValidatorParametersResolver parameters = new OpenSamlLogoutRequestValidatorParametersResolver( + OpenSaml4LogoutRequestValidatorParametersResolver parameters = new OpenSaml4LogoutRequestValidatorParametersResolver( registrations); parameters.setRequestMatcher(requestMatcher); Saml2LogoutRequestFilter filter = new Saml2LogoutRequestFilter(parameters, diff --git a/docs/modules/ROOT/pages/servlet/saml2/logout.adoc b/docs/modules/ROOT/pages/servlet/saml2/logout.adoc index 91aeb651ae..03e7bb153e 100644 --- a/docs/modules/ROOT/pages/servlet/saml2/logout.adoc +++ b/docs/modules/ROOT/pages/servlet/saml2/logout.adoc @@ -605,7 +605,7 @@ Kotlin:: ---- @Component open class MyOpenSamlLogoutResponseValidator: Saml2LogoutResponseValidator { - private val delegate = OpenSamlLogoutResponseValidator() + private val delegate = OpenSaml4LogoutResponseValidator() @Override fun logout(parameters: Saml2LogoutResponseValidatorParameters): Saml2LogoutResponseValidator { diff --git a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle index 3b1a60b309..a2576d297b 100644 --- a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle +++ b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle @@ -25,6 +25,12 @@ sourceSets.configureEach { set -> filter { line -> line.replaceAll(".saml2.internal", ".saml2.provider.service.metadata") } with from } + + copy { + into "$projectDir/src/$set.name/java/org/springframework/security/saml2/provider/service/web/authentication/logout" + filter { line -> line.replaceAll(".saml2.internal", ".saml2.provider.service.web.authentication.logout") } + with from + } } dependencies { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/BaseOpenSamlLogoutRequestResolver.java similarity index 72% rename from saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestResolver.java rename to saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/BaseOpenSamlLogoutRequestResolver.java index 937233146e..980cfa4ac6 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/BaseOpenSamlLogoutRequestResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,17 +16,18 @@ package org.springframework.security.saml2.provider.service.web.authentication.logout; -import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; import java.util.UUID; -import java.util.function.BiConsumer; +import java.util.function.Consumer; import jakarta.servlet.http.HttpServletRequest; -import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.opensaml.core.config.ConfigurationService; import org.opensaml.core.xml.config.XMLObjectProviderRegistry; -import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.saml.saml2.core.Issuer; import org.opensaml.saml.saml2.core.LogoutRequest; import org.opensaml.saml.saml2.core.NameID; @@ -36,11 +37,9 @@ import org.opensaml.saml.saml2.core.impl.LogoutRequestBuilder; import org.opensaml.saml.saml2.core.impl.LogoutRequestMarshaller; import org.opensaml.saml.saml2.core.impl.NameIDBuilder; import org.opensaml.saml.saml2.core.impl.SessionIndexBuilder; -import org.w3c.dom.Element; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.Authentication; -import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2ParameterNames; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal; @@ -50,14 +49,13 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; -import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSamlSigningUtils.QueryParametersPartial; import org.springframework.util.Assert; /** * For internal use only. Intended for consolidating common behavior related to minting a * SAML 2.0 Logout Request. */ -final class OpenSamlLogoutRequestResolver { +final class BaseOpenSamlLogoutRequestResolver implements Saml2LogoutRequestResolver { static { OpenSamlInitializationService.initialize(); @@ -65,6 +63,10 @@ final class OpenSamlLogoutRequestResolver { private final Log logger = LogFactory.getLog(getClass()); + private final OpenSamlOperations saml; + + private Clock clock = Clock.systemUTC(); + private final LogoutRequestMarshaller marshaller; private final IssuerBuilder issuerBuilder; @@ -79,11 +81,16 @@ final class OpenSamlLogoutRequestResolver { private Converter relayStateResolver = (request) -> UUID.randomUUID().toString(); + private Consumer parametersConsumer = (parameters) -> { + }; + /** - * Construct a {@link OpenSamlLogoutRequestResolver} + * Construct a {@link BaseOpenSamlLogoutRequestResolver} */ - OpenSamlLogoutRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + BaseOpenSamlLogoutRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver, + OpenSamlOperations saml) { this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; + this.saml = saml; XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); this.marshaller = (LogoutRequestMarshaller) registry.getMarshallerFactory() .getMarshaller(LogoutRequest.DEFAULT_ELEMENT_NAME); @@ -100,10 +107,18 @@ final class OpenSamlLogoutRequestResolver { Assert.notNull(this.sessionIndexBuilder, "sessionIndexBuilder must be configured in OpenSAML"); } + void setClock(Clock clock) { + this.clock = clock; + } + void setRelayStateResolver(Converter relayStateResolver) { this.relayStateResolver = relayStateResolver; } + void setParametersConsumer(Consumer parametersConsumer) { + this.parametersConsumer = parametersConsumer; + } + /** * Prepare to create, sign, and serialize a SAML 2.0 Logout Request. * @@ -114,13 +129,8 @@ final class OpenSamlLogoutRequestResolver { * @param authentication the current user * @return a signed and serialized SAML 2.0 Logout Request */ - Saml2LogoutRequest resolve(HttpServletRequest request, Authentication authentication) { - return resolve(request, authentication, (registration, logoutRequest) -> { - }); - } - - Saml2LogoutRequest resolve(HttpServletRequest request, Authentication authentication, - BiConsumer logoutRequestConsumer) { + @Override + public Saml2LogoutRequest resolve(HttpServletRequest request, Authentication authentication) { String registrationId = getRegistrationId(authentication); RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request, registrationId); if (registration == null) { @@ -147,7 +157,9 @@ final class OpenSamlLogoutRequestResolver { logoutRequest.getSessionIndexes().add(sessionIndex); } } - logoutRequestConsumer.accept(registration, logoutRequest); + logoutRequest.setIssueInstant(Instant.now(this.clock)); + this.parametersConsumer + .accept(new LogoutRequestParameters(request, registration, authentication, logoutRequest)); if (logoutRequest.getID() == null) { logoutRequest.setID("LR" + UUID.randomUUID()); } @@ -155,18 +167,23 @@ final class OpenSamlLogoutRequestResolver { Saml2LogoutRequest.Builder result = Saml2LogoutRequest.withRelyingPartyRegistration(registration) .id(logoutRequest.getID()); if (registration.getAssertingPartyMetadata().getSingleLogoutServiceBinding() == Saml2MessageBinding.POST) { - String xml = serialize(OpenSamlSigningUtils.sign(logoutRequest, registration)); - String samlRequest = Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8)); + String xml = serialize(this.saml.withSigningKeys(registration.getSigningX509Credentials()) + .algorithms(registration.getAssertingPartyMetadata().getSigningAlgorithms()) + .sign(logoutRequest)); + String samlRequest = Saml2Utils.withDecoded(xml).encode(); return result.samlRequest(samlRequest).relayState(relayState).build(); } else { String xml = serialize(logoutRequest); - String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); + String deflatedAndEncoded = Saml2Utils.withDecoded(xml).deflate(true).encode(); result.samlRequest(deflatedAndEncoded); - QueryParametersPartial partial = OpenSamlSigningUtils.sign(registration) - .param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded) - .param(Saml2ParameterNames.RELAY_STATE, relayState); - return result.parameters((params) -> params.putAll(partial.parameters())).build(); + Map signingParameters = new HashMap<>(); + signingParameters.put(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded); + signingParameters.put(Saml2ParameterNames.RELAY_STATE, relayState); + Map query = this.saml.withSigningKeys(registration.getSigningX509Credentials()) + .algorithms(registration.getAssertingPartyMetadata().getSigningAlgorithms()) + .sign(signingParameters); + return result.parameters((params) -> params.putAll(query)).build(); } } @@ -185,13 +202,43 @@ final class OpenSamlLogoutRequestResolver { } private String serialize(LogoutRequest logoutRequest) { - try { - Element element = this.marshaller.marshall(logoutRequest); - return SerializeSupport.nodeToString(element); + return this.saml.serialize(logoutRequest).serialize(); + } + + static final class LogoutRequestParameters { + + private final HttpServletRequest request; + + private final RelyingPartyRegistration registration; + + private final Authentication authentication; + + private final LogoutRequest logoutRequest; + + LogoutRequestParameters(HttpServletRequest request, RelyingPartyRegistration registration, + Authentication authentication, LogoutRequest logoutRequest) { + this.request = request; + this.registration = registration; + this.authentication = authentication; + this.logoutRequest = logoutRequest; + } + + HttpServletRequest getRequest() { + return this.request; } - catch (MarshallingException ex) { - throw new Saml2Exception(ex); + + RelyingPartyRegistration getRelyingPartyRegistration() { + return this.registration; + } + + Authentication getAuthentication() { + return this.authentication; } + + LogoutRequest getLogoutRequest() { + return this.logoutRequest; + } + } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/BaseOpenSamlLogoutRequestValidatorParametersResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/BaseOpenSamlLogoutRequestValidatorParametersResolver.java new file mode 100644 index 0000000000..95853fbe9f --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/BaseOpenSamlLogoutRequestValidatorParametersResolver.java @@ -0,0 +1,199 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import jakarta.servlet.http.HttpServletRequest; +import org.opensaml.saml.saml2.core.LogoutRequest; + +import org.springframework.http.HttpMethod; +import org.springframework.security.core.Authentication; +import org.springframework.security.saml2.core.OpenSamlInitializationService; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; +import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest; +import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidatorParameters; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.OrRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; + +/** + * An OpenSAML-based implementation of + * {@link Saml2LogoutRequestValidatorParametersResolver} + */ +final class BaseOpenSamlLogoutRequestValidatorParametersResolver + implements Saml2LogoutRequestValidatorParametersResolver { + + static { + OpenSamlInitializationService.initialize(); + } + + private final OpenSamlOperations saml; + + private final RelyingPartyRegistrationRepository registrations; + + private RequestMatcher requestMatcher = new OrRequestMatcher( + new AntPathRequestMatcher("/logout/saml2/slo/{registrationId}"), + new AntPathRequestMatcher("/logout/saml2/slo")); + + /** + * Constructs a {@link BaseOpenSamlLogoutRequestValidatorParametersResolver} + */ + BaseOpenSamlLogoutRequestValidatorParametersResolver(OpenSamlOperations saml, + RelyingPartyRegistrationRepository registrations) { + Assert.notNull(registrations, "relyingPartyRegistrationRepository cannot be null"); + this.saml = saml; + this.registrations = registrations; + } + + /** + * Construct the parameters necessary for validating an asserting party's + * {@code } based on the given {@link HttpServletRequest} + * + *

+ * Uses the configured {@link RequestMatcher} to identify the processing request, + * including looking for any indicated {@code registrationId}. + * + *

+ * If a {@code registrationId} is found in the request, it will attempt to use that, + * erroring if no {@link RelyingPartyRegistration} is found. + * + *

+ * If no {@code registrationId} is found in the request, it will look for a currently + * logged-in user and use the associated {@code registrationId}. + * + *

+ * In the event that neither the URL nor any logged in user could determine a + * {@code registrationId}, this code then will try and derive a + * {@link RelyingPartyRegistration} given the {@code }'s + * {@code Issuer} value. + * @param request the HTTP request + * @return a {@link Saml2LogoutRequestValidatorParameters} instance, or {@code null} + * if one could not be resolved + * @throws Saml2AuthenticationException if the {@link RequestMatcher} specifies a + * non-existent {@code registrationId} + */ + @Override + public Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request, Authentication authentication) { + if (request.getParameter(Saml2ParameterNames.SAML_REQUEST) == null) { + return null; + } + RequestMatcher.MatchResult result = this.requestMatcher.matcher(request); + if (!result.isMatch()) { + return null; + } + String registrationId = getRegistrationId(result, authentication); + if (registrationId == null) { + return logoutRequestByEntityId(request, authentication); + } + return logoutRequestById(request, authentication, registrationId); + } + + /** + * The request matcher to use to identify a request to process a + * {@code }. By default, checks for {@code /logout/saml2/slo} and + * {@code /logout/saml2/slo/{registrationId}}. + * + *

+ * Generally speaking, the URL does not need to have a {@code registrationId} in it + * since either it can be looked up from the active logged in user or it can be + * derived through the {@code Issuer} in the {@code }. + * @param requestMatcher the {@link RequestMatcher} to use + */ + void setRequestMatcher(RequestMatcher requestMatcher) { + Assert.notNull(requestMatcher, "requestMatcher cannot be null"); + this.requestMatcher = requestMatcher; + } + + private String getRegistrationId(RequestMatcher.MatchResult result, Authentication authentication) { + String registrationId = result.getVariables().get("registrationId"); + if (registrationId != null) { + return registrationId; + } + if (authentication == null) { + return null; + } + if (authentication.getPrincipal() instanceof Saml2AuthenticatedPrincipal principal) { + return principal.getRelyingPartyRegistrationId(); + } + return null; + } + + private Saml2LogoutRequestValidatorParameters logoutRequestById(HttpServletRequest request, + Authentication authentication, String registrationId) { + RelyingPartyRegistration registration = this.registrations.findByRegistrationId(registrationId); + if (registration == null) { + throw new Saml2AuthenticationException( + new Saml2Error(Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND, "registration not found"), + "registration not found"); + } + return logoutRequestByRegistration(request, registration, authentication); + } + + private Saml2LogoutRequestValidatorParameters logoutRequestByEntityId(HttpServletRequest request, + Authentication authentication) { + String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST); + LogoutRequest logoutRequest = this.saml.deserialize( + Saml2Utils.withEncoded(serialized).inflate(HttpMethod.GET.matches(request.getMethod())).decode()); + String issuer = logoutRequest.getIssuer().getValue(); + RelyingPartyRegistration registration = this.registrations.findUniqueByAssertingPartyEntityId(issuer); + return logoutRequestByRegistration(request, registration, authentication); + } + + private Saml2LogoutRequestValidatorParameters logoutRequestByRegistration(HttpServletRequest request, + RelyingPartyRegistration registration, Authentication authentication) { + if (registration == null) { + return null; + } + Saml2MessageBinding saml2MessageBinding = Saml2MessageBindingUtils.resolveBinding(request); + registration = fromRequest(request, registration); + String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST); + Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) + .samlRequest(serialized) + .relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE)) + .binding(saml2MessageBinding) + .location(registration.getSingleLogoutServiceLocation()) + .parameters((params) -> params.put(Saml2ParameterNames.SIG_ALG, + request.getParameter(Saml2ParameterNames.SIG_ALG))) + .parameters((params) -> params.put(Saml2ParameterNames.SIGNATURE, + request.getParameter(Saml2ParameterNames.SIGNATURE))) + .parametersQuery((params) -> request.getQueryString()) + .build(); + return new Saml2LogoutRequestValidatorParameters(logoutRequest, registration, authentication); + } + + private RelyingPartyRegistration fromRequest(HttpServletRequest request, RelyingPartyRegistration registration) { + RelyingPartyRegistrationPlaceholderResolvers.UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers + .uriResolver(request, registration); + String entityId = uriResolver.resolve(registration.getEntityId()); + String logoutLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation()); + String logoutResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation()); + return registration.mutate() + .entityId(entityId) + .singleLogoutServiceLocation(logoutLocation) + .singleLogoutServiceResponseLocation(logoutResponseLocation) + .build(); + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutResponseResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/BaseOpenSamlLogoutResponseResolver.java similarity index 65% rename from saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutResponseResolver.java rename to saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/BaseOpenSamlLogoutResponseResolver.java index 531bcad602..41a44a1856 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutResponseResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/BaseOpenSamlLogoutResponseResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,20 +16,19 @@ package org.springframework.security.saml2.provider.service.web.authentication.logout; -import java.io.ByteArrayInputStream; -import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; import java.util.UUID; -import java.util.function.BiConsumer; +import java.util.function.Consumer; import jakarta.servlet.http.HttpServletRequest; -import net.shibboleth.utilities.java.support.xml.ParserPool; -import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.opensaml.core.config.ConfigurationService; import org.opensaml.core.xml.config.XMLObjectProviderRegistry; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; -import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.saml.saml2.core.Issuer; import org.opensaml.saml.saml2.core.LogoutRequest; import org.opensaml.saml.saml2.core.LogoutResponse; @@ -41,11 +40,8 @@ import org.opensaml.saml.saml2.core.impl.LogoutResponseBuilder; import org.opensaml.saml.saml2.core.impl.LogoutResponseMarshaller; import org.opensaml.saml.saml2.core.impl.StatusBuilder; import org.opensaml.saml.saml2.core.impl.StatusCodeBuilder; -import org.w3c.dom.Document; -import org.w3c.dom.Element; import org.springframework.security.core.Authentication; -import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2ParameterNames; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal; @@ -56,14 +52,13 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; -import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSamlSigningUtils.QueryParametersPartial; import org.springframework.util.Assert; /** * For internal use only. Intended for consolidating common behavior related to minting a * SAML 2.0 Logout Response. */ -final class OpenSamlLogoutResponseResolver { +final class BaseOpenSamlLogoutResponseResolver implements Saml2LogoutResponseResolver { static { OpenSamlInitializationService.initialize(); @@ -71,7 +66,7 @@ final class OpenSamlLogoutResponseResolver { private final Log logger = LogFactory.getLog(getClass()); - private final ParserPool parserPool; + private XMLObjectProviderRegistry registry; private final LogoutRequestUnmarshaller unmarshaller; @@ -85,32 +80,39 @@ final class OpenSamlLogoutResponseResolver { private final StatusCodeBuilder statusCodeBuilder; + private final OpenSamlOperations saml; + private final RelyingPartyRegistrationRepository registrations; private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; + private Clock clock = Clock.systemUTC(); + + private Consumer parametersConsumer = (parameters) -> { + }; + /** - * Construct a {@link OpenSamlLogoutResponseResolver} + * Construct a {@link BaseOpenSamlLogoutResponseResolver} */ - OpenSamlLogoutResponseResolver(RelyingPartyRegistrationRepository registrations, - RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + BaseOpenSamlLogoutResponseResolver(RelyingPartyRegistrationRepository registrations, + RelyingPartyRegistrationResolver relyingPartyRegistrationResolver, OpenSamlOperations saml) { + this.saml = saml; this.registrations = registrations; this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; - XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); - this.parserPool = registry.getParserPool(); + this.registry = ConfigurationService.get(XMLObjectProviderRegistry.class); this.unmarshaller = (LogoutRequestUnmarshaller) XMLObjectProviderRegistrySupport.getUnmarshallerFactory() .getUnmarshaller(LogoutRequest.DEFAULT_ELEMENT_NAME); - this.marshaller = (LogoutResponseMarshaller) registry.getMarshallerFactory() + this.marshaller = (LogoutResponseMarshaller) this.registry.getMarshallerFactory() .getMarshaller(LogoutResponse.DEFAULT_ELEMENT_NAME); Assert.notNull(this.marshaller, "logoutResponseMarshaller must be configured in OpenSAML"); - this.logoutResponseBuilder = (LogoutResponseBuilder) registry.getBuilderFactory() + this.logoutResponseBuilder = (LogoutResponseBuilder) this.registry.getBuilderFactory() .getBuilder(LogoutResponse.DEFAULT_ELEMENT_NAME); Assert.notNull(this.logoutResponseBuilder, "logoutResponseBuilder must be configured in OpenSAML"); - this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME); + this.issuerBuilder = (IssuerBuilder) this.registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME); Assert.notNull(this.issuerBuilder, "issuerBuilder must be configured in OpenSAML"); - this.statusBuilder = (StatusBuilder) registry.getBuilderFactory().getBuilder(Status.DEFAULT_ELEMENT_NAME); + this.statusBuilder = (StatusBuilder) this.registry.getBuilderFactory().getBuilder(Status.DEFAULT_ELEMENT_NAME); Assert.notNull(this.statusBuilder, "statusBuilder must be configured in OpenSAML"); - this.statusCodeBuilder = (StatusCodeBuilder) registry.getBuilderFactory() + this.statusCodeBuilder = (StatusCodeBuilder) this.registry.getBuilderFactory() .getBuilder(StatusCode.DEFAULT_ELEMENT_NAME); Assert.notNull(this.statusCodeBuilder, "statusCodeBuilder must be configured in OpenSAML"); } @@ -126,14 +128,9 @@ final class OpenSamlLogoutResponseResolver { * @param authentication the current user * @return a signed and serialized SAML 2.0 Logout Response */ - Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication) { - return resolve(request, authentication, (registration, logoutResponse) -> { - }); - } - - Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication, - BiConsumer logoutResponseConsumer) { - LogoutRequest logoutRequest = parse(extractSamlRequest(request)); + @Override + public Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication) { + LogoutRequest logoutRequest = this.saml.deserialize(extractSamlRequest(request)); String registrationId = getRegistrationId(authentication); RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request, registrationId); if (registration == null && this.registrations != null) { @@ -163,30 +160,46 @@ final class OpenSamlLogoutResponseResolver { if (logoutResponse.getID() == null) { logoutResponse.setID("LR" + UUID.randomUUID()); } - logoutResponseConsumer.accept(registration, logoutResponse); + logoutResponse.setIssueInstant(Instant.now(this.clock)); + this.parametersConsumer + .accept(new LogoutResponseParameters(request, registration, authentication, logoutRequest)); + String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE); Saml2LogoutResponse.Builder result = Saml2LogoutResponse.withRelyingPartyRegistration(registration); if (registration.getAssertingPartyMetadata().getSingleLogoutServiceBinding() == Saml2MessageBinding.POST) { - String xml = serialize(OpenSamlSigningUtils.sign(logoutResponse, registration)); - String samlResponse = Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8)); + String xml = serialize(this.saml.withSigningKeys(registration.getSigningX509Credentials()) + .algorithms(registration.getAssertingPartyMetadata().getSigningAlgorithms()) + .sign(logoutResponse)); + String samlResponse = Saml2Utils.withDecoded(xml).encode(); result.samlResponse(samlResponse); - if (request.getParameter(Saml2ParameterNames.RELAY_STATE) != null) { - result.relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE)); + if (relayState != null) { + result.relayState(relayState); } return result.build(); } else { String xml = serialize(logoutResponse); - String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); + String deflatedAndEncoded = Saml2Utils.withDecoded(xml).deflate(true).encode(); result.samlResponse(deflatedAndEncoded); - QueryParametersPartial partial = OpenSamlSigningUtils.sign(registration) - .param(Saml2ParameterNames.SAML_RESPONSE, deflatedAndEncoded); - if (request.getParameter(Saml2ParameterNames.RELAY_STATE) != null) { - partial.param(Saml2ParameterNames.RELAY_STATE, request.getParameter(Saml2ParameterNames.RELAY_STATE)); + Map signingParameters = new HashMap<>(); + signingParameters.put(Saml2ParameterNames.SAML_RESPONSE, deflatedAndEncoded); + if (relayState != null) { + signingParameters.put(Saml2ParameterNames.RELAY_STATE, relayState); } - return result.parameters((params) -> params.putAll(partial.parameters())).build(); + Map parameters = this.saml.withSigningKeys(registration.getSigningX509Credentials()) + .algorithms(registration.getAssertingPartyMetadata().getSigningAlgorithms()) + .sign(signingParameters); + return result.parameters((params) -> params.putAll(parameters)).build(); } } + void setClock(Clock clock) { + this.clock = clock; + } + + void setParametersConsumer(Consumer parametersConsumer) { + this.parametersConsumer = parametersConsumer; + } + private String getRegistrationId(Authentication authentication) { if (this.logger.isTraceEnabled()) { this.logger.trace("Attempting to resolve registrationId from " + authentication); @@ -202,34 +215,49 @@ final class OpenSamlLogoutResponseResolver { } private String extractSamlRequest(HttpServletRequest request) { - String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST); - byte[] b = Saml2Utils.samlDecode(serialized); - if (Saml2MessageBindingUtils.isHttpRedirectBinding(request)) { - return Saml2Utils.samlInflate(b); - } - return new String(b, StandardCharsets.UTF_8); + return Saml2Utils.withEncoded(request.getParameter(Saml2ParameterNames.SAML_REQUEST)) + .inflate(Saml2MessageBindingUtils.isHttpRedirectBinding(request)) + .decode(); + } + + private String serialize(LogoutResponse logoutResponse) { + return this.saml.serialize(logoutResponse).serialize(); } - private LogoutRequest parse(String request) throws Saml2Exception { - try { - Document document = this.parserPool - .parse(new ByteArrayInputStream(request.getBytes(StandardCharsets.UTF_8))); - Element element = document.getDocumentElement(); - return (LogoutRequest) this.unmarshaller.unmarshall(element); + static final class LogoutResponseParameters { + + private final HttpServletRequest request; + + private final RelyingPartyRegistration registration; + + private final Authentication authentication; + + private final LogoutRequest logoutRequest; + + LogoutResponseParameters(HttpServletRequest request, RelyingPartyRegistration registration, + Authentication authentication, LogoutRequest logoutRequest) { + this.request = request; + this.registration = registration; + this.authentication = authentication; + this.logoutRequest = logoutRequest; + } + + HttpServletRequest getRequest() { + return this.request; } - catch (Exception ex) { - throw new Saml2Exception("Failed to deserialize LogoutRequest", ex); + + RelyingPartyRegistration getRelyingPartyRegistration() { + return this.registration; } - } - private String serialize(LogoutResponse logoutResponse) { - try { - Element element = this.marshaller.marshall(logoutResponse); - return SerializeSupport.nodeToString(element); + Authentication getAuthentication() { + return this.authentication; } - catch (MarshallingException ex) { - throw new Saml2Exception(ex); + + LogoutRequest getLogoutRequest() { + return this.logoutRequest; } + } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolver.java index 8180791f04..0c46afa635 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolver.java @@ -41,12 +41,7 @@ import org.springframework.util.Assert; */ public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestResolver { - private final OpenSamlLogoutRequestResolver logoutRequestResolver; - - private Consumer parametersConsumer = (parameters) -> { - }; - - private Clock clock = Clock.systemUTC(); + private final BaseOpenSamlLogoutRequestResolver delegate; public OpenSaml4LogoutRequestResolver(RelyingPartyRegistrationRepository registrations) { this((request, id) -> { @@ -61,7 +56,8 @@ public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestR * Construct a {@link OpenSaml4LogoutRequestResolver} */ public OpenSaml4LogoutRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { - this.logoutRequestResolver = new OpenSamlLogoutRequestResolver(relyingPartyRegistrationResolver); + this.delegate = new BaseOpenSamlLogoutRequestResolver(relyingPartyRegistrationResolver, + new OpenSaml4Template()); } /** @@ -69,11 +65,7 @@ public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestR */ @Override public Saml2LogoutRequest resolve(HttpServletRequest request, Authentication authentication) { - return this.logoutRequestResolver.resolve(request, authentication, (registration, logoutRequest) -> { - logoutRequest.setIssueInstant(Instant.now(this.clock)); - this.parametersConsumer - .accept(new LogoutRequestParameters(request, registration, authentication, logoutRequest)); - }); + return this.delegate.resolve(request, authentication); } /** @@ -83,7 +75,8 @@ public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestR */ public void setParametersConsumer(Consumer parametersConsumer) { Assert.notNull(parametersConsumer, "parametersConsumer cannot be null"); - this.parametersConsumer = parametersConsumer; + this.delegate + .setParametersConsumer((parameters) -> parametersConsumer.accept(new LogoutRequestParameters(parameters))); } /** @@ -92,7 +85,7 @@ public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestR */ public void setClock(Clock clock) { Assert.notNull(clock, "clock must not be null"); - this.clock = clock; + this.delegate.setClock(clock); } /** @@ -102,7 +95,7 @@ public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestR */ public void setRelayStateResolver(Converter relayStateResolver) { Assert.notNull(relayStateResolver, "relayStateResolver cannot be null"); - this.logoutRequestResolver.setRelayStateResolver(relayStateResolver); + this.delegate.setRelayStateResolver(relayStateResolver); } public static final class LogoutRequestParameters { @@ -123,6 +116,11 @@ public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestR this.logoutRequest = logoutRequest; } + LogoutRequestParameters(BaseOpenSamlLogoutRequestResolver.LogoutRequestParameters parameters) { + this(parameters.getRequest(), parameters.getRelyingPartyRegistration(), parameters.getAuthentication(), + parameters.getLogoutRequest()); + } + public HttpServletRequest getRequest() { return this.request; } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolver.java new file mode 100644 index 0000000000..a29785bef9 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolver.java @@ -0,0 +1,100 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.security.core.Authentication; +import org.springframework.security.saml2.core.OpenSamlInitializationService; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; +import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidatorParameters; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; + +/** + * An OpenSAML-based implementation of + * {@link Saml2LogoutRequestValidatorParametersResolver} + */ +public final class OpenSaml4LogoutRequestValidatorParametersResolver + implements Saml2LogoutRequestValidatorParametersResolver { + + static { + OpenSamlInitializationService.initialize(); + } + + private final BaseOpenSamlLogoutRequestValidatorParametersResolver delegate; + + /** + * Constructs a {@link OpenSaml4LogoutRequestValidatorParametersResolver} + */ + public OpenSaml4LogoutRequestValidatorParametersResolver(RelyingPartyRegistrationRepository registrations) { + Assert.notNull(registrations, "relyingPartyRegistrationRepository cannot be null"); + this.delegate = new BaseOpenSamlLogoutRequestValidatorParametersResolver(new OpenSaml4Template(), + registrations); + } + + /** + * Construct the parameters necessary for validating an asserting party's + * {@code } based on the given {@link HttpServletRequest} + * + *

+ * Uses the configured {@link RequestMatcher} to identify the processing request, + * including looking for any indicated {@code registrationId}. + * + *

+ * If a {@code registrationId} is found in the request, it will attempt to use that, + * erroring if no {@link RelyingPartyRegistration} is found. + * + *

+ * If no {@code registrationId} is found in the request, it will look for a currently + * logged-in user and use the associated {@code registrationId}. + * + *

+ * In the event that neither the URL nor any logged in user could determine a + * {@code registrationId}, this code then will try and derive a + * {@link RelyingPartyRegistration} given the {@code }'s + * {@code Issuer} value. + * @param request the HTTP request + * @return a {@link Saml2LogoutRequestValidatorParameters} instance, or {@code null} + * if one could not be resolved + * @throws Saml2AuthenticationException if the {@link RequestMatcher} specifies a + * non-existent {@code registrationId} + */ + @Override + public Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request, Authentication authentication) { + return this.delegate.resolve(request, authentication); + } + + /** + * The request matcher to use to identify a request to process a + * {@code }. By default, checks for {@code /logout/saml2/slo} and + * {@code /logout/saml2/slo/{registrationId}}. + * + *

+ * Generally speaking, the URL does not need to have a {@code registrationId} in it + * since either it can be looked up from the active logged in user or it can be + * derived through the {@code Issuer} in the {@code }. + * @param requestMatcher the {@link RequestMatcher} to use + */ + public void setRequestMatcher(RequestMatcher requestMatcher) { + Assert.notNull(requestMatcher, "requestMatcher cannot be null"); + this.delegate.setRequestMatcher(requestMatcher); + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutResponseResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutResponseResolver.java index 6e95b3dae1..4d70062b3f 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutResponseResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutResponseResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import java.time.Instant; import java.util.function.Consumer; import jakarta.servlet.http.HttpServletRequest; -import org.opensaml.saml.saml2.core.LogoutResponse; +import org.opensaml.saml.saml2.core.LogoutRequest; import org.springframework.security.core.Authentication; import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutResponse; @@ -39,27 +39,23 @@ import org.springframework.util.Assert; */ public final class OpenSaml4LogoutResponseResolver implements Saml2LogoutResponseResolver { - private final OpenSamlLogoutResponseResolver logoutResponseResolver; - - private Consumer parametersConsumer = (parameters) -> { - }; - - private Clock clock = Clock.systemUTC(); + private final BaseOpenSamlLogoutResponseResolver delegate; public OpenSaml4LogoutResponseResolver(RelyingPartyRegistrationRepository registrations) { - this.logoutResponseResolver = new OpenSamlLogoutResponseResolver(registrations, (request, id) -> { + this.delegate = new BaseOpenSamlLogoutResponseResolver(registrations, (request, id) -> { if (id == null) { return null; } return registrations.findByRegistrationId(id); - }); + }, new OpenSaml4Template()); } /** * Construct a {@link OpenSaml4LogoutResponseResolver} */ public OpenSaml4LogoutResponseResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { - this.logoutResponseResolver = new OpenSamlLogoutResponseResolver(null, relyingPartyRegistrationResolver); + this.delegate = new BaseOpenSamlLogoutResponseResolver(null, relyingPartyRegistrationResolver, + new OpenSaml4Template()); } /** @@ -67,26 +63,27 @@ public final class OpenSaml4LogoutResponseResolver implements Saml2LogoutRespons */ @Override public Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication) { - return this.logoutResponseResolver.resolve(request, authentication, (registration, logoutResponse) -> { - logoutResponse.setIssueInstant(Instant.now(this.clock)); - this.parametersConsumer - .accept(new LogoutResponseParameters(request, registration, authentication, logoutResponse)); - }); + return this.delegate.resolve(request, authentication); } /** - * Set a {@link Consumer} for modifying the OpenSAML {@link LogoutResponse} + * Set a {@link Consumer} for modifying the OpenSAML {@link LogoutRequest} * @param parametersConsumer a consumer that accepts an - * {@link LogoutResponseParameters} + * {@link OpenSaml4LogoutRequestResolver.LogoutRequestParameters} */ public void setParametersConsumer(Consumer parametersConsumer) { Assert.notNull(parametersConsumer, "parametersConsumer cannot be null"); - this.parametersConsumer = parametersConsumer; + this.delegate + .setParametersConsumer((parameters) -> parametersConsumer.accept(new LogoutResponseParameters(parameters))); } + /** + * Use this {@link Clock} for determining the issued {@link Instant} + * @param clock the {@link Clock} to use + */ public void setClock(Clock clock) { Assert.notNull(clock, "clock must not be null"); - this.clock = clock; + this.delegate.setClock(clock); } public static final class LogoutResponseParameters { @@ -97,14 +94,19 @@ public final class OpenSaml4LogoutResponseResolver implements Saml2LogoutRespons private final Authentication authentication; - private final LogoutResponse logoutResponse; + private final LogoutRequest logoutRequest; public LogoutResponseParameters(HttpServletRequest request, RelyingPartyRegistration registration, - Authentication authentication, LogoutResponse logoutResponse) { + Authentication authentication, LogoutRequest logoutRequest) { this.request = request; this.registration = registration; this.authentication = authentication; - this.logoutResponse = logoutResponse; + this.logoutRequest = logoutRequest; + } + + LogoutResponseParameters(BaseOpenSamlLogoutResponseResolver.LogoutResponseParameters parameters) { + this(parameters.getRequest(), parameters.getRelyingPartyRegistration(), parameters.getAuthentication(), + parameters.getLogoutRequest()); } public HttpServletRequest getRequest() { @@ -119,8 +121,8 @@ public final class OpenSaml4LogoutResponseResolver implements Saml2LogoutRespons return this.authentication; } - public LogoutResponse getLogoutResponse() { - return this.logoutResponse; + public LogoutRequest getLogoutRequest() { + return this.logoutRequest; } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4Template.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4Template.java new file mode 100644 index 0000000000..eee4fc8242 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4Template.java @@ -0,0 +1,617 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.xml.namespace.QName; + +import net.shibboleth.utilities.java.support.resolver.CriteriaSet; +import net.shibboleth.utilities.java.support.xml.SerializeSupport; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.opensaml.core.criterion.EntityIdCriterion; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.Marshaller; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.core.xml.io.Unmarshaller; +import org.opensaml.core.xml.io.UnmarshallerFactory; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.criterion.ProtocolCriterion; +import org.opensaml.saml.ext.saml2delrestrict.Delegate; +import org.opensaml.saml.ext.saml2delrestrict.DelegationRestrictionType; +import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Condition; +import org.opensaml.saml.saml2.core.EncryptedAssertion; +import org.opensaml.saml.saml2.core.EncryptedAttribute; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.LogoutRequest; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.encryption.Decrypter; +import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialResolver; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion; +import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion; +import org.opensaml.security.credential.impl.CollectionCredentialResolver; +import org.opensaml.security.criteria.UsageCriterion; +import org.opensaml.security.x509.BasicX509Credential; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.DecryptionException; +import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver; +import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; +import org.opensaml.xmlsec.keyinfo.impl.CollectionKeyInfoCredentialResolver; +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.opensaml.xmlsec.signature.support.SignatureTrustEngine; +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * For internal use only. Subject to breaking changes at any time. + */ +final class OpenSaml4Template implements OpenSamlOperations { + + private static final Log logger = LogFactory.getLog(OpenSaml4Template.class); + + @Override + public T build(QName elementName) { + XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); + if (builder == null) { + throw new Saml2Exception("Unable to resolve Builder for " + elementName); + } + return (T) builder.buildObject(elementName); + } + + @Override + public T deserialize(String serialized) { + return deserialize(new ByteArrayInputStream(serialized.getBytes(StandardCharsets.UTF_8))); + } + + @Override + public T deserialize(InputStream serialized) { + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool().parse(serialized); + Element element = document.getDocumentElement(); + UnmarshallerFactory factory = XMLObjectProviderRegistrySupport.getUnmarshallerFactory(); + Unmarshaller unmarshaller = factory.getUnmarshaller(element); + if (unmarshaller == null) { + throw new Saml2Exception("Unsupported element of type " + element.getTagName()); + } + return (T) unmarshaller.unmarshall(element); + } + catch (Saml2Exception ex) { + throw ex; + } + catch (Exception ex) { + throw new Saml2Exception("Failed to deserialize payload", ex); + } + } + + @Override + public OpenSaml4SerializationConfigurer serialize(XMLObject object) { + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); + try { + return serialize(marshaller.marshall(object)); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + @Override + public OpenSaml4SerializationConfigurer serialize(Element element) { + return new OpenSaml4SerializationConfigurer(element); + } + + @Override + public OpenSaml4SignatureConfigurer withSigningKeys(Collection credentials) { + return new OpenSaml4SignatureConfigurer(credentials); + } + + @Override + public OpenSaml4VerificationConfigurer withVerificationKeys(Collection credentials) { + return new OpenSaml4VerificationConfigurer(credentials); + } + + @Override + public OpenSaml4DecryptionConfigurer withDecryptionKeys(Collection credentials) { + return new OpenSaml4DecryptionConfigurer(credentials); + } + + OpenSaml4Template() { + + } + + static final class OpenSaml4SerializationConfigurer + implements SerializationConfigurer { + + private final Element element; + + boolean pretty; + + OpenSaml4SerializationConfigurer(Element element) { + this.element = element; + } + + @Override + public OpenSaml4SerializationConfigurer prettyPrint(boolean pretty) { + this.pretty = pretty; + return this; + } + + @Override + public String serialize() { + if (this.pretty) { + return SerializeSupport.prettyPrintXML(this.element); + } + return SerializeSupport.nodeToString(this.element); + } + + } + + static final class OpenSaml4SignatureConfigurer implements SignatureConfigurer { + + private final Collection credentials; + + private final Map components = new LinkedHashMap<>(); + + private List algs = List.of(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + + OpenSaml4SignatureConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public OpenSaml4SignatureConfigurer algorithms(List algs) { + this.algs = algs; + return this; + } + + @Override + public O sign(O object) { + SignatureSigningParameters parameters = resolveSigningParameters(); + try { + SignatureSupport.signObject(object, parameters); + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + return object; + } + + @Override + public Map sign(Map params) { + SignatureSigningParameters parameters = resolveSigningParameters(); + this.components.putAll(params); + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : this.components.entrySet()) { + builder.queryParam(component.getKey(), + UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + return this.components; + } + + private SignatureSigningParameters resolveSigningParameters() { + List credentials = resolveSigningCredentials(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(this.algs); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); + CriteriaSet criteria = new CriteriaSet(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { + final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); + + namedManager.setUseDefaultManager(true); + final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); + + // Generator for X509Credentials + final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); + x509Factory.setEmitEntityCertificate(true); + x509Factory.setEmitEntityCertificateChain(true); + + defaultManager.registerFactory(x509Factory); + + return namedManager; + } + + private List resolveSigningCredentials() { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : this.credentials) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + + } + + static final class OpenSaml4VerificationConfigurer implements VerificationConfigurer { + + private final Collection credentials; + + private String entityId; + + OpenSaml4VerificationConfigurer(Collection credentials) { + this.credentials = credentials; + } + + @Override + public VerificationConfigurer entityId(String entityId) { + this.entityId = entityId; + return this; + } + + private SignatureTrustEngine trustEngine(Collection keys) { + Set credentials = new HashSet<>(); + for (Saml2X509Credential key : keys) { + BasicX509Credential cred = new BasicX509Credential(key.getCertificate()); + cred.setUsageType(UsageType.SIGNING); + cred.setEntityId(this.entityId); + credentials.add(cred); + } + CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials); + return new ExplicitKeySignatureTrustEngine(credentialsResolver, + DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()); + } + + private CriteriaSet verificationCriteria(Issuer issuer) { + return new CriteriaSet(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer.getValue())), + new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)), + new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING))); + } + + @Override + public Collection verify(SignableXMLObject signable) { + if (signable instanceof StatusResponseType response) { + return verifySignature(response.getID(), response.getIssuer(), response.getSignature()); + } + if (signable instanceof RequestAbstractType request) { + return verifySignature(request.getID(), request.getIssuer(), request.getSignature()); + } + if (signable instanceof Assertion assertion) { + return verifySignature(assertion.getID(), assertion.getIssuer(), assertion.getSignature()); + } + throw new Saml2Exception("Unsupported object of type: " + signable.getClass().getName()); + } + + private Collection verifySignature(String id, Issuer issuer, Signature signature) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(issuer); + Collection errors = new ArrayList<>(); + SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); + try { + profileValidator.validate(signature); + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + try { + if (!trustEngine.validate(signature, criteria)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + id + "]: ")); + } + + return errors; + } + + @Override + public Collection verify(RedirectParameters parameters) { + SignatureTrustEngine trustEngine = trustEngine(this.credentials); + CriteriaSet criteria = verificationCriteria(parameters.getIssuer()); + if (parameters.getAlgorithm() == null) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature algorithm for object [" + parameters.getId() + "]")); + } + if (!parameters.hasSignature()) { + return Collections.singletonList(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Missing signature for object [" + parameters.getId() + "]")); + } + Collection errors = new ArrayList<>(); + String algorithmUri = parameters.getAlgorithm(); + try { + if (!trustEngine.validate(parameters.getSignature(), parameters.getContent(), algorithmUri, criteria, + null)) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]")); + } + } + catch (Exception ex) { + errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, + "Invalid signature for object [" + parameters.getId() + "]: ")); + } + return errors; + } + + } + + static final class OpenSaml4DecryptionConfigurer implements DecryptionConfigurer { + + private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver( + Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(), + new SimpleRetrievalMethodEncryptedKeyResolver())); + + private final Decrypter decrypter; + + OpenSaml4DecryptionConfigurer(Collection decryptionCredentials) { + this.decrypter = decrypter(decryptionCredentials); + } + + private static Decrypter decrypter(Collection decryptionCredentials) { + Collection credentials = new ArrayList<>(); + for (Saml2X509Credential key : decryptionCredentials) { + Credential cred = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey()); + credentials.add(cred); + } + KeyInfoCredentialResolver resolver = new CollectionKeyInfoCredentialResolver(credentials); + Decrypter decrypter = new Decrypter(null, resolver, encryptedKeyResolver); + decrypter.setRootInNewDocument(true); + return decrypter; + } + + @Override + public void decrypt(XMLObject object) { + if (object instanceof Response response) { + decryptResponse(response); + return; + } + if (object instanceof Assertion assertion) { + decryptAssertion(assertion); + } + if (object instanceof LogoutRequest request) { + decryptLogoutRequest(request); + } + } + + /* + * The methods that follow are adapted from OpenSAML's {@link DecryptAssertions}, + * {@link DecryptNameIDs}, and {@link DecryptAttributes}. + * + *

The reason that these OpenSAML classes are not used directly is because they + * reference {@link javax.servlet.http.HttpServletRequest} which is a lower + * Servlet API version than what Spring Security SAML uses. + * + * If OpenSAML 5 updates to {@link jakarta.servlet.http.HttpServletRequest}, then + * this arrangement can be revisited. + */ + + private void decryptResponse(Response response) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + + int count = 0; + int size = response.getEncryptedAssertions().size(); + for (EncryptedAssertion encrypted : response.getEncryptedAssertions()) { + logger.trace(String.format("Decrypting EncryptedAssertion (%d/%d) in Response [%s]", count, size, + response.getID())); + try { + Assertion decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + count++; + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + response.getEncryptedAssertions().removeAll(encrypteds); + response.getAssertions().addAll(decrypteds); + + // Re-marshall the response so that any ID attributes within the decrypted + // Assertions + // will have their ID-ness re-established at the DOM level. + if (!decrypteds.isEmpty()) { + try { + XMLObjectSupport.marshall(response); + } + catch (final MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + } + + private void decryptAssertion(Assertion assertion) { + for (AttributeStatement statement : assertion.getAttributeStatements()) { + decryptAttributes(statement); + } + decryptSubject(assertion.getSubject()); + if (assertion.getConditions() != null) { + for (Condition c : assertion.getConditions().getConditions()) { + if (!(c instanceof DelegationRestrictionType delegation)) { + continue; + } + for (Delegate d : delegation.getDelegates()) { + if (d.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(d.getEncryptedID()); + if (decrypted != null) { + d.setNameID(decrypted); + d.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + } + + private void decryptAttributes(AttributeStatement statement) { + Collection decrypteds = new ArrayList<>(); + Collection encrypteds = new ArrayList<>(); + for (EncryptedAttribute encrypted : statement.getEncryptedAttributes()) { + try { + Attribute decrypted = this.decrypter.decrypt(encrypted); + if (decrypted != null) { + encrypteds.add(encrypted); + decrypteds.add(decrypted); + } + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + statement.getEncryptedAttributes().removeAll(encrypteds); + statement.getAttributes().addAll(decrypteds); + } + + private void decryptSubject(Subject subject) { + if (subject != null) { + if (subject.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(subject.getEncryptedID()); + if (decrypted != null) { + subject.setNameID(decrypted); + subject.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + + for (final SubjectConfirmation sc : subject.getSubjectConfirmations()) { + if (sc.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(sc.getEncryptedID()); + if (decrypted != null) { + sc.setNameID(decrypted); + sc.setEncryptedID(null); + } + } + catch (final DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + } + } + + private void decryptLogoutRequest(LogoutRequest request) { + if (request.getEncryptedID() != null) { + try { + NameID decrypted = (NameID) this.decrypter.decrypt(request.getEncryptedID()); + if (decrypted != null) { + request.setNameID(decrypted); + request.setEncryptedID(null); + } + } + catch (DecryptionException ex) { + throw new Saml2Exception(ex); + } + } + } + + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolver.java index 7e005c79d2..f520dad657 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolver.java @@ -16,29 +16,21 @@ package org.springframework.security.saml2.provider.service.web.authentication.logout; -import java.io.ByteArrayInputStream; -import java.nio.charset.StandardCharsets; - import jakarta.servlet.http.HttpServletRequest; -import net.shibboleth.utilities.java.support.xml.ParserPool; import org.opensaml.core.config.ConfigurationService; import org.opensaml.core.xml.config.XMLObjectProviderRegistry; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.saml.saml2.core.LogoutRequest; import org.opensaml.saml.saml2.core.impl.LogoutRequestUnmarshaller; -import org.w3c.dom.Document; -import org.w3c.dom.Element; import org.springframework.http.HttpMethod; import org.springframework.security.core.Authentication; -import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2Error; import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ParameterNames; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; -import org.springframework.security.saml2.provider.service.authentication.logout.OpenSamlLogoutRequestValidator; import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest; import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidatorParameters; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; @@ -53,7 +45,12 @@ import org.springframework.util.Assert; /** * An OpenSAML-based implementation of * {@link Saml2LogoutRequestValidatorParametersResolver} + * + * @deprecated Please use a version-specific + * {@link Saml2LogoutRequestValidatorParametersResolver} such as + * {@code OpenSaml4LogoutRequestValidatorParametersResolver} */ +@Deprecated public final class OpenSamlLogoutRequestValidatorParametersResolver implements Saml2LogoutRequestValidatorParametersResolver { @@ -65,19 +62,20 @@ public final class OpenSamlLogoutRequestValidatorParametersResolver new AntPathRequestMatcher("/logout/saml2/slo/{registrationId}"), new AntPathRequestMatcher("/logout/saml2/slo")); + private final OpenSamlOperations saml = new OpenSaml4Template(); + private final RelyingPartyRegistrationRepository registrations; - private final ParserPool parserPool; + private final XMLObjectProviderRegistry registry; private final LogoutRequestUnmarshaller unmarshaller; /** - * Constructs a {@link OpenSamlLogoutRequestValidator} + * Constructs a {@link OpenSamlLogoutRequestValidatorParametersResolver} */ public OpenSamlLogoutRequestValidatorParametersResolver(RelyingPartyRegistrationRepository registrations) { Assert.notNull(registrations, "relyingPartyRegistrationRepository cannot be null"); - XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); - this.parserPool = registry.getParserPool(); + this.registry = ConfigurationService.get(XMLObjectProviderRegistry.class); this.unmarshaller = (LogoutRequestUnmarshaller) XMLObjectProviderRegistrySupport.getUnmarshallerFactory() .getUnmarshaller(LogoutRequest.DEFAULT_ELEMENT_NAME); this.registrations = registrations; @@ -170,8 +168,11 @@ public final class OpenSamlLogoutRequestValidatorParametersResolver private Saml2LogoutRequestValidatorParameters logoutRequestByEntityId(HttpServletRequest request, Authentication authentication) { String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST); - byte[] b = Saml2Utils.samlDecode(serialized); - LogoutRequest logoutRequest = parse(inflateIfRequired(request, b)); + LogoutRequest logoutRequest = this.saml + .deserialize(org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2Utils + .withEncoded(serialized) + .inflate(HttpMethod.GET.matches(request.getMethod())) + .decode()); String issuer = logoutRequest.getIssuer().getValue(); RelyingPartyRegistration registration = this.registrations.findUniqueByAssertingPartyEntityId(issuer); return logoutRequestByRegistration(request, registration, authentication); @@ -199,25 +200,6 @@ public final class OpenSamlLogoutRequestValidatorParametersResolver return new Saml2LogoutRequestValidatorParameters(logoutRequest, registration, authentication); } - private String inflateIfRequired(HttpServletRequest request, byte[] b) { - if (HttpMethod.GET.matches(request.getMethod())) { - return Saml2Utils.samlInflate(b); - } - return new String(b, StandardCharsets.UTF_8); - } - - private LogoutRequest parse(String request) throws Saml2Exception { - try { - Document document = this.parserPool - .parse(new ByteArrayInputStream(request.getBytes(StandardCharsets.UTF_8))); - Element element = document.getDocumentElement(); - return (LogoutRequest) this.unmarshaller.unmarshall(element); - } - catch (Exception ex) { - throw new Saml2Exception("Failed to deserialize LogoutRequest", ex); - } - } - private RelyingPartyRegistration fromRequest(HttpServletRequest request, RelyingPartyRegistration registration) { RelyingPartyRegistrationPlaceholderResolvers.UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers .uriResolver(request, registration); diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlOperations.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlOperations.java new file mode 100644 index 0000000000..07277ec68d --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlOperations.java @@ -0,0 +1,184 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import javax.xml.namespace.QName; + +import org.opensaml.core.xml.XMLObject; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.RequestAbstractType; +import org.opensaml.saml.saml2.core.StatusResponseType; +import org.opensaml.xmlsec.signature.SignableXMLObject; +import org.w3c.dom.Element; + +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.web.util.UriComponentsBuilder; + +interface OpenSamlOperations { + + T build(QName elementName); + + T deserialize(String serialized); + + T deserialize(InputStream serialized); + + SerializationConfigurer serialize(XMLObject object); + + SerializationConfigurer serialize(Element element); + + SignatureConfigurer withSigningKeys(Collection credentials); + + VerificationConfigurer withVerificationKeys(Collection credentials); + + DecryptionConfigurer withDecryptionKeys(Collection credentials); + + interface SerializationConfigurer> { + + B prettyPrint(boolean pretty); + + String serialize(); + + } + + interface SignatureConfigurer> { + + B algorithms(List algs); + + O sign(O object); + + Map sign(Map params); + + } + + interface VerificationConfigurer { + + VerificationConfigurer entityId(String entityId); + + Collection verify(SignableXMLObject signable); + + Collection verify(VerificationConfigurer.RedirectParameters parameters); + + final class RedirectParameters { + + private final String id; + + private final Issuer issuer; + + private final String algorithm; + + private final byte[] signature; + + private final byte[] content; + + RedirectParameters(Map parameters, String parametersQuery, RequestAbstractType request) { + this.id = request.getID(); + this.issuer = request.getIssuer(); + this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG); + if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) { + this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE)); + } + else { + this.signature = null; + } + Map queryParams = UriComponentsBuilder.newInstance() + .query(parametersQuery) + .build(true) + .getQueryParams() + .toSingleValueMap(); + String relayState = parameters.get(Saml2ParameterNames.RELAY_STATE); + this.content = getContent(Saml2ParameterNames.SAML_REQUEST, relayState, queryParams); + } + + RedirectParameters(Map parameters, String parametersQuery, StatusResponseType response) { + this.id = response.getID(); + this.issuer = response.getIssuer(); + this.algorithm = parameters.get(Saml2ParameterNames.SIG_ALG); + if (parameters.get(Saml2ParameterNames.SIGNATURE) != null) { + this.signature = Saml2Utils.samlDecode(parameters.get(Saml2ParameterNames.SIGNATURE)); + } + else { + this.signature = null; + } + Map queryParams = UriComponentsBuilder.newInstance() + .query(parametersQuery) + .build(true) + .getQueryParams() + .toSingleValueMap(); + String relayState = parameters.get(Saml2ParameterNames.RELAY_STATE); + this.content = getContent(Saml2ParameterNames.SAML_RESPONSE, relayState, queryParams); + } + + static byte[] getContent(String samlObject, String relayState, final Map queryParams) { + if (Objects.nonNull(relayState)) { + return String + .format("%s=%s&%s=%s&%s=%s", samlObject, queryParams.get(samlObject), + Saml2ParameterNames.RELAY_STATE, queryParams.get(Saml2ParameterNames.RELAY_STATE), + Saml2ParameterNames.SIG_ALG, queryParams.get(Saml2ParameterNames.SIG_ALG)) + .getBytes(StandardCharsets.UTF_8); + } + else { + return String + .format("%s=%s&%s=%s", samlObject, queryParams.get(samlObject), Saml2ParameterNames.SIG_ALG, + queryParams.get(Saml2ParameterNames.SIG_ALG)) + .getBytes(StandardCharsets.UTF_8); + } + } + + String getId() { + return this.id; + } + + Issuer getIssuer() { + return this.issuer; + } + + byte[] getContent() { + return this.content; + } + + String getAlgorithm() { + return this.algorithm; + } + + byte[] getSignature() { + return this.signature; + } + + boolean hasSignature() { + return this.signature != null; + } + + } + + } + + interface DecryptionConfigurer { + + void decrypt(XMLObject object); + + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlSigningUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlSigningUtils.java deleted file mode 100644 index 0b7ef324dd..0000000000 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlSigningUtils.java +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.web.authentication.logout; - -import java.nio.charset.StandardCharsets; -import java.security.PrivateKey; -import java.security.cert.X509Certificate; -import java.util.ArrayList; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -import net.shibboleth.utilities.java.support.resolver.CriteriaSet; -import net.shibboleth.utilities.java.support.xml.SerializeSupport; -import org.opensaml.core.xml.XMLObject; -import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; -import org.opensaml.core.xml.io.Marshaller; -import org.opensaml.core.xml.io.MarshallingException; -import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; -import org.opensaml.security.SecurityException; -import org.opensaml.security.credential.BasicCredential; -import org.opensaml.security.credential.Credential; -import org.opensaml.security.credential.CredentialSupport; -import org.opensaml.security.credential.UsageType; -import org.opensaml.xmlsec.SignatureSigningParameters; -import org.opensaml.xmlsec.SignatureSigningParametersResolver; -import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; -import org.opensaml.xmlsec.crypto.XMLSigningUtil; -import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; -import org.opensaml.xmlsec.keyinfo.KeyInfoGeneratorManager; -import org.opensaml.xmlsec.keyinfo.NamedKeyInfoGeneratorManager; -import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; -import org.opensaml.xmlsec.signature.SignableXMLObject; -import org.opensaml.xmlsec.signature.support.SignatureConstants; -import org.opensaml.xmlsec.signature.support.SignatureSupport; -import org.w3c.dom.Element; - -import org.springframework.security.saml2.Saml2Exception; -import org.springframework.security.saml2.core.Saml2ParameterNames; -import org.springframework.security.saml2.core.Saml2X509Credential; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.util.Assert; -import org.springframework.web.util.UriComponentsBuilder; -import org.springframework.web.util.UriUtils; - -/** - * Utility methods for signing SAML components with OpenSAML - * - * For internal use only. - * - * @author Josh Cummings - */ -final class OpenSamlSigningUtils { - - static String serialize(XMLObject object) { - try { - Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); - Element element = marshaller.marshall(object); - return SerializeSupport.nodeToString(element); - } - catch (MarshallingException ex) { - throw new Saml2Exception(ex); - } - } - - static O sign(O object, RelyingPartyRegistration relyingPartyRegistration) { - SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration); - try { - SignatureSupport.signObject(object, parameters); - return object; - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } - } - - static QueryParametersPartial sign(RelyingPartyRegistration registration) { - return new QueryParametersPartial(registration); - } - - private static SignatureSigningParameters resolveSigningParameters( - RelyingPartyRegistration relyingPartyRegistration) { - List credentials = resolveSigningCredentials(relyingPartyRegistration); - List algorithms = relyingPartyRegistration.getAssertingPartyMetadata().getSigningAlgorithms(); - List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); - String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; - SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); - CriteriaSet criteria = new CriteriaSet(); - BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); - signingConfiguration.setSigningCredentials(credentials); - signingConfiguration.setSignatureAlgorithms(algorithms); - signingConfiguration.setSignatureReferenceDigestMethods(digests); - signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); - signingConfiguration.setKeyInfoGeneratorManager(buildSignatureKeyInfoGeneratorManager()); - criteria.add(new SignatureSigningConfigurationCriterion(signingConfiguration)); - try { - SignatureSigningParameters parameters = resolver.resolveSingle(criteria); - Assert.notNull(parameters, "Failed to resolve any signing credential"); - return parameters; - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } - } - - private static NamedKeyInfoGeneratorManager buildSignatureKeyInfoGeneratorManager() { - final NamedKeyInfoGeneratorManager namedManager = new NamedKeyInfoGeneratorManager(); - - namedManager.setUseDefaultManager(true); - final KeyInfoGeneratorManager defaultManager = namedManager.getDefaultManager(); - - // Generator for X509Credentials - final X509KeyInfoGeneratorFactory x509Factory = new X509KeyInfoGeneratorFactory(); - x509Factory.setEmitEntityCertificate(true); - x509Factory.setEmitEntityCertificateChain(true); - - defaultManager.registerFactory(x509Factory); - - return namedManager; - } - - private static List resolveSigningCredentials(RelyingPartyRegistration relyingPartyRegistration) { - List credentials = new ArrayList<>(); - for (Saml2X509Credential x509Credential : relyingPartyRegistration.getSigningX509Credentials()) { - X509Certificate certificate = x509Credential.getCertificate(); - PrivateKey privateKey = x509Credential.getPrivateKey(); - BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); - credential.setEntityId(relyingPartyRegistration.getEntityId()); - credential.setUsageType(UsageType.SIGNING); - credentials.add(credential); - } - return credentials; - } - - private OpenSamlSigningUtils() { - - } - - static class QueryParametersPartial { - - final RelyingPartyRegistration registration; - - final Map components = new LinkedHashMap<>(); - - QueryParametersPartial(RelyingPartyRegistration registration) { - this.registration = registration; - } - - QueryParametersPartial param(String key, String value) { - this.components.put(key, value); - return this; - } - - Map parameters() { - SignatureSigningParameters parameters = resolveSigningParameters(this.registration); - Credential credential = parameters.getSigningCredential(); - String algorithmUri = parameters.getSignatureAlgorithm(); - this.components.put(Saml2ParameterNames.SIG_ALG, algorithmUri); - UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); - for (Map.Entry component : this.components.entrySet()) { - builder.queryParam(component.getKey(), - UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); - } - String queryString = builder.build(true).toString().substring(1); - try { - byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, - queryString.getBytes(StandardCharsets.UTF_8)); - String b64Signature = Saml2Utils.samlEncode(rawSignature); - this.components.put(Saml2ParameterNames.SIGNATURE, b64Signature); - } - catch (SecurityException ex) { - throw new Saml2Exception(ex); - } - return this.components; - } - - } - -} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2Utils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2Utils.java index 95046bc3a1..547fdf959d 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2Utils.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2Utils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.security.saml2.provider.service.web.authentication.l import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.Base64; import java.util.zip.Deflater; import java.util.zip.DeflaterOutputStream; @@ -73,4 +74,123 @@ final class Saml2Utils { } } + static EncodingConfigurer withDecoded(String decoded) { + return new EncodingConfigurer(decoded); + } + + static DecodingConfigurer withEncoded(String encoded) { + return new DecodingConfigurer(encoded); + } + + static final class EncodingConfigurer { + + private final String decoded; + + private boolean deflate; + + private EncodingConfigurer(String decoded) { + this.decoded = decoded; + } + + EncodingConfigurer deflate(boolean deflate) { + this.deflate = deflate; + return this; + } + + String encode() { + byte[] bytes = (this.deflate) ? Saml2Utils.samlDeflate(this.decoded) + : this.decoded.getBytes(StandardCharsets.UTF_8); + return Saml2Utils.samlEncode(bytes); + } + + } + + static final class DecodingConfigurer { + + private static final Base64Checker BASE_64_CHECKER = new Base64Checker(); + + private final String encoded; + + private boolean inflate; + + private boolean requireBase64; + + private DecodingConfigurer(String encoded) { + this.encoded = encoded; + } + + DecodingConfigurer inflate(boolean inflate) { + this.inflate = inflate; + return this; + } + + DecodingConfigurer requireBase64(boolean requireBase64) { + this.requireBase64 = requireBase64; + return this; + } + + String decode() { + if (this.requireBase64) { + BASE_64_CHECKER.checkAcceptable(this.encoded); + } + byte[] bytes = Saml2Utils.samlDecode(this.encoded); + return (this.inflate) ? Saml2Utils.samlInflate(bytes) : new String(bytes, StandardCharsets.UTF_8); + } + + static class Base64Checker { + + private static final int[] values = genValueMapping(); + + Base64Checker() { + + } + + private static int[] genValueMapping() { + byte[] alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + .getBytes(StandardCharsets.ISO_8859_1); + + int[] values = new int[256]; + Arrays.fill(values, -1); + for (int i = 0; i < alphabet.length; i++) { + values[alphabet[i] & 0xff] = i; + } + return values; + } + + boolean isAcceptable(String s) { + int goodChars = 0; + int lastGoodCharVal = -1; + + // count number of characters from Base64 alphabet + for (int i = 0; i < s.length(); i++) { + int val = values[0xff & s.charAt(i)]; + if (val != -1) { + lastGoodCharVal = val; + goodChars++; + } + } + + // in cases of an incomplete final chunk, ensure the unused bits are zero + switch (goodChars % 4) { + case 0: + return true; + case 2: + return (lastGoodCharVal & 0b1111) == 0; + case 3: + return (lastGoodCharVal & 0b11) == 0; + default: + return false; + } + } + + void checkAcceptable(String ins) { + if (!isAcceptable(ins)) { + throw new IllegalArgumentException("Failed to decode SAMLResponse"); + } + } + + } + + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolverTests.java new file mode 100644 index 0000000000..c7aeb8b878 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolverTests.java @@ -0,0 +1,153 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import java.nio.charset.StandardCharsets; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensaml.core.xml.XMLObject; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; +import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; +import org.springframework.security.saml2.provider.service.authentication.TestSaml2Authentications; +import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidatorParameters; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.BDDMockito.given; + +@ExtendWith(MockitoExtension.class) +public final class OpenSaml4LogoutRequestValidatorParametersResolverTests { + + @Mock + RelyingPartyRegistrationRepository registrations; + + private final OpenSamlOperations saml = new OpenSaml4Template(); + + private RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build(); + + private OpenSaml4LogoutRequestValidatorParametersResolver resolver; + + @BeforeEach + void setup() { + this.resolver = new OpenSaml4LogoutRequestValidatorParametersResolver(this.registrations); + } + + @Test + void saml2LogoutRegistrationIdResolveWhenMatchesThenParameters() { + String registrationId = this.registration.getRegistrationId(); + MockHttpServletRequest request = post("/logout/saml2/slo/" + registrationId); + Authentication authentication = new TestingAuthenticationToken("user", "pass"); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + given(this.registrations.findByRegistrationId(registrationId)).willReturn(this.registration); + Saml2LogoutRequestValidatorParameters parameters = this.resolver.resolve(request, authentication); + assertThat(parameters.getAuthentication()).isEqualTo(authentication); + assertThat(parameters.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(registrationId); + assertThat(parameters.getLogoutRequest().getSamlRequest()).isEqualTo("request"); + } + + @Test + void saml2LogoutRegistrationIdWhenUnauthenticatedThenParameters() { + String registrationId = this.registration.getRegistrationId(); + MockHttpServletRequest request = post("/logout/saml2/slo/" + registrationId); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + given(this.registrations.findByRegistrationId(registrationId)).willReturn(this.registration); + Saml2LogoutRequestValidatorParameters parameters = this.resolver.resolve(request, null); + assertThat(parameters.getAuthentication()).isNull(); + assertThat(parameters.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(registrationId); + assertThat(parameters.getLogoutRequest().getSamlRequest()).isEqualTo("request"); + } + + @Test + void saml2LogoutResolveWhenAuthenticatedThenParameters() { + String registrationId = this.registration.getRegistrationId(); + MockHttpServletRequest request = post("/logout/saml2/slo"); + Authentication authentication = TestSaml2Authentications.authentication(); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + given(this.registrations.findByRegistrationId(registrationId)).willReturn(this.registration); + Saml2LogoutRequestValidatorParameters parameters = this.resolver.resolve(request, authentication); + assertThat(parameters.getAuthentication()).isEqualTo(authentication); + assertThat(parameters.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(registrationId); + assertThat(parameters.getLogoutRequest().getSamlRequest()).isEqualTo("request"); + } + + @Test + void saml2LogoutResolveWhenUnauthenticatedThenParameters() { + String registrationId = this.registration.getRegistrationId(); + MockHttpServletRequest request = post("/logout/saml2/slo"); + String logoutRequest = serialize(TestOpenSamlObjects.logoutRequest()); + String encoded = Saml2Utils.samlEncode(logoutRequest.getBytes(StandardCharsets.UTF_8)); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, encoded); + given(this.registrations.findUniqueByAssertingPartyEntityId(TestOpenSamlObjects.ASSERTING_PARTY_ENTITY_ID)) + .willReturn(this.registration); + Saml2LogoutRequestValidatorParameters parameters = this.resolver.resolve(request, null); + assertThat(parameters.getAuthentication()).isNull(); + assertThat(parameters.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(registrationId); + assertThat(parameters.getLogoutRequest().getSamlRequest()).isEqualTo(encoded); + } + + @Test + void saml2LogoutResolveWhenUnauthenticatedGetRequestThenInflates() { + String registrationId = this.registration.getRegistrationId(); + MockHttpServletRequest request = get("/logout/saml2/slo"); + String logoutRequest = serialize(TestOpenSamlObjects.logoutRequest()); + String encoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(logoutRequest)); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, encoded); + given(this.registrations.findUniqueByAssertingPartyEntityId(TestOpenSamlObjects.ASSERTING_PARTY_ENTITY_ID)) + .willReturn(this.registration); + Saml2LogoutRequestValidatorParameters parameters = this.resolver.resolve(request, null); + assertThat(parameters.getAuthentication()).isNull(); + assertThat(parameters.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(registrationId); + assertThat(parameters.getLogoutRequest().getSamlRequest()).isEqualTo(encoded); + } + + @Test + void saml2LogoutRegistrationIdResolveWhenNoMatchingRegistrationIdThenSaml2Exception() { + MockHttpServletRequest request = post("/logout/saml2/slo/id"); + request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.resolver.resolve(request, null)); + } + + private MockHttpServletRequest post(String uri) { + MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); + request.setServletPath(uri); + return request; + } + + private MockHttpServletRequest get(String uri) { + MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); + request.setServletPath(uri); + return request; + } + + private String serialize(XMLObject object) { + return this.saml.serialize(object).serialize(); + } + +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutResponseResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutResponseResolverTests.java index 1eec58c8a8..eba07f55c5 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutResponseResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutResponseResolverTests.java @@ -44,6 +44,8 @@ import static org.mockito.Mockito.verify; */ public class OpenSaml4LogoutResponseResolverTests { + private final OpenSamlOperations saml = new OpenSaml4Template(); + RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = mock(RelyingPartyRegistrationResolver.class); @Test @@ -60,7 +62,7 @@ public class OpenSaml4LogoutResponseResolverTests { Authentication authentication = new TestingAuthenticationToken("user", "password"); LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequest(registration); request.setParameter(Saml2ParameterNames.SAML_REQUEST, - Saml2Utils.samlEncode(OpenSamlSigningUtils.serialize(logoutRequest).getBytes())); + Saml2Utils.samlEncode(this.saml.serialize(logoutRequest).serialize().getBytes())); given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); Saml2LogoutResponse logoutResponse = logoutResponseResolver.resolve(request, authentication); assertThat(logoutResponse).isNotNull(); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestResolverTests.java deleted file mode 100644 index 27c1ad267c..0000000000 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestResolverTests.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.web.authentication.logout; - -import java.io.ByteArrayInputStream; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; - -import jakarta.servlet.http.HttpServletRequest; -import org.junit.jupiter.api.Test; -import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; -import org.opensaml.saml.saml2.core.LogoutRequest; -import org.w3c.dom.Document; -import org.w3c.dom.Element; - -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.security.saml2.Saml2Exception; -import org.springframework.security.saml2.core.Saml2ParameterNames; -import org.springframework.security.saml2.provider.service.authentication.DefaultSaml2AuthenticatedPrincipal; -import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; -import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; -import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; -import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.BDDMockito.given; -import static org.mockito.Mockito.mock; - -/** - * Tests for {@link OpenSamlLogoutRequestResolver} - * - * @author Josh Cummings - */ -public class OpenSamlLogoutRequestResolverTests { - - RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = mock(RelyingPartyRegistrationResolver.class); - - OpenSamlLogoutRequestResolver logoutRequestResolver = new OpenSamlLogoutRequestResolver( - this.relyingPartyRegistrationResolver); - - @Test - public void resolveRedirectWhenAuthenticatedThenIncludesName() { - RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); - Saml2Authentication authentication = authentication(registration); - HttpServletRequest request = new MockHttpServletRequest(); - given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); - Saml2LogoutRequest saml2LogoutRequest = this.logoutRequestResolver.resolve(request, authentication); - assertThat(saml2LogoutRequest.getParameter(Saml2ParameterNames.SIG_ALG)).isNotNull(); - assertThat(saml2LogoutRequest.getParameter(Saml2ParameterNames.SIGNATURE)).isNotNull(); - assertThat(saml2LogoutRequest.getParameter(Saml2ParameterNames.RELAY_STATE)).isNotNull(); - Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleLogoutServiceBinding(); - LogoutRequest logoutRequest = getLogoutRequest(saml2LogoutRequest.getSamlRequest(), binding); - assertThat(logoutRequest.getNameID().getValue()).isEqualTo(authentication.getName()); - } - - @Test - public void resolvePostWhenAuthenticatedThenIncludesName() { - RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() - .assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)) - .build(); - Saml2Authentication authentication = authentication(registration); - HttpServletRequest request = new MockHttpServletRequest(); - given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); - Saml2LogoutRequest saml2LogoutRequest = this.logoutRequestResolver.resolve(request, authentication); - assertThat(saml2LogoutRequest.getParameter(Saml2ParameterNames.SIG_ALG)).isNull(); - assertThat(saml2LogoutRequest.getParameter(Saml2ParameterNames.SIGNATURE)).isNull(); - assertThat(saml2LogoutRequest.getParameter(Saml2ParameterNames.RELAY_STATE)).isNotNull(); - Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleLogoutServiceBinding(); - LogoutRequest logoutRequest = getLogoutRequest(saml2LogoutRequest.getSamlRequest(), binding); - assertThat(logoutRequest.getNameID().getValue()).isEqualTo(authentication.getName()); - assertThat(logoutRequest.getSessionIndexes()).hasSize(1); - assertThat(logoutRequest.getSessionIndexes().get(0).getValue()).isEqualTo("session-index"); - } - - private Saml2Authentication authentication(RelyingPartyRegistration registration) { - DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal("user", new HashMap<>(), - Arrays.asList("session-index")); - principal.setRelyingPartyRegistrationId(registration.getRegistrationId()); - return new Saml2Authentication(principal, "response", new ArrayList<>()); - } - - private LogoutRequest getLogoutRequest(String samlRequest, Saml2MessageBinding binding) { - if (binding == Saml2MessageBinding.REDIRECT) { - samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest)); - } - else { - samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8); - } - try { - Document document = XMLObjectProviderRegistrySupport.getParserPool() - .parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8))); - Element element = document.getDocumentElement(); - return (LogoutRequest) XMLObjectProviderRegistrySupport.getUnmarshallerFactory() - .getUnmarshaller(element) - .unmarshall(element); - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } - } - -} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolverTests.java index 8e2ae5a393..b57db0a895 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolverTests.java @@ -18,22 +18,16 @@ package org.springframework.security.saml2.provider.service.web.authentication.l import java.nio.charset.StandardCharsets; -import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensaml.core.xml.XMLObject; -import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; -import org.opensaml.core.xml.io.Marshaller; -import org.opensaml.core.xml.io.MarshallingException; -import org.w3c.dom.Element; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; -import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.Saml2ParameterNames; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; @@ -53,6 +47,8 @@ public final class OpenSamlLogoutRequestValidatorParametersResolverTests { @Mock RelyingPartyRegistrationRepository registrations; + private final OpenSamlOperations saml = new OpenSaml4Template(); + private RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build(); private OpenSamlLogoutRequestValidatorParametersResolver resolver; @@ -151,14 +147,7 @@ public final class OpenSamlLogoutRequestValidatorParametersResolverTests { } private String serialize(XMLObject object) { - try { - Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); - Element element = marshaller.marshall(object); - return SerializeSupport.nodeToString(element); - } - catch (MarshallingException ex) { - throw new Saml2Exception(ex); - } + return this.saml.serialize(object).serialize(); } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutResponseResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutResponseResolverTests.java deleted file mode 100644 index eb640c470d..0000000000 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutResponseResolverTests.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.web.authentication.logout; - -import java.io.ByteArrayInputStream; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.HashMap; - -import org.junit.jupiter.api.Test; -import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; -import org.opensaml.saml.saml2.core.LogoutRequest; -import org.opensaml.saml.saml2.core.LogoutResponse; -import org.opensaml.saml.saml2.core.StatusCode; -import org.w3c.dom.Document; -import org.w3c.dom.Element; - -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.security.core.Authentication; -import org.springframework.security.saml2.Saml2Exception; -import org.springframework.security.saml2.core.Saml2ParameterNames; -import org.springframework.security.saml2.provider.service.authentication.DefaultSaml2AuthenticatedPrincipal; -import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; -import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; -import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutResponse; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; -import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; -import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.BDDMockito.given; -import static org.mockito.Mockito.mock; - -/** - * Tests for {@link OpenSamlLogoutResponseResolver} - * - * @author Josh Cummings - */ -public class OpenSamlLogoutResponseResolverTests { - - RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = mock(RelyingPartyRegistrationResolver.class); - - OpenSamlLogoutResponseResolver logoutResponseResolver = new OpenSamlLogoutResponseResolver(null, - this.relyingPartyRegistrationResolver); - - @Test - public void resolveRedirectWhenAuthenticatedThenSuccess() { - RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); - MockHttpServletRequest request = new MockHttpServletRequest(); - LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequest(registration); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, - Saml2Utils.samlEncode(OpenSamlSigningUtils.serialize(logoutRequest).getBytes())); - request.setParameter(Saml2ParameterNames.RELAY_STATE, "abcd"); - Authentication authentication = authentication(registration); - given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); - Saml2LogoutResponse saml2LogoutResponse = this.logoutResponseResolver.resolve(request, authentication); - assertThat(saml2LogoutResponse.getParameter(Saml2ParameterNames.SIG_ALG)).isNotNull(); - assertThat(saml2LogoutResponse.getParameter(Saml2ParameterNames.SIGNATURE)).isNotNull(); - assertThat(saml2LogoutResponse.getParameter(Saml2ParameterNames.RELAY_STATE)).isSameAs("abcd"); - Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleLogoutServiceBinding(); - LogoutResponse logoutResponse = getLogoutResponse(saml2LogoutResponse.getSamlResponse(), binding); - assertThat(logoutResponse.getStatus().getStatusCode().getValue()).isEqualTo(StatusCode.SUCCESS); - } - - @Test - public void resolvePostWhenAuthenticatedThenSuccess() { - RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() - .assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)) - .build(); - MockHttpServletRequest request = new MockHttpServletRequest(); - LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequest(registration); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, - Saml2Utils.samlEncode(OpenSamlSigningUtils.serialize(logoutRequest).getBytes())); - request.setParameter(Saml2ParameterNames.RELAY_STATE, "abcd"); - Authentication authentication = authentication(registration); - given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); - Saml2LogoutResponse saml2LogoutResponse = this.logoutResponseResolver.resolve(request, authentication); - assertThat(saml2LogoutResponse.getParameter(Saml2ParameterNames.SIG_ALG)).isNull(); - assertThat(saml2LogoutResponse.getParameter(Saml2ParameterNames.SIGNATURE)).isNull(); - assertThat(saml2LogoutResponse.getParameter(Saml2ParameterNames.RELAY_STATE)).isSameAs("abcd"); - Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleLogoutServiceBinding(); - LogoutResponse logoutResponse = getLogoutResponse(saml2LogoutResponse.getSamlResponse(), binding); - assertThat(logoutResponse.getStatus().getStatusCode().getValue()).isEqualTo(StatusCode.SUCCESS); - } - - // gh-10923 - @Test - public void resolvePostWithLineBreaksWhenAuthenticatedThenSuccess() { - RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() - .assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)) - .build(); - MockHttpServletRequest request = new MockHttpServletRequest(); - LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequest(registration); - String encoded = new StringBuffer( - Saml2Utils.samlEncode(OpenSamlSigningUtils.serialize(logoutRequest).getBytes())) - .insert(10, "\r\n") - .toString(); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, encoded); - request.setParameter(Saml2ParameterNames.RELAY_STATE, "abcd"); - Authentication authentication = authentication(registration); - given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); - Saml2LogoutResponse saml2LogoutResponse = this.logoutResponseResolver.resolve(request, authentication); - assertThat(saml2LogoutResponse.getParameter(Saml2ParameterNames.SIG_ALG)).isNull(); - assertThat(saml2LogoutResponse.getParameter(Saml2ParameterNames.SIGNATURE)).isNull(); - assertThat(saml2LogoutResponse.getParameter(Saml2ParameterNames.RELAY_STATE)).isSameAs("abcd"); - Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleLogoutServiceBinding(); - LogoutResponse logoutResponse = getLogoutResponse(saml2LogoutResponse.getSamlResponse(), binding); - assertThat(logoutResponse.getStatus().getStatusCode().getValue()).isEqualTo(StatusCode.SUCCESS); - } - - private Saml2Authentication authentication(RelyingPartyRegistration registration) { - DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal("user", new HashMap<>()); - principal.setRelyingPartyRegistrationId(registration.getRegistrationId()); - return new Saml2Authentication(principal, "response", new ArrayList<>()); - } - - private LogoutResponse getLogoutResponse(String saml2Response, Saml2MessageBinding binding) { - if (binding == Saml2MessageBinding.REDIRECT) { - saml2Response = Saml2Utils.samlInflate(Saml2Utils.samlDecode(saml2Response)); - } - else { - saml2Response = new String(Saml2Utils.samlDecode(saml2Response), StandardCharsets.UTF_8); - } - try { - Document document = XMLObjectProviderRegistrySupport.getParserPool() - .parse(new ByteArrayInputStream(saml2Response.getBytes(StandardCharsets.UTF_8))); - Element element = document.getDocumentElement(); - return (LogoutResponse) XMLObjectProviderRegistrySupport.getUnmarshallerFactory() - .getUnmarshaller(element) - .unmarshall(element); - } - catch (Exception ex) { - throw new Saml2Exception(ex); - } - } - -} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlSigningUtilsTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlSigningUtilsTests.java deleted file mode 100644 index c6ce3699c5..0000000000 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlSigningUtilsTests.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.web.authentication.logout; - -import java.util.UUID; - -import javax.xml.namespace.QName; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.opensaml.core.xml.XMLObject; -import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; -import org.opensaml.saml.common.SAMLVersion; -import org.opensaml.saml.saml2.core.Issuer; -import org.opensaml.saml.saml2.core.Response; -import org.opensaml.xmlsec.signature.Signature; - -import org.springframework.security.saml2.core.OpenSamlInitializationService; -import org.springframework.security.saml2.core.TestSaml2X509Credentials; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Test open SAML signatures - */ -public class OpenSamlSigningUtilsTests { - - static { - OpenSamlInitializationService.initialize(); - } - - private RelyingPartyRegistration registration; - - @BeforeEach - public void setup() { - this.registration = RelyingPartyRegistration.withRegistrationId("saml-idp") - .entityId("https://some.idp.example.com/entity-id") - .signingX509Credentials((c) -> { - c.add(TestSaml2X509Credentials.relyingPartySigningCredential()); - c.add(TestSaml2X509Credentials.assertingPartySigningCredential()); - }) - .assertingPartyDetails((c) -> c.entityId("https://some.idp.example.com/entity-id") - .singleSignOnServiceLocation("https://some.idp.example.com/service-location")) - .build(); - } - - @Test - public void whenSigningAnObjectThenKeyInfoIsPartOfTheSignature() { - Response response = response("destination", "issuer"); - OpenSamlSigningUtils.sign(response, this.registration); - Signature signature = response.getSignature(); - assertThat(signature).isNotNull(); - assertThat(signature.getKeyInfo()).isNotNull(); - } - - Response response(String destination, String issuerEntityId) { - Response response = build(Response.DEFAULT_ELEMENT_NAME); - response.setID("R" + UUID.randomUUID()); - response.setVersion(SAMLVersion.VERSION_20); - response.setID("_" + UUID.randomUUID()); - response.setDestination(destination); - response.setIssuer(issuer(issuerEntityId)); - return response; - } - - Issuer issuer(String entityId) { - Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME); - issuer.setValue(entityId); - return issuer; - } - - T build(QName qName) { - return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName); - } - -} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutSigningUtilsTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutSigningUtilsTests.java deleted file mode 100644 index 2471efe905..0000000000 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutSigningUtilsTests.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.saml2.provider.service.web.authentication.logout; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.opensaml.saml.saml2.core.LogoutRequest; -import org.opensaml.xmlsec.signature.Signature; - -import org.springframework.security.saml2.core.TestSaml2X509Credentials; -import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Test open SAML Logout signatures - */ -public class Saml2LogoutSigningUtilsTests { - - private RelyingPartyRegistration registration; - - @BeforeEach - public void setup() { - this.registration = RelyingPartyRegistration.withRegistrationId("saml-idp") - .entityId("https://some.idp.example.com/entity-id") - .signingX509Credentials((c) -> { - c.add(TestSaml2X509Credentials.relyingPartySigningCredential()); - c.add(TestSaml2X509Credentials.assertingPartySigningCredential()); - }) - .assertingPartyDetails((c) -> c.entityId("https://some.idp.example.com/entity-id") - .singleSignOnServiceLocation("https://some.idp.example.com/service-location")) - .build(); - } - - @Test - public void whenSigningLogoutRequestRPThenKeyInfoIsPartOfTheSignature() { - LogoutRequest logoutRequest = TestOpenSamlObjects.relyingPartyLogoutRequest(this.registration); - OpenSamlSigningUtils.sign(logoutRequest, this.registration); - Signature signature = logoutRequest.getSignature(); - assertThat(signature).isNotNull(); - assertThat(signature.getKeyInfo()).isNotNull(); - } - -}