@ -16,30 +16,40 @@
@@ -16,30 +16,40 @@
package org.springframework.security.saml2.provider.service.authentication ;
import java.nio.charset.StandardCharsets ;
import java.security.PrivateKey ;
import java.security.cert.X509Certificate ;
import java.time.Clock ;
import java.time.Instant ;
import java.util.Collection ;
import java.util.LinkedHashMap ;
import java.util.Map ;
import java.util.UUID ;
import java.util.function.Consumer ;
import java.util.function.Function ;
import net.shibboleth.utilities.java.support.xml.SerializeSupport ;
import org.joda.time.DateTime ;
import org.opensaml.core.config.ConfigurationService ;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry ;
import org.opensaml.core.xml.io.MarshallingException ;
import org.opensaml.saml.common.xml.SAMLConstants ;
import org.opensaml.saml.saml2.core.AuthnRequest ;
import org.opensaml.saml.saml2.core.Issuer ;
import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder ;
import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller ;
import org.opensaml.saml.saml2.core.impl.IssuerBuilder ;
import org.opensaml.security.SecurityException ;
import org.opensaml.security.credential.BasicCredential ;
import org.opensaml.security.credential.Credential ;
import org.opensaml.security.credential.CredentialSupport ;
import org.opensaml.security.credential.UsageType ;
import org.opensaml.xmlsec.SignatureSigningParameters ;
import org.opensaml.xmlsec.crypto.XMLSigningUtil ;
import org.opensaml.xmlsec.signature.support.SignatureConstants ;
import org.opensaml.xmlsec.signature.support.SignatureException ;
import org.opensaml.xmlsec.signature.support.SignatureSupport ;
import org.w3c.dom.Element ;
import org.springframework.core.convert.converter.Converter ;
import org.springframework.security.saml2.Saml2Exception ;
@ -47,11 +57,14 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService;
@@ -47,11 +57,14 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2X509Credential ;
import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest.Builder ;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration ;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding ;
import org.springframework.util.Assert ;
import org.springframework.web.util.UriUtils ;
import static java.nio.charset.StandardCharsets.UTF_8 ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDeflate ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlEncode ;
import static org.springframework.util.StringUtils.hasText ;
/ * *
* @since 5 . 2
@ -62,7 +75,10 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@@ -62,7 +75,10 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
}
private Clock clock = Clock . systemUTC ( ) ;
private final OpenSamlImplementation saml = OpenSamlImplementation . getInstance ( ) ;
private AuthnRequestMarshaller marshaller ;
private AuthnRequestBuilder authnRequestBuilder ;
private IssuerBuilder issuerBuilder ;
private Converter < Saml2AuthenticationRequestContext , String > protocolBindingResolver =
context - > {
@ -75,6 +91,19 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@@ -75,6 +91,19 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
private Function < Saml2AuthenticationRequestContext , Consumer < AuthnRequest > > authnRequestConsumerResolver
= context - > authnRequest - > { } ;
/ * *
* Creates an { @link OpenSamlAuthenticationRequestFactory }
* /
public OpenSamlAuthenticationRequestFactory ( ) {
XMLObjectProviderRegistry registry = ConfigurationService . get ( XMLObjectProviderRegistry . class ) ;
this . marshaller = ( AuthnRequestMarshaller ) registry . getMarshallerFactory ( )
. getMarshaller ( AuthnRequest . DEFAULT_ELEMENT_NAME ) ;
this . authnRequestBuilder = ( AuthnRequestBuilder ) registry . getBuilderFactory ( )
. getBuilder ( AuthnRequest . DEFAULT_ELEMENT_NAME ) ;
this . issuerBuilder = ( IssuerBuilder ) registry . getBuilderFactory ( )
. getBuilder ( Issuer . DEFAULT_ELEMENT_NAME ) ;
}
@Override
@Deprecated
public String createAuthenticationRequest ( Saml2AuthenticationRequest request ) {
@ -84,8 +113,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@@ -84,8 +113,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
for ( org . springframework . security . saml2 . credentials . Saml2X509Credential credential : request . getCredentials ( ) ) {
if ( credential . isSigningCredential ( ) ) {
Credential cred = getSigningCredential ( credential . getCertificate ( ) , credential . getPrivateKey ( ) , request . getIssuer ( ) ) ;
signAuthnRequest ( authnRequest , cred ) ;
return this . saml . serialize ( authnRequest ) ;
return serialize ( sign ( authnRequest , cred ) ) ;
}
}
throw new IllegalArgumentException ( "No signing credential provided" ) ;
@ -98,8 +126,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@@ -98,8 +126,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
public Saml2PostAuthenticationRequest createPostAuthenticationRequest ( Saml2AuthenticationRequestContext context ) {
AuthnRequest authnRequest = createAuthnRequest ( context ) ;
String xml = context . getRelyingPartyRegistration ( ) . getAssertingPartyDetails ( ) . getWantAuthnRequestsSigned ( ) ?
signThenS erialize ( authnRequest , context . getRelyingPartyRegistration ( ) ) :
this . saml . serialize ( authnRequest ) ;
serialize ( sign ( authnRequest , context . getRelyingPartyRegistration ( ) ) ) :
serialize ( authnRequest ) ;
return Saml2PostAuthenticationRequest . withAuthenticationRequestContext ( context )
. samlRequest ( samlEncode ( xml . getBytes ( UTF_8 ) ) )
@ -112,7 +140,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@@ -112,7 +140,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@Override
public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest ( Saml2AuthenticationRequestContext context ) {
AuthnRequest authnRequest = createAuthnRequest ( context ) ;
String xml = this . saml . serialize ( authnRequest ) ;
String xml = serialize ( authnRequest ) ;
Builder result = Saml2RedirectAuthenticationRequest . withAuthenticationRequestContext ( context ) ;
String deflatedAndEncoded = samlEncode ( samlDeflate ( xml ) ) ;
result . samlRequest ( deflatedAndEncoded )
@ -120,15 +148,20 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@@ -120,15 +148,20 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
if ( context . getRelyingPartyRegistration ( ) . getAssertingPartyDetails ( ) . getWantAuthnRequestsSigned ( ) ) {
Collection < Saml2X509Credential > signingCredentials = context . getRelyingPartyRegistration ( ) . getSigningX509Credentials ( ) ;
Map < String , String > signedParams = this . saml . signQueryParameters (
signingCredentials ,
deflatedAndEncoded ,
context . getRelayState ( )
) ;
result . samlRequest ( signedParams . get ( "SAMLRequest" ) )
. relayState ( signedParams . get ( "RelayState" ) )
. sigAlg ( signedParams . get ( "SigAlg" ) )
. signature ( signedParams . get ( "Signature" ) ) ;
for ( Saml2X509Credential credential : signingCredentials ) {
Credential cred = getSigningCredential ( credential . getCertificate ( ) , credential . getPrivateKey ( ) , "" ) ;
Map < String , String > signedParams = signQueryParameters (
cred ,
deflatedAndEncoded ,
context . getRelayState ( ) ) ;
return result
. samlRequest ( signedParams . get ( "SAMLRequest" ) )
. relayState ( signedParams . get ( "RelayState" ) )
. sigAlg ( signedParams . get ( "SigAlg" ) )
. signature ( signedParams . get ( "Signature" ) )
. build ( ) ;
}
throw new Saml2Exception ( "No signing credential provided" ) ;
}
return result . build ( ) ;
@ -144,13 +177,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@@ -144,13 +177,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
private AuthnRequest createAuthnRequest
( String issuer , String destination , String assertionConsumerServiceUrl , String protocolBinding ) {
AuthnRequest auth = this . saml . buildSaml Object ( AuthnRequest . DEFAULT_ELEMENT_NAME ) ;
AuthnRequest auth = this . authnRequestBuilder . buildObject ( ) ;
auth . setID ( "ARQ" + UUID . randomUUID ( ) . toString ( ) . substring ( 1 ) ) ;
auth . setIssueInstant ( new DateTime ( this . clock . millis ( ) ) ) ;
auth . setForceAuthn ( Boolean . FALSE ) ;
auth . setIsPassive ( Boolean . FALSE ) ;
auth . setProtocolBinding ( protocolBinding ) ;
Issuer iss = this . saml . buildSaml Object ( Issuer . DEFAULT_ELEMENT_NAME ) ;
Issuer iss = this . issuerBuilder . buildObject ( ) ;
iss . setValue ( issuer ) ;
auth . setIssuer ( iss ) ;
auth . setDestination ( destination ) ;
@ -192,7 +225,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@@ -192,7 +225,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
* @param protocolBinding either { @link SAMLConstants # SAML2_POST_BINDING_URI } or
* { @link SAMLConstants # SAML2_REDIRECT_BINDING_URI }
* @throws IllegalArgumentException if the protocolBinding is not valid
* @deprecated Use { @link org . springframework . security . saml2 . provider . service . registration . RelyingPartyRegistration . Builder # assertionConsumerServiceBinding }
* @deprecated Use { @link org . springframework . security . saml2 . provider . service . registration . RelyingPartyRegistration . Builder # assertionConsumerServiceBinding ( Saml2MessageBinding ) }
* instead
* /
@Deprecated
@ -205,17 +238,16 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@@ -205,17 +238,16 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
this . protocolBindingResolver = context - > protocolBinding ;
}
private String signThenSerialize ( AuthnRequest authnRequest , RelyingPartyRegistration relyingPartyRegistration ) {
private AuthnRequest sign ( AuthnRequest authnRequest , RelyingPartyRegistration relyingPartyRegistration ) {
for ( Saml2X509Credential credential : relyingPartyRegistration . getSigningX509Credentials ( ) ) {
Credential cred = getSigningCredential (
credential . getCertificate ( ) , credential . getPrivateKey ( ) , relyingPartyRegistration . getEntityId ( ) ) ;
signAuthnRequest ( authnRequest , cred ) ;
return this . saml . serialize ( authnRequest ) ;
return sign ( authnRequest , cred ) ;
}
throw new IllegalArgumentException ( "No signing credential provided" ) ;
}
private void signAuthnRequest ( AuthnRequest authnRequest , Credential credential ) {
private AuthnRequest sign ( AuthnRequest authnRequest , Credential credential ) {
SignatureSigningParameters parameters = new SignatureSigningParameters ( ) ;
parameters . setSigningCredential ( credential ) ;
parameters . setSignatureAlgorithm ( SignatureConstants . ALGO_ID_SIGNATURE_RSA_SHA256 ) ;
@ -223,6 +255,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@@ -223,6 +255,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
parameters . setSignatureCanonicalizationAlgorithm ( SignatureConstants . ALGO_ID_C14N_EXCL_OMIT_COMMENTS ) ;
try {
SignatureSupport . signObject ( authnRequest , parameters ) ;
return authnRequest ;
} catch ( MarshallingException | SignatureException | SecurityException e ) {
throw new Saml2Exception ( e ) ;
}
@ -234,4 +267,59 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@@ -234,4 +267,59 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
cred . setUsageType ( UsageType . SIGNING ) ;
return cred ;
}
private Map < String , String > signQueryParameters (
Credential credential ,
String samlRequest ,
String relayState ) {
Assert . notNull ( samlRequest , "samlRequest cannot be null" ) ;
String algorithmUri = SignatureConstants . ALGO_ID_SIGNATURE_RSA_SHA256 ;
StringBuilder queryString = new StringBuilder ( ) ;
queryString
. append ( "SAMLRequest" )
. append ( "=" )
. append ( UriUtils . encode ( samlRequest , StandardCharsets . ISO_8859_1 ) )
. append ( "&" ) ;
if ( hasText ( relayState ) ) {
queryString
. append ( "RelayState" )
. append ( "=" )
. append ( UriUtils . encode ( relayState , StandardCharsets . ISO_8859_1 ) )
. append ( "&" ) ;
}
queryString
. append ( "SigAlg" )
. append ( "=" )
. append ( UriUtils . encode ( algorithmUri , StandardCharsets . ISO_8859_1 ) ) ;
try {
byte [ ] rawSignature = XMLSigningUtil . signWithURI (
credential ,
algorithmUri ,
queryString . toString ( ) . getBytes ( StandardCharsets . UTF_8 )
) ;
String b64Signature = Saml2Utils . samlEncode ( rawSignature ) ;
Map < String , String > result = new LinkedHashMap < > ( ) ;
result . put ( "SAMLRequest" , samlRequest ) ;
if ( hasText ( relayState ) ) {
result . put ( "RelayState" , relayState ) ;
}
result . put ( "SigAlg" , algorithmUri ) ;
result . put ( "Signature" , b64Signature ) ;
return result ;
}
catch ( SecurityException e ) {
throw new Saml2Exception ( e ) ;
}
}
private String serialize ( AuthnRequest authnRequest ) {
try {
Element element = this . marshaller . marshall ( authnRequest ) ;
return SerializeSupport . nodeToString ( element ) ;
} catch ( MarshallingException e ) {
throw new Saml2Exception ( e ) ;
}
}
}