From 5779121da644e767532480ae9eb3951a4bf55756 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Thu, 23 Jul 2020 16:08:48 -0600 Subject: [PATCH] OpenSamlAuthenticationRequestFactory Uses OpenSAML Directly Closes gh-8774 --- .../OpenSamlAuthenticationRequestFactory.java | 132 +++++++++++++++--- ...SamlAuthenticationRequestFactoryTests.java | 20 ++- 2 files changed, 129 insertions(+), 23 deletions(-) diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java index 71d9a18733..f35686bbc2 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java @@ -16,30 +16,40 @@ package org.springframework.security.saml2.provider.service.authentication; +import java.nio.charset.StandardCharsets; import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.time.Clock; import java.time.Instant; import java.util.Collection; +import java.util.LinkedHashMap; import java.util.Map; import java.util.UUID; import java.util.function.Consumer; import java.util.function.Function; +import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.joda.time.DateTime; +import org.opensaml.core.config.ConfigurationService; +import org.opensaml.core.xml.config.XMLObjectProviderRegistry; import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder; +import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller; +import org.opensaml.saml.saml2.core.impl.IssuerBuilder; import org.opensaml.security.SecurityException; import org.opensaml.security.credential.BasicCredential; import org.opensaml.security.credential.Credential; import org.opensaml.security.credential.CredentialSupport; import org.opensaml.security.credential.UsageType; import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; import org.opensaml.xmlsec.signature.support.SignatureConstants; import org.opensaml.xmlsec.signature.support.SignatureException; import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.w3c.dom.Element; import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.Saml2Exception; @@ -47,11 +57,14 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest.Builder; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.util.Assert; +import org.springframework.web.util.UriUtils; import static java.nio.charset.StandardCharsets.UTF_8; import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDeflate; import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlEncode; +import static org.springframework.util.StringUtils.hasText; /** * @since 5.2 @@ -62,7 +75,10 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication } private Clock clock = Clock.systemUTC(); - private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance(); + + private AuthnRequestMarshaller marshaller; + private AuthnRequestBuilder authnRequestBuilder; + private IssuerBuilder issuerBuilder; private Converter protocolBindingResolver = context -> { @@ -75,6 +91,19 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication private Function> authnRequestConsumerResolver = context -> authnRequest -> {}; + /** + * Creates an {@link OpenSamlAuthenticationRequestFactory} + */ + public OpenSamlAuthenticationRequestFactory() { + XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); + this.marshaller = (AuthnRequestMarshaller) registry.getMarshallerFactory() + .getMarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); + this.authnRequestBuilder = (AuthnRequestBuilder) registry.getBuilderFactory() + .getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME); + this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory() + .getBuilder(Issuer.DEFAULT_ELEMENT_NAME); + } + @Override @Deprecated public String createAuthenticationRequest(Saml2AuthenticationRequest request) { @@ -84,8 +113,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication for (org.springframework.security.saml2.credentials.Saml2X509Credential credential : request.getCredentials()) { if (credential.isSigningCredential()) { Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), request.getIssuer()); - signAuthnRequest(authnRequest, cred); - return this.saml.serialize(authnRequest); + return serialize(sign(authnRequest, cred)); } } throw new IllegalArgumentException("No signing credential provided"); @@ -98,8 +126,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) { AuthnRequest authnRequest = createAuthnRequest(context); String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() ? - signThenSerialize(authnRequest, context.getRelyingPartyRegistration()) : - this.saml.serialize(authnRequest); + serialize(sign(authnRequest, context.getRelyingPartyRegistration())) : + serialize(authnRequest); return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context) .samlRequest(samlEncode(xml.getBytes(UTF_8))) @@ -112,7 +140,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication @Override public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) { AuthnRequest authnRequest = createAuthnRequest(context); - String xml = this.saml.serialize(authnRequest); + String xml = serialize(authnRequest); Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context); String deflatedAndEncoded = samlEncode(samlDeflate(xml)); result.samlRequest(deflatedAndEncoded) @@ -120,15 +148,20 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) { Collection signingCredentials = context.getRelyingPartyRegistration().getSigningX509Credentials(); - Map signedParams = this.saml.signQueryParameters( - signingCredentials, - deflatedAndEncoded, - context.getRelayState() - ); - result.samlRequest(signedParams.get("SAMLRequest")) - .relayState(signedParams.get("RelayState")) - .sigAlg(signedParams.get("SigAlg")) - .signature(signedParams.get("Signature")); + for (Saml2X509Credential credential : signingCredentials) { + Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), ""); + Map signedParams = signQueryParameters( + cred, + deflatedAndEncoded, + context.getRelayState()); + return result + .samlRequest(signedParams.get("SAMLRequest")) + .relayState(signedParams.get("RelayState")) + .sigAlg(signedParams.get("SigAlg")) + .signature(signedParams.get("Signature")) + .build(); + } + throw new Saml2Exception("No signing credential provided"); } return result.build(); @@ -144,13 +177,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication private AuthnRequest createAuthnRequest (String issuer, String destination, String assertionConsumerServiceUrl, String protocolBinding) { - AuthnRequest auth = this.saml.buildSamlObject(AuthnRequest.DEFAULT_ELEMENT_NAME); + AuthnRequest auth = this.authnRequestBuilder.buildObject(); auth.setID("ARQ" + UUID.randomUUID().toString().substring(1)); auth.setIssueInstant(new DateTime(this.clock.millis())); auth.setForceAuthn(Boolean.FALSE); auth.setIsPassive(Boolean.FALSE); auth.setProtocolBinding(protocolBinding); - Issuer iss = this.saml.buildSamlObject(Issuer.DEFAULT_ELEMENT_NAME); + Issuer iss = this.issuerBuilder.buildObject(); iss.setValue(issuer); auth.setIssuer(iss); auth.setDestination(destination); @@ -192,7 +225,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication * @param protocolBinding either {@link SAMLConstants#SAML2_POST_BINDING_URI} or * {@link SAMLConstants#SAML2_REDIRECT_BINDING_URI} * @throws IllegalArgumentException if the protocolBinding is not valid - * @deprecated Use {@link org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.Builder#assertionConsumerServiceBinding} + * @deprecated Use {@link org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.Builder#assertionConsumerServiceBinding(Saml2MessageBinding)} * instead */ @Deprecated @@ -205,17 +238,16 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication this.protocolBindingResolver = context -> protocolBinding; } - private String signThenSerialize(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) { + private AuthnRequest sign(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) { for (Saml2X509Credential credential : relyingPartyRegistration.getSigningX509Credentials()) { Credential cred = getSigningCredential( credential.getCertificate(), credential.getPrivateKey(), relyingPartyRegistration.getEntityId()); - signAuthnRequest(authnRequest, cred); - return this.saml.serialize(authnRequest); + return sign(authnRequest, cred); } throw new IllegalArgumentException("No signing credential provided"); } - private void signAuthnRequest(AuthnRequest authnRequest, Credential credential) { + private AuthnRequest sign(AuthnRequest authnRequest, Credential credential) { SignatureSigningParameters parameters = new SignatureSigningParameters(); parameters.setSigningCredential(credential); parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); @@ -223,6 +255,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS); try { SignatureSupport.signObject(authnRequest, parameters); + return authnRequest; } catch (MarshallingException | SignatureException | SecurityException e) { throw new Saml2Exception(e); } @@ -234,4 +267,59 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication cred.setUsageType(UsageType.SIGNING); return cred; } + + private Map signQueryParameters( + Credential credential, + String samlRequest, + String relayState) { + Assert.notNull(samlRequest, "samlRequest cannot be null"); + String algorithmUri = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256; + StringBuilder queryString = new StringBuilder(); + queryString + .append("SAMLRequest") + .append("=") + .append(UriUtils.encode(samlRequest, StandardCharsets.ISO_8859_1)) + .append("&"); + if (hasText(relayState)) { + queryString + .append("RelayState") + .append("=") + .append(UriUtils.encode(relayState, StandardCharsets.ISO_8859_1)) + .append("&"); + } + queryString + .append("SigAlg") + .append("=") + .append(UriUtils.encode(algorithmUri, StandardCharsets.ISO_8859_1)); + + try { + byte[] rawSignature = XMLSigningUtil.signWithURI( + credential, + algorithmUri, + queryString.toString().getBytes(StandardCharsets.UTF_8) + ); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + + Map result = new LinkedHashMap<>(); + result.put("SAMLRequest", samlRequest); + if (hasText(relayState)) { + result.put("RelayState", relayState); + } + result.put("SigAlg", algorithmUri); + result.put("Signature", b64Signature); + return result; + } + catch (SecurityException e) { + throw new Saml2Exception(e); + } + } + + private String serialize(AuthnRequest authnRequest) { + try { + Element element = this.marshaller.marshall(authnRequest); + return SerializeSupport.nodeToString(element); + } catch (MarshallingException e) { + throw new Saml2Exception(e); + } + } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java index a273f8bf9f..8b8da0de9e 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java @@ -16,6 +16,7 @@ package org.springframework.security.saml2.provider.service.authentication; +import java.io.ByteArrayInputStream; import java.util.function.Consumer; import java.util.function.Function; @@ -26,7 +27,11 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.saml2.core.AuthnRequest; +import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller; +import org.w3c.dom.Document; +import org.w3c.dom.Element; +import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; @@ -37,6 +42,8 @@ import static org.hamcrest.CoreMatchers.containsString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getParserPool; +import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getUnmarshallerFactory; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential; import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode; import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate; @@ -56,6 +63,9 @@ public class OpenSamlAuthenticationRequestFactoryTests { private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder; private RelyingPartyRegistration relyingPartyRegistration; + private AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) getUnmarshallerFactory() + .getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); + @Rule public ExpectedException exception = ExpectedException.none(); @@ -224,6 +234,14 @@ public class OpenSamlAuthenticationRequestFactoryTests { else { samlRequest = new String(samlDecode(samlRequest), UTF_8); } - return (AuthnRequest) OpenSamlImplementation.getInstance().resolve(samlRequest); + try { + Document document = getParserPool().parse( + new ByteArrayInputStream(samlRequest.getBytes(UTF_8))); + Element element = document.getDocumentElement(); + return (AuthnRequest) this.unmarshaller.unmarshall(element); + } + catch (Exception e) { + throw new Saml2Exception(e); + } } }