@ -157,6 +157,7 @@ import org.springframework.util.StringUtils;
@@ -157,6 +157,7 @@ import org.springframework.util.StringUtils;
* asserting party , IDP , verification certificates .
* < / p >
*
* @author Ryan Cassar
* @since 5 . 2
* @see < a href =
* "https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38" > SAML 2
@ -211,6 +212,32 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
@@ -211,6 +212,32 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
private Converter < Saml2AuthenticationToken , Decrypter > decrypterConverter = new DecrypterConverter ( ) ;
private Consumer < ResponseToken > assertionDecrypter = ( responseToken ) - > {
List < Assertion > assertions = new ArrayList < > ( ) ;
for ( EncryptedAssertion encryptedAssertion : responseToken . getResponse ( ) . getEncryptedAssertions ( ) ) {
try {
Decrypter decrypter = this . decrypterConverter . convert ( responseToken . getToken ( ) ) ;
Assertion assertion = decrypter . decrypt ( encryptedAssertion ) ;
assertions . add ( assertion ) ;
}
catch ( DecryptionException ex ) {
throw createAuthenticationException ( Saml2ErrorCodes . DECRYPTION_ERROR , ex . getMessage ( ) , ex ) ;
}
}
responseToken . getResponse ( ) . getAssertions ( ) . addAll ( assertions ) ;
} ;
private Consumer < ResponseToken > principalDecrypter = ( responseToken ) - > {
try {
Decrypter decrypter = this . decrypterConverter . convert ( responseToken . getToken ( ) ) ;
Assertion assertion = CollectionUtils . firstElement ( responseToken . getResponse ( ) . getAssertions ( ) ) ;
assertion . getSubject ( ) . setNameID ( ( NameID ) decrypter . decrypt ( assertion . getSubject ( ) . getEncryptedID ( ) ) ) ;
}
catch ( DecryptionException ex ) {
throw createAuthenticationException ( Saml2ErrorCodes . DECRYPTION_ERROR , ex . getMessage ( ) , ex ) ;
}
} ;
/ * *
* Creates an { @link OpenSamlAuthenticationProvider }
* /
@ -332,6 +359,52 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
@@ -332,6 +359,52 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
this . responseTimeValidationSkew = responseTimeValidationSkew ;
}
/ * *
* Sets the assertion response custom decrypter .
*
* You can use this method like so :
*
* < pre >
* YourDecrypter decrypter = // ... your custom decrypter
*
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider ( ) ;
* provider . setAssertionDecrypter ( ( responseToken ) - > {
* Response response = responseToken . getResponse ( ) ;
* EncryptedAssertion encrypted = response . getEncryptedAssertions ( ) . get ( 0 ) ;
* Assertion assertion = decrypter . decrypt ( encrypted ) ;
* response . getAssertions ( ) . add ( assertion ) ;
* } ) ;
* < / pre >
* @param assertionDecrypter response token consumer
* /
public void setAssertionDecrypter ( Consumer < ResponseToken > assertionDecrypter ) {
Assert . notNull ( assertionDecrypter , "Consumer<ResponseToken> required" ) ;
this . assertionDecrypter = assertionDecrypter ;
}
/ * *
* Sets the principal custom decrypter .
*
* You can use this method like so :
*
* < pre >
* YourDecrypter decrypter = // ... your custom decrypter
*
* OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider ( ) ;
* provider . setAssertionDecrypter ( ( responseToken ) - > {
* Assertion assertion = CollectionUtils . firstElement ( responseToken . getResponse ( ) . getAssertions ( ) ) ;
* EncryptedID encrypted = assertion . getSubject ( ) . getEncryptedID ( ) ;
* NameID name = decrypter . decrypt ( encrypted ) ;
* assertion . getSubject ( ) . setNameID ( name )
* } ) ;
* < / pre >
* @param principalDecrypter response token consumer
* /
public void setPrincipalDecrypter ( Consumer < ResponseToken > principalDecrypter ) {
Assert . notNull ( principalDecrypter , "Consumer<ResponseToken> required" ) ;
this . principalDecrypter = principalDecrypter ;
}
/ * *
* Construct a default strategy for validating each SAML 2 . 0 Assertion and associated
* { @link Authentication } token
@ -429,8 +502,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
@@ -429,8 +502,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
boolean responseSigned = response . isSigned ( ) ;
Saml2ResponseValidatorResult result = validateResponse ( token , response ) ;
Decrypter decrypter = this . decrypterConverter . convert ( token ) ;
List < Assertion > assertions = decryptAssertions ( decrypter , response ) ;
ResponseToken responseToken = new ResponseToken ( response , token ) ;
List < Assertion > assertions = decryptAssertions ( responseToken ) ;
if ( ! isSigned ( responseSigned , assertions ) ) {
String description = "Either the response or one of the assertions is unsigned. "
+ "Please either sign the response or all of the assertions." ;
@ -439,7 +512,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
@@ -439,7 +512,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
result = result . concat ( validateAssertions ( token , response ) ) ;
Assertion firstAssertion = CollectionUtils . firstElement ( response . getAssertions ( ) ) ;
NameID nameId = decryptPrincipal ( decrypter , firstAssertio n) ;
NameID nameId = decryptPrincipal ( responseToke n) ;
if ( nameId = = null | | nameId . getValue ( ) = = null ) {
Saml2Error error = new Saml2Error ( Saml2ErrorCodes . SUBJECT_NOT_FOUND ,
"Assertion [" + firstAssertion . getID ( ) + "] is missing a subject" ) ;
@ -511,19 +584,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
@@ -511,19 +584,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return Saml2ResponseValidatorResult . failure ( errors ) ;
}
private List < Assertion > decryptAssertions ( Decrypter decrypter , Response response ) {
List < Assertion > assertions = new ArrayList < > ( ) ;
for ( EncryptedAssertion encryptedAssertion : response . getEncryptedAssertions ( ) ) {
try {
Assertion assertion = decrypter . decrypt ( encryptedAssertion ) ;
assertions . add ( assertion ) ;
}
catch ( DecryptionException ex ) {
throw createAuthenticationException ( Saml2ErrorCodes . DECRYPTION_ERROR , ex . getMessage ( ) , ex ) ;
}
}
response . getAssertions ( ) . addAll ( assertions ) ;
return response . getAssertions ( ) ;
private List < Assertion > decryptAssertions ( ResponseToken response ) {
this . assertionDecrypter . accept ( response ) ;
return response . getResponse ( ) . getAssertions ( ) ;
}
private Saml2ResponseValidatorResult validateAssertions ( Saml2AuthenticationToken token , Response response ) {
@ -567,21 +630,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
@@ -567,21 +630,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
return true ;
}
private NameID decryptPrincipal ( Decrypter decrypter , Assertion assertion ) {
private NameID decryptPrincipal ( ResponseToken responseToken ) {
Assertion assertion = CollectionUtils . firstElement ( responseToken . getResponse ( ) . getAssertions ( ) ) ;
if ( assertion . getSubject ( ) = = null ) {
return null ;
}
if ( assertion . getSubject ( ) . getEncryptedID ( ) = = null ) {
return assertion . getSubject ( ) . getNameID ( ) ;
}
try {
NameID nameId = ( NameID ) decrypter . decrypt ( assertion . getSubject ( ) . getEncryptedID ( ) ) ;
assertion . getSubject ( ) . setNameID ( nameId ) ;
return nameId ;
}
catch ( DecryptionException ex ) {
throw createAuthenticationException ( Saml2ErrorCodes . DECRYPTION_ERROR , ex . getMessage ( ) , ex ) ;
}
this . principalDecrypter . accept ( responseToken ) ;
return assertion . getSubject ( ) . getNameID ( ) ;
}
private static Map < String , List < Object > > getAssertionAttributes ( Assertion assertion ) {