@ -15,12 +15,28 @@
@@ -15,12 +15,28 @@
* /
package org.springframework.security.saml2.provider.service.authentication ;
import java.security.cert.X509Certificate ;
import java.time.Duration ;
import java.util.ArrayList ;
import java.util.Collection ;
import java.util.Collections ;
import java.util.HashMap ;
import java.util.HashSet ;
import java.util.LinkedList ;
import java.util.List ;
import java.util.Map ;
import java.util.Set ;
import javax.annotation.Nonnull ;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet ;
import org.apache.commons.logging.Log ;
import org.apache.commons.logging.LogFactory ;
import org.opensaml.saml.common.SignableSAMLObject ;
import org.opensaml.saml.common.assertion.AssertionValidationException ;
import org.opensaml.core.criterion.EntityIdCriterion ;
import org.opensaml.saml.common.assertion.ValidationContext ;
import org.opensaml.saml.common.assertion.ValidationResult ;
import org.opensaml.saml.common.xml.SAMLConstants ;
import org.opensaml.saml.criterion.ProtocolCriterion ;
import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion ;
import org.opensaml.saml.saml2.assertion.ConditionValidator ;
import org.opensaml.saml.saml2.assertion.SAML20AssertionValidator ;
import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters ;
@ -40,16 +56,20 @@ import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator;
@@ -40,16 +56,20 @@ import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator;
import org.opensaml.security.credential.Credential ;
import org.opensaml.security.credential.CredentialResolver ;
import org.opensaml.security.credential.CredentialSupport ;
import org.opensaml.security.credential.UsageType ;
import org.opensaml.security.credential.criteria.impl.EvaluableEntityIDCredentialCriterion ;
import org.opensaml.security.credential.criteria.impl.EvaluableUsageCredentialCriterion ;
import org.opensaml.security.credential.impl.CollectionCredentialResolver ;
import org.opensaml.security.criteria.UsageCriterion ;
import org.opensaml.security.x509.BasicX509Credential ;
import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap ;
import org.opensaml.xmlsec.encryption.support.DecryptionException ;
import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver ;
import org.opensaml.xmlsec.keyinfo.impl.StaticKeyInfoCredentialResolver ;
import org.opensaml.xmlsec.signature.support.SignatureException ;
import org.opensaml.xmlsec.signature.support.SignaturePrevalidator ;
import org.opensaml.xmlsec.signature.support.SignatureTrustEngine ;
import org.opensaml.xmlsec.signature.support.SignatureValidator ;
import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine ;
import org.springframework.core.convert.converter.Converter ;
import org.springframework.security.authentication.AuthenticationProvider ;
import org.springframework.security.core.Authentication ;
@ -60,30 +80,24 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
@@ -60,30 +80,24 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
import org.springframework.security.saml2.Saml2Exception ;
import org.springframework.security.saml2.credentials.Saml2X509Credential ;
import org.springframework.util.Assert ;
import org.springframework.util.StringUtils ;
import java.security.cert.X509Certificate ;
import java.time.Duration ;
import java.util.Collection ;
import java.util.Collections ;
import java.util.HashMap ;
import java.util.HashSet ;
import java.util.LinkedList ;
import java.util.List ;
import java.util.Map ;
import java.util.Set ;
import static java.lang.String.format ;
import static java.util.Collections.singleton ;
import static java.util.Collections.singletonList ;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.CLOCK_SKEW ;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.COND_VALID_AUDIENCES ;
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.DECRYPTION_ERROR ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INVALID_ASSERTION ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INVALID_DESTINATION ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INVALID_ISSUER ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INVALID_SIGNATURE ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.SUBJECT_NOT_FOUND ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.USERNAME_NOT_FOUND ;
import static org.springframework.util.Assert.notNull ;
import static org.springframework.util.StringUtils.hasText ;
/ * *
* Implementation of { @link AuthenticationProvider } for SAML authentications when receiving a
@ -125,6 +139,20 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
@@ -125,6 +139,20 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private static Log logger = LogFactory . getLog ( OpenSamlAuthenticationProvider . class ) ;
private final List < ConditionValidator > conditions = Collections . singletonList ( new AudienceRestrictionConditionValidator ( ) ) ;
private final SubjectConfirmationValidator subjectConfirmationValidator = new BearerSubjectConfirmationValidator ( ) {
@Nonnull
@Override
protected ValidationResult validateAddress ( @Nonnull SubjectConfirmation confirmation ,
@Nonnull Assertion assertion , @Nonnull ValidationContext context ) {
// skipping address validation - gh-7514
return ValidationResult . VALID ;
}
} ;
private final List < SubjectConfirmationValidator > subjects = Collections . singletonList ( this . subjectConfirmationValidator ) ;
private final List < StatementValidator > statements = Collections . emptyList ( ) ;
private final SignaturePrevalidator signaturePrevalidator = new SAMLSignatureProfileValidator ( ) ;
private final OpenSamlImplementation saml = OpenSamlImplementation . getInstance ( ) ;
private Converter < Assertion , Collection < ? extends GrantedAuthority > > authoritiesExtractor =
( a - > singletonList ( new SimpleGrantedAuthority ( "ROLE_USER" ) ) ) ;
@ -173,17 +201,17 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
@@ -173,17 +201,17 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
public Authentication authenticate ( Authentication authentication ) throws AuthenticationException {
try {
Saml2AuthenticationToken token = ( Saml2AuthenticationToken ) authentication ;
Response samlResponse = getSaml2Response ( token ) ;
Assertion assertion = validateSaml2Response ( token , token . getRecipientUri ( ) , samlResponse ) ;
Response response = parse ( token . getSaml2Response ( ) ) ;
List < Assertion > validAssertions = validateResponse ( token , response ) ;
Assertion assertion = validAssertions . get ( 0 ) ;
String username = getUsername ( token , assertion ) ;
return new Saml2Authentication (
new SimpleSaml2AuthenticatedPrincipal ( username ) , token . getSaml2Response ( ) ,
this . authoritiesMapper . mapAuthorities ( getAssertionAuthorities ( assertion ) )
) ;
this . authoritiesMapper . mapAuthorities ( getAssertionAuthorities ( assertion ) ) ) ;
} catch ( Saml2AuthenticationException e ) {
throw e ;
} catch ( Exception e ) {
throw authException ( Saml2ErrorCodes . INTERNAL_VALIDATION_ERROR , e . getMessage ( ) , e ) ;
throw authException ( INTERNAL_VALIDATION_ERROR , e . getMessage ( ) , e ) ;
}
}
@ -199,241 +227,187 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
@@ -199,241 +227,187 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return this . authoritiesExtractor . convert ( assertion ) ;
}
private String getUsername ( Saml2AuthenticationToken token , Assertion assertion ) throws Saml2AuthenticationException {
String username = null ;
Subject subject = assertion . getSubject ( ) ;
if ( subject = = null ) {
throw authException ( SUBJECT_NOT_FOUND , "Assertion [" + assertion . getID ( ) + "] is missing a subject" ) ;
}
if ( subject . getNameID ( ) ! = null ) {
username = subject . getNameID ( ) . getValue ( ) ;
}
else if ( subject . getEncryptedID ( ) ! = null ) {
NameID nameId = decrypt ( token , subject . getEncryptedID ( ) ) ;
username = nameId . getValue ( ) ;
}
if ( username = = null ) {
throw authException ( USERNAME_NOT_FOUND , "Assertion [" + assertion . getID ( ) + "] is missing a user identifier" ) ;
private Response parse ( String response ) throws Saml2Exception , Saml2AuthenticationException {
try {
Object result = this . saml . resolve ( response ) ;
if ( result instanceof Response ) {
return ( Response ) result ;
}
else {
throw authException ( UNKNOWN_RESPONSE_CLASS , "Invalid response class:" + result . getClass ( ) . getName ( ) ) ;
}
} catch ( Saml2Exception x ) {
throw authException ( MALFORMED_RESPONSE_DATA , x . getMessage ( ) , x ) ;
}
return username ;
}
private Assertion validateSaml2Response ( Saml2AuthenticationToken token ,
String recipient ,
Response samlResponse ) throws Saml2AuthenticationException {
//optional validation if the response contains a destination
if ( hasText ( samlResponse . getDestination ( ) ) & & ! recipient . equals ( samlResponse . getDestination ( ) ) ) {
throw authException ( INVALID_DESTINATION , "Invalid SAML response destination: " + samlResponse . getDestination ( ) ) ;
}
private List < Assertion > validateResponse ( Saml2AuthenticationToken token , Response response )
throws Saml2AuthenticationException {
String issuer = samlResponse . getIssuer ( ) . getValue ( ) ;
List < Assertion > validAssertions = new ArrayList < > ( ) ;
String issuer = response . getIssuer ( ) . getValue ( ) ;
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Validating SAML response from " + issuer ) ;
}
if ( ! hasText ( issuer ) | | ( ! issuer . equals ( token . getIdpEntityId ( ) ) ) ) {
String message = String . format ( "Response issuer '%s' doesn't match '%s'" , issuer , token . getIdpEntityId ( ) ) ;
throw authException ( INVALID_ISSUER , message ) ;
}
Saml2AuthenticationException lastValidationError = null ;
boolean responseSigned = hasValidSignature ( samlResponse , token ) ;
for ( Assertion a : samlResponse . getAssertions ( ) ) {
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Checking plain assertion validity " + a ) ;
}
try {
validateAssertion ( recipient , a , token , ! responseSigned ) ;
return a ;
} catch ( Saml2AuthenticationException e ) {
lastValidationError = e ;
}
List < Assertion > assertions = new ArrayList < > ( response . getAssertions ( ) ) ;
for ( EncryptedAssertion encryptedAssertion : response . getEncryptedAssertions ( ) ) {
Assertion assertion = decrypt ( token , encryptedAssertion ) ;
assertions . add ( assertion ) ;
}
for ( EncryptedAssertion ea : samlResponse . getEncryptedAssertions ( ) ) {
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Checking encrypted assertion validity " + ea ) ;
}
try {
Assertion a = decrypt ( token , ea ) ;
validateAssertion ( recipient , a , token , ! responseSigned ) ;
return a ;
} catch ( Saml2AuthenticationException e ) {
lastValidationError = e ;
}
}
if ( lastValidationError ! = null ) {
throw lastValidationError ;
}
else {
if ( assertions . isEmpty ( ) ) {
throw authException ( MALFORMED_RESPONSE_DATA , "No assertions found in response." ) ;
}
}
private boolean hasValidSignature ( SignableSAMLObject samlObject , Saml2AuthenticationToken token ) {
if ( ! samlObject . isSigned ( ) ) {
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "SAML object is not signed, no signatures found" ) ;
}
return false ;
if ( ! isSigned ( response , assertions ) ) {
throw authException ( INVALID_SIGNATURE , "Either the response or one of the assertions is unsigned. " +
"Please either sign the response or all of the assertions." ) ;
}
List < X509Certificate > verificationKeys = getVerificationCertificates ( token ) ;
if ( verificationKeys . isEmpty ( ) ) {
return false ;
}
SignatureTrustEngine signatureTrustEngine = buildSignatureTrustEngine ( token ) ;
for ( X509Certificate certificate : verificationKeys ) {
Credential credential = getVerificationCredential ( certificate ) ;
Map < String , Saml2AuthenticationException > validationExceptions = new HashMap < > ( ) ;
if ( response . isSigned ( ) ) {
SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator ( ) ;
try {
SignatureValidator . validate ( samlObject . getSignature ( ) , credential ) ;
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Valid signature found in SAML object:" + samlObject . getClass ( ) . getName ( ) ) ;
}
return true ;
profileValidator . validate ( response . getSignature ( ) ) ;
} catch ( Exception e ) {
validationExceptions . put ( INVALID_SIGNATURE , authException ( INVALID_SIGNATURE ,
"Invalid signature for SAML Response [" + response . getID ( ) + "]" , e ) ) ;
}
catch ( SignatureException ignored ) {
if ( logger . isTraceEnabled ( ) ) {
logger . trace ( "Signature validation failed with cert:" + certificate . toString ( ) , ignored ) ;
}
else if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Signature validation failed with cert:" + certificate . toString ( ) ) ;
try {
CriteriaSet criteriaSet = new CriteriaSet ( ) ;
criteriaSet . add ( new EvaluableEntityIDCredentialCriterion ( new EntityIdCriterion ( issuer ) ) ) ;
criteriaSet . add ( new EvaluableProtocolRoleDescriptorCriterion ( new ProtocolCriterion ( SAMLConstants . SAML20P_NS ) ) ) ;
criteriaSet . add ( new EvaluableUsageCredentialCriterion ( new UsageCriterion ( UsageType . SIGNING ) ) ) ;
if ( ! signatureTrustEngine . validate ( response . getSignature ( ) , criteriaSet ) ) {
validationExceptions . put ( INVALID_SIGNATURE , authException ( INVALID_SIGNATURE ,
"Invalid signature for SAML Response [" + response . getID ( ) + "]" ) ) ;
}
} catch ( Exception e ) {
validationExceptions . put ( INVALID_SIGNATURE , authException ( INVALID_SIGNATURE ,
"Invalid signature for SAML Response [" + response . getID ( ) + "]" , e ) ) ;
}
}
return false ;
}
private void validateAssertion ( String recipient , Assertion a , Saml2AuthenticationToken token , boolean signatureRequired ) {
SAML20AssertionValidator validator = getAssertionValidator ( token ) ;
Map < String , Object > validationParams = new HashMap < > ( ) ;
validationParams . put ( SAML2AssertionValidationParameters . SIGNATURE_REQUIRED , false ) ;
validationParams . put (
SAML2AssertionValidationParameters . CLOCK_SKEW ,
this . responseTimeValidationSkew . toMillis ( )
) ;
validationParams . put (
SAML2AssertionValidationParameters . COND_VALID_AUDIENCES ,
singleton ( token . getLocalSpEntityId ( ) )
) ;
if ( hasText ( recipient ) ) {
validationParams . put ( SAML2AssertionValidationParameters . SC_VALID_RECIPIENTS , singleton ( recipient ) ) ;
String destination = response . getDestination ( ) ;
if ( StringUtils . hasText ( destination ) & & ! destination . equals ( token . getRecipientUri ( ) ) ) {
String message = "Invalid destination [" + destination + "] for SAML response [" + response . getID ( ) + "]" ;
validationExceptions . put ( INVALID_DESTINATION , authException ( INVALID_DESTINATION , message ) ) ;
}
if ( signatureRequired & & ! hasValidSignature ( a , token ) ) {
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( format ( "Assertion [%s] does not a valid signature." , a . getID ( ) ) ) ;
}
throw authException ( Saml2ErrorCodes . INVALID_SIGNATURE , "Assertion doesn't have a valid signature." ) ;
}
//ensure that OpenSAML doesn't attempt signature validation, already performed
a . setSignature ( null ) ;
//ensure that we don't validate IP addresses as part of our validation gh-7514
if ( a . getSubject ( ) ! = null ) {
for ( SubjectConfirmation sc : a . getSubject ( ) . getSubjectConfirmations ( ) ) {
if ( sc . getSubjectConfirmationData ( ) ! = null ) {
sc . getSubjectConfirmationData ( ) . setAddress ( null ) ;
}
}
if ( ! StringUtils . hasText ( issuer ) | | ! issuer . equals ( token . getIdpEntityId ( ) ) ) {
String message = String . format ( "Invalid issuer [%s] for SAML response [%s]" , issuer , response . getID ( ) ) ;
validationExceptions . put ( INVALID_ISSUER , authException ( INVALID_ISSUER , message ) ) ;
}
//remainder of assertion validation
ValidationContext vctx = new ValidationContext ( validationParams ) ;
try {
ValidationResult result = validator . validate ( a , vctx ) ;
boolean valid = result . equals ( ValidationResult . VALID ) ;
if ( ! valid ) {
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( format ( "Failed to validate assertion from %s" , token . getIdpEntityId ( ) ) ) ;
}
throw authException ( Saml2ErrorCodes . INVALID_ASSERTION , vctx . getValidationFailureMessage ( ) ) ;
}
SAML20AssertionValidator validator = buildSamlAssertionValidator ( signatureTrustEngine ) ;
ValidationContext context = buildValidationContext ( token , response ) ;
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Validating " + assertions . size ( ) + " assertions" ) ;
}
catch ( AssertionValidationException e ) {
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Failed to validate assertion:" , e ) ;
for ( Assertion assertion : assertions ) {
if ( logger . isTraceEnabled ( ) ) {
logger . trace ( "Validating assertion " + assertion . getID ( ) ) ;
}
try {
validAssertions . add ( validateAssertion ( assertion , validator , context ) ) ;
} catch ( Exception e ) {
String message = String . format ( "Invalid assertion [%s] for SAML response [%s]" , assertion . getID ( ) , response . getID ( ) ) ;
validationExceptions . put ( INVALID_ASSERTION , authException ( INVALID_ASSERTION , message , e ) ) ;
}
throw authException ( Saml2ErrorCodes . INTERNAL_VALIDATION_ERROR , e . getMessage ( ) , e ) ;
}
}
private Response getSaml2Response ( Saml2AuthenticationToken token ) throws Saml2Exception , Saml2AuthenticationException {
try {
Object result = this . saml . resolve ( token . getSaml2Response ( ) ) ;
if ( result instanceof Response ) {
return ( Response ) result ;
if ( validationExceptions . isEmpty ( ) ) {
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Successfully validated SAML Response [" + response . getID ( ) + "]" ) ;
}
else {
throw authException ( UNKNOWN_RESPONSE_CLASS , "Invalid response class:" + result . getClass ( ) . getName ( ) ) ;
} else {
if ( logger . isTraceEnabled ( ) ) {
logger . debug ( "Found " + validationExceptions . size ( ) + " validation errors in SAML response [" + response . getID ( ) + "]: " +
validationExceptions . values ( ) ) ;
} else if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Found " + validationExceptions . size ( ) + " validation errors in SAML response [" + response . getID ( ) + "]" ) ;
}
} catch ( Saml2Exception x ) {
throw authException ( MALFORMED_RESPONSE_DATA , x . getMessage ( ) , x ) ;
}
}
if ( ! validationExceptions . isEmpty ( ) ) {
throw validationExceptions . values ( ) . iterator ( ) . next ( ) ;
}
if ( validAssertions . isEmpty ( ) ) {
throw authException ( MALFORMED_RESPONSE_DATA , "No valid assertions found in response." ) ;
}
private Saml2Error validationError ( String code , String description ) {
return new Saml2Error (
code ,
description
) ;
return validAssertions ;
}
private Saml2AuthenticationException authException ( String code , String description ) throws Saml2AuthenticationException {
return new Saml2AuthenticationException (
validationError ( code , description )
) ;
}
private boolean isSigned ( Response samlResponse , List < Assertion > assertions ) {
if ( samlResponse . isSigned ( ) ) {
return true ;
}
for ( Assertion assertion : assertions ) {
if ( ! assertion . isSigned ( ) ) {
return false ;
}
}
private Saml2AuthenticationException authException ( String code , String description , Exception cause ) throws Saml2AuthenticationException {
return new Saml2AuthenticationException (
validationError ( code , description ) ,
cause
) ;
return true ;
}
private SAML20AssertionValidator getAssertionValidator ( Saml2AuthenticationToken provider ) {
List < ConditionValidator > conditions = Collections . singletonList ( new AudienceRestrictionConditionValidator ( ) ) ;
BearerSubjectConfirmationValidator subjectConfirmationValidator = new BearerSubjectConfirmationValidator ( ) ;
List < SubjectConfirmationValidator > subjects = Collections . singletonList ( subjectConfirmationValidator ) ;
List < StatementValidator > statements = Collections . emptyList ( ) ;
private SignatureTrustEngine buildSignatureTrustEngine ( Saml2AuthenticationToken token ) {
Set < Credential > credentials = new HashSet < > ( ) ;
for ( X509Certificate key : getVerificationCertificates ( provider ) ) {
Credential cred = getVerificationCredential ( key ) ;
for ( X509Certificate key : getVerificationCertificates ( token ) ) {
BasicX509Credential cred = new BasicX509Credential ( key ) ;
cred . setUsageType ( UsageType . SIGNING ) ;
cred . setEntityId ( token . getIdpEntityId ( ) ) ;
credentials . add ( cred ) ;
}
CredentialResolver credentialsResolver = new CollectionCredentialResolver ( credentials ) ;
SignatureTrustEngine signatureTrustEngine = new ExplicitKeySignatureTrustEngine (
return new ExplicitKeySignatureTrustEngine (
credentialsResolver ,
DefaultSecurityConfigurationBootstrap . buildBasicInlineKeyInfoCredentialResolver ( )
) ;
SignaturePrevalidator signaturePrevalidator = new SAMLSignatureProfileValidator ( ) ;
return new SAML20AssertionValidator (
conditions ,
subjects ,
statements ,
signatureTrustEngine ,
signaturePrevalidator
) ;
}
private Credential getVerificationCredential ( X509Certificate certificate ) {
return CredentialSupport . getSimpleCredential ( certificate , null ) ;
private ValidationContext buildValidationContext ( Saml2AuthenticationToken token , Response response ) {
Map < String , Object > validationParams = new HashMap < > ( ) ;
validationParams . put ( SIGNATURE_REQUIRED , ! response . isSigned ( ) ) ;
validationParams . put ( CLOCK_SKEW , this . responseTimeValidationSkew . toMillis ( ) ) ;
validationParams . put ( COND_VALID_AUDIENCES , singleton ( token . getLocalSpEntityId ( ) ) ) ;
if ( StringUtils . hasText ( token . getRecipientUri ( ) ) ) {
validationParams . put ( SAML2AssertionValidationParameters . SC_VALID_RECIPIENTS , singleton ( token . getRecipientUri ( ) ) ) ;
}
return new ValidationContext ( validationParams ) ;
}
private Decrypter getDecrypter ( Saml2X509Credential key ) {
Credential credential = CredentialSupport . getSimpleCredential ( key . getCertificate ( ) , key . getPrivateKey ( ) ) ;
KeyInfoCredentialResolver resolver = new StaticKeyInfoCredentialResolver ( credential ) ;
Decrypter decrypter = new Decrypter ( null , resolver , this . saml . getEncryptedKeyResolver ( ) ) ;
decrypter . setRootInNewDocument ( true ) ;
return decrypter ;
private SAML20AssertionValidator buildSamlAssertionValidator ( SignatureTrustEngine signatureTrustEngine ) {
return new SAML20AssertionValidator (
this . conditions , this . subjects , this . statements , signatureTrustEngine , this . signaturePrevalidator ) ;
}
private Assertion validateAssertion ( Assertion assertion ,
SAML20AssertionValidator validator , ValidationContext context ) {
ValidationResult result ;
try {
result = validator . validate ( assertion , context ) ;
} catch ( Exception e ) {
throw new Saml2Exception ( "An error occurred while validation the assertion" , e ) ;
}
if ( result ! = ValidationResult . VALID ) {
throw new Saml2Exception ( "An error occurred while validating the assertion: " +
context . getValidationFailureMessage ( ) ) ;
}
return assertion ;
}
private Assertion decrypt ( Saml2AuthenticationToken token , EncryptedAssertion assertion )
throws Saml2AuthenticationException {
Saml2AuthenticationException last = null ;
List < Saml2X509Credential > decryptionCredentials = getDecryptionCredentials ( token ) ;
if ( decryptionCredentials . isEmpty ( ) ) {
@ -451,22 +425,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
@@ -451,22 +425,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
throw last ;
}
private NameID decrypt ( Saml2AuthenticationToken token , EncryptedID assertion ) throws Saml2AuthenticationException {
Saml2AuthenticationException last = null ;
List < Saml2X509Credential > decryptionCredentials = getDecryptionCredentials ( token ) ;
if ( decryptionCredentials . isEmpty ( ) ) {
throw authException ( DECRYPTION_ERROR , "No valid decryption credentials found." ) ;
}
for ( Saml2X509Credential key : decryptionCredentials ) {
Decrypter decrypter = getDecrypter ( key ) ;
try {
return ( NameID ) decrypter . decrypt ( assertion ) ;
}
catch ( DecryptionException e ) {
last = authException ( DECRYPTION_ERROR , e . getMessage ( ) , e ) ;
}
}
throw last ;
private Decrypter getDecrypter ( Saml2X509Credential key ) {
Credential credential = CredentialSupport . getSimpleCredential ( key . getCertificate ( ) , key . getPrivateKey ( ) ) ;
KeyInfoCredentialResolver resolver = new StaticKeyInfoCredentialResolver ( credential ) ;
Decrypter decrypter = new Decrypter ( null , resolver , this . saml . getEncryptedKeyResolver ( ) ) ;
decrypter . setRootInNewDocument ( true ) ;
return decrypter ;
}
private List < Saml2X509Credential > getDecryptionCredentials ( Saml2AuthenticationToken token ) {
@ -488,4 +452,61 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
@@ -488,4 +452,61 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
}
return result ;
}
private String getUsername ( Saml2AuthenticationToken token , Assertion assertion )
throws Saml2AuthenticationException {
String username = null ;
Subject subject = assertion . getSubject ( ) ;
if ( subject = = null ) {
throw authException ( SUBJECT_NOT_FOUND , "Assertion [" + assertion . getID ( ) + "] is missing a subject" ) ;
}
if ( subject . getNameID ( ) ! = null ) {
username = subject . getNameID ( ) . getValue ( ) ;
}
else if ( subject . getEncryptedID ( ) ! = null ) {
NameID nameId = decrypt ( token , subject . getEncryptedID ( ) ) ;
username = nameId . getValue ( ) ;
}
if ( username = = null ) {
throw authException ( USERNAME_NOT_FOUND , "Assertion [" + assertion . getID ( ) + "] is missing a user identifier" ) ;
}
return username ;
}
private NameID decrypt ( Saml2AuthenticationToken token , EncryptedID assertion )
throws Saml2AuthenticationException {
Saml2AuthenticationException last = null ;
List < Saml2X509Credential > decryptionCredentials = getDecryptionCredentials ( token ) ;
if ( decryptionCredentials . isEmpty ( ) ) {
throw authException ( DECRYPTION_ERROR , "No valid decryption credentials found." ) ;
}
for ( Saml2X509Credential key : decryptionCredentials ) {
Decrypter decrypter = getDecrypter ( key ) ;
try {
return ( NameID ) decrypter . decrypt ( assertion ) ;
}
catch ( DecryptionException e ) {
last = authException ( DECRYPTION_ERROR , e . getMessage ( ) , e ) ;
}
}
throw last ;
}
private Saml2Error validationError ( String code , String description ) {
return new Saml2Error ( code , description ) ;
}
private Saml2AuthenticationException authException ( String code , String description )
throws Saml2AuthenticationException {
return new Saml2AuthenticationException ( validationError ( code , description ) ) ;
}
private Saml2AuthenticationException authException ( String code , String description , Exception cause )
throws Saml2AuthenticationException {
return new Saml2AuthenticationException ( validationError ( code , description ) , cause ) ;
}
}