@ -65,7 +65,6 @@ import org.opensaml.saml.saml2.core.EncryptedAssertion;
@@ -65,7 +65,6 @@ import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.OneTimeUse ;
import org.opensaml.saml.saml2.core.Response ;
import org.opensaml.saml.saml2.core.StatusCode ;
import org.opensaml.saml.saml2.core.Subject ;
import org.opensaml.saml.saml2.core.SubjectConfirmation ;
import org.opensaml.saml.saml2.core.SubjectConfirmationData ;
import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller ;
@ -146,6 +145,13 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
@@ -146,6 +145,13 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
private final ResponseUnmarshaller responseUnmarshaller ;
private static final AuthnRequestUnmarshaller authnRequestUnmarshaller ;
static {
XMLObjectProviderRegistry registry = ConfigurationService . get ( XMLObjectProviderRegistry . class ) ;
authnRequestUnmarshaller = ( AuthnRequestUnmarshaller ) registry . getUnmarshallerFactory ( )
. getUnmarshaller ( AuthnRequest . DEFAULT_ELEMENT_NAME ) ;
}
private final ParserPool parserPool ;
private final Converter < ResponseToken , Saml2ResponseValidatorResult > responseSignatureValidator = createDefaultResponseSignatureValidator ( ) ;
@ -355,37 +361,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
@@ -355,37 +361,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
this . responseAuthenticationConverter = responseAuthenticationConverter ;
}
private static Saml2ResponseValidatorResult validateInResponseTo ( AbstractSaml2AuthenticationRequest storedRequest ,
String inResponseTo ) {
if ( ! StringUtils . hasText ( inResponseTo ) ) {
return Saml2ResponseValidatorResult . success ( ) ;
}
AuthnRequest request ;
try {
request = parseRequest ( storedRequest ) ;
}
catch ( Exception ex ) {
String message = "The stored AuthNRequest could not be properly deserialized [" + ex . getMessage ( ) + "]" ;
return Saml2ResponseValidatorResult
. failure ( new Saml2Error ( Saml2ErrorCodes . MALFORMED_REQUEST_DATA , message ) ) ;
}
if ( request = = null ) {
String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]"
+ " but no saved AuthNRequest request was found" ;
return Saml2ResponseValidatorResult
. failure ( new Saml2Error ( Saml2ErrorCodes . INVALID_IN_RESPONSE_TO , message ) ) ;
}
else if ( ! request . getID ( ) . equals ( inResponseTo ) ) {
String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the "
+ "AuthNRequest [" + request . getID ( ) + "]" ;
return Saml2ResponseValidatorResult
. failure ( new Saml2Error ( Saml2ErrorCodes . INVALID_IN_RESPONSE_TO , message ) ) ;
}
else {
return Saml2ResponseValidatorResult . success ( ) ;
}
}
/ * *
* Construct a default strategy for validating the SAML 2 . 0 Response
* @return the default response validator strategy
@ -428,6 +403,27 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
@@ -428,6 +403,27 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
} ;
}
private static Saml2ResponseValidatorResult validateInResponseTo ( AbstractSaml2AuthenticationRequest storedRequest ,
String inResponseTo ) {
if ( ! StringUtils . hasText ( inResponseTo ) ) {
return Saml2ResponseValidatorResult . success ( ) ;
}
AuthnRequest request = parseRequest ( storedRequest ) ;
if ( request = = null ) {
String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]"
+ " but no saved authentication request was found" ;
return Saml2ResponseValidatorResult
. failure ( new Saml2Error ( Saml2ErrorCodes . INVALID_IN_RESPONSE_TO , message ) ) ;
}
if ( ! inResponseTo . equals ( request . getID ( ) ) ) {
String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the "
+ "authentication request [" + request . getID ( ) + "]" ;
return Saml2ResponseValidatorResult
. failure ( new Saml2Error ( Saml2ErrorCodes . INVALID_IN_RESPONSE_TO , message ) ) ;
}
return Saml2ResponseValidatorResult . success ( ) ;
}
/ * *
* Construct a default strategy for validating each SAML 2 . 0 Assertion and associated
* { @link Authentication } token
@ -522,28 +518,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
@@ -522,28 +518,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
}
}
private static AuthnRequest parseRequest ( AbstractSaml2AuthenticationRequest request ) throws Exception {
if ( request = = null ) {
return null ;
}
String samlRequest = request . getSamlRequest ( ) ;
if ( ! StringUtils . hasText ( samlRequest ) ) {
return null ;
}
if ( request . getBinding ( ) = = Saml2MessageBinding . REDIRECT ) {
samlRequest = Saml2Utils . samlInflate ( Saml2Utils . samlDecode ( samlRequest ) ) ;
}
else {
samlRequest = new String ( Saml2Utils . samlDecode ( samlRequest ) , StandardCharsets . UTF_8 ) ;
}
Document document = XMLObjectProviderRegistrySupport . getParserPool ( )
. parse ( new ByteArrayInputStream ( samlRequest . getBytes ( StandardCharsets . UTF_8 ) ) ) ;
Element element = document . getDocumentElement ( ) ;
AuthnRequestUnmarshaller unmarshaller = ( AuthnRequestUnmarshaller ) XMLObjectProviderRegistrySupport
. getUnmarshallerFactory ( ) . getUnmarshaller ( AuthnRequest . DEFAULT_ELEMENT_NAME ) ;
return ( AuthnRequest ) unmarshaller . unmarshall ( element ) ;
}
private void process ( Saml2AuthenticationToken token , Response response ) {
String issuer = response . getIssuer ( ) . getValue ( ) ;
this . logger . debug ( LogMessage . format ( "Processing SAML response from %s" , issuer ) ) ;
@ -748,40 +722,18 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
@@ -748,40 +722,18 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
} ;
}
private static boolean assertionContainsInResponseTo ( Assertion assertion ) {
Subject subject = ( assertion ! = null ) ? assertion . getSubject ( ) : null ;
List < SubjectConfirmation > confirmations = ( subject ! = null ) ? subject . getSubjectConfirmations ( )
: new ArrayList < > ( ) ;
return confirmations . stream ( ) . filter ( ( confirmation ) - > {
SubjectConfirmationData confirmationData = confirmation . getSubjectConfirmationData ( ) ;
return confirmationData ! = null & & StringUtils . hasText ( confirmationData . getInResponseTo ( ) ) ;
} ) . findFirst ( ) . orElse ( null ) ! = null ;
}
private static void addRequestIdToValidationContext ( AbstractSaml2AuthenticationRequest storedRequest ,
Map < String , Object > context ) {
String requestId = null ;
try {
AuthnRequest request = parseRequest ( storedRequest ) ;
requestId = ( request ! = null ) ? request . getID ( ) : null ;
}
catch ( Exception ex ) {
}
if ( StringUtils . hasText ( requestId ) ) {
context . put ( SAML2AssertionValidationParameters . SC_VALID_IN_RESPONSE_TO , requestId ) ;
}
}
private static ValidationContext createValidationContext ( AssertionToken assertionToken ,
Consumer < Map < String , Object > > paramsConsumer ) {
RelyingPartyRegistration relyingPartyRegistration = assertionToken . token . getRelyingPartyRegistration ( ) ;
Saml2AuthenticationToken token = assertionToken . token ;
RelyingPartyRegistration relyingPartyRegistration = token . getRelyingPartyRegistration ( ) ;
String audience = relyingPartyRegistration . getEntityId ( ) ;
String recipient = relyingPartyRegistration . getAssertionConsumerServiceLocation ( ) ;
String assertingPartyEntityId = relyingPartyRegistration . getAssertingPartyDetails ( ) . getEntityId ( ) ;
Map < String , Object > params = new HashMap < > ( ) ;
Assertion assertion = assertionToken . getAssertion ( ) ;
if ( assertionContainsInResponseTo ( assertion ) ) {
addRequestIdToValidationContext ( assertionToken . token . getAuthenticationRequest ( ) , params ) ;
String requestId = getAuthnRequestId ( token . getAuthenticationRequest ( ) ) ;
params . put ( SAML2AssertionValidationParameters . SC_VALID_IN_RESPONSE_TO , requestId ) ;
}
params . put ( SAML2AssertionValidationParameters . COND_VALID_AUDIENCES , Collections . singleton ( audience ) ) ;
params . put ( SAML2AssertionValidationParameters . SC_VALID_RECIPIENTS , Collections . singleton ( recipient ) ) ;
@ -790,6 +742,56 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
@@ -790,6 +742,56 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
return new ValidationContext ( params ) ;
}
private static boolean assertionContainsInResponseTo ( Assertion assertion ) {
if ( assertion . getSubject ( ) = = null ) {
return false ;
}
for ( SubjectConfirmation confirmation : assertion . getSubject ( ) . getSubjectConfirmations ( ) ) {
SubjectConfirmationData confirmationData = confirmation . getSubjectConfirmationData ( ) ;
if ( confirmationData = = null ) {
continue ;
}
if ( StringUtils . hasText ( confirmationData . getInResponseTo ( ) ) ) {
return true ;
}
}
return false ;
}
private static String getAuthnRequestId ( AbstractSaml2AuthenticationRequest serialized ) {
AuthnRequest request = parseRequest ( serialized ) ;
if ( request = = null ) {
return null ;
}
return request . getID ( ) ;
}
private static AuthnRequest parseRequest ( AbstractSaml2AuthenticationRequest request ) {
if ( request = = null ) {
return null ;
}
String samlRequest = request . getSamlRequest ( ) ;
if ( ! StringUtils . hasText ( samlRequest ) ) {
return null ;
}
if ( request . getBinding ( ) = = Saml2MessageBinding . REDIRECT ) {
samlRequest = Saml2Utils . samlInflate ( Saml2Utils . samlDecode ( samlRequest ) ) ;
}
else {
samlRequest = new String ( Saml2Utils . samlDecode ( samlRequest ) , StandardCharsets . UTF_8 ) ;
}
try {
Document document = XMLObjectProviderRegistrySupport . getParserPool ( )
. parse ( new ByteArrayInputStream ( samlRequest . getBytes ( StandardCharsets . UTF_8 ) ) ) ;
Element element = document . getDocumentElement ( ) ;
return ( AuthnRequest ) authnRequestUnmarshaller . unmarshall ( element ) ;
}
catch ( Exception ex ) {
String message = "Failed to deserialize associated authentication request [" + ex . getMessage ( ) + "]" ;
throw createAuthenticationException ( Saml2ErrorCodes . MALFORMED_REQUEST_DATA , message , ex ) ;
}
}
private static class SAML20AssertionValidators {
private static final Collection < ConditionValidator > conditions = new ArrayList < > ( ) ;