@ -19,6 +19,7 @@ package org.springframework.security.saml2.provider.service.authentication;
@@ -19,6 +19,7 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayOutputStream ;
import java.io.IOException ;
import java.io.ObjectOutputStream ;
import java.nio.charset.StandardCharsets ;
import java.time.Duration ;
import java.time.Instant ;
import java.util.Arrays ;
@ -46,6 +47,7 @@ import org.opensaml.saml.saml2.core.Assertion;
@@ -46,6 +47,7 @@ import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute ;
import org.opensaml.saml.saml2.core.AttributeStatement ;
import org.opensaml.saml.saml2.core.AttributeValue ;
import org.opensaml.saml.saml2.core.AuthnRequest ;
import org.opensaml.saml.saml2.core.Conditions ;
import org.opensaml.saml.saml2.core.EncryptedAssertion ;
import org.opensaml.saml.saml2.core.EncryptedAttribute ;
@ -74,6 +76,7 @@ import org.springframework.security.saml2.core.TestSaml2X509Credentials;
@@ -74,6 +76,7 @@ import org.springframework.security.saml2.core.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken ;
import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSamlObjects.CustomOpenSamlObject ;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration ;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding ;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations ;
import org.springframework.util.StringUtils ;
@ -217,6 +220,111 @@ public class OpenSaml4AuthenticationProviderTests {
@@ -217,6 +220,111 @@ public class OpenSaml4AuthenticationProviderTests {
this . provider . authenticate ( token ) ;
}
@Test
public void evaluateInResponseToSucceedsWhenInResponseToInResponseAndAssertionsMatchRequestID ( ) {
Response response = response ( ) ;
response . setInResponseTo ( "SAML2" ) ;
response . getAssertions ( ) . add ( signed ( assertion ( "SAML2" ) ) ) ;
response . getAssertions ( ) . add ( signed ( assertion ( "SAML2" ) ) ) ;
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest ( "SAML2" ,
Saml2MessageBinding . POST , false ) ;
Saml2AuthenticationToken token = token ( response , verifying ( registration ( ) ) , mockAuthenticationRequest ) ;
this . provider . authenticate ( token ) ;
}
@Test
public void evaluateInResponseToSucceedsWhenInResponseToInAssertionOnlyMatchRequestID ( ) {
Response response = response ( ) ;
response . getAssertions ( ) . add ( signed ( assertion ( ) ) ) ;
response . getAssertions ( ) . add ( signed ( assertion ( "SAML2" ) ) ) ;
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest ( "SAML2" ,
Saml2MessageBinding . POST , false ) ;
Saml2AuthenticationToken token = token ( response , verifying ( registration ( ) ) , mockAuthenticationRequest ) ;
this . provider . authenticate ( token ) ;
}
@Test
public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndCorruptedStoredRequest ( ) {
Response response = response ( ) ;
response . getAssertions ( ) . add ( signed ( assertion ( ) ) ) ;
response . getAssertions ( ) . add ( signed ( assertion ( "SAML2" ) ) ) ;
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest ( "SAML2" ,
Saml2MessageBinding . POST , true ) ;
Saml2AuthenticationToken token = token ( response , verifying ( registration ( ) ) , mockAuthenticationRequest ) ;
assertThatExceptionOfType ( Saml2AuthenticationException . class )
. isThrownBy ( ( ) - > this . provider . authenticate ( token ) ) . withStackTraceContaining ( "invalid_assertion" ) ;
}
@Test
public void evaluateInResponseToFailsWhenInResponseToInAssertionMismatchWithRequestID ( ) {
Response response = response ( ) ;
response . setInResponseTo ( "SAML2" ) ;
response . getAssertions ( ) . add ( signed ( assertion ( "SAML2" ) ) ) ;
response . getAssertions ( ) . add ( signed ( assertion ( "BAD" ) ) ) ;
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest ( "SAML2" ,
Saml2MessageBinding . POST , false ) ;
Saml2AuthenticationToken token = token ( response , verifying ( registration ( ) ) , mockAuthenticationRequest ) ;
assertThatExceptionOfType ( Saml2AuthenticationException . class )
. isThrownBy ( ( ) - > this . provider . authenticate ( token ) ) . withStackTraceContaining ( "invalid_assertion" ) ;
}
@Test
public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndMismatchWithRequestID ( ) {
Response response = response ( ) ;
response . getAssertions ( ) . add ( signed ( assertion ( ) ) ) ;
response . getAssertions ( ) . add ( signed ( assertion ( "BAD" ) ) ) ;
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest ( "SAML2" ,
Saml2MessageBinding . POST , false ) ;
Saml2AuthenticationToken token = token ( response , verifying ( registration ( ) ) , mockAuthenticationRequest ) ;
assertThatExceptionOfType ( Saml2AuthenticationException . class )
. isThrownBy ( ( ) - > this . provider . authenticate ( token ) ) . withStackTraceContaining ( "invalid_assertion" ) ;
}
@Test
public void evaluateInResponseToFailsWhenInResponseInToResponseMismatchWithRequestID ( ) {
Response response = response ( ) ;
response . setInResponseTo ( "BAD" ) ;
response . getAssertions ( ) . add ( signed ( assertion ( "SAML2" ) ) ) ;
response . getAssertions ( ) . add ( signed ( assertion ( "SAML2" ) ) ) ;
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest ( "SAML2" ,
Saml2MessageBinding . POST , false ) ;
Saml2AuthenticationToken token = token ( response , verifying ( registration ( ) ) , mockAuthenticationRequest ) ;
assertThatExceptionOfType ( Saml2AuthenticationException . class )
. isThrownBy ( ( ) - > this . provider . authenticate ( token ) ) . withStackTraceContaining ( "invalid_in_response_to" ) ;
}
@Test
public void evaluateInResponseToFailsWhenInResponseInToResponseAndCorruptedStoredRequest ( ) {
Response response = response ( ) ;
response . setInResponseTo ( "SAML2" ) ;
response . getAssertions ( ) . add ( signed ( assertion ( ) ) ) ;
response . getAssertions ( ) . add ( signed ( assertion ( ) ) ) ;
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest ( "SAML2" ,
Saml2MessageBinding . POST , true ) ;
Saml2AuthenticationToken token = token ( response , verifying ( registration ( ) ) , mockAuthenticationRequest ) ;
assertThatExceptionOfType ( Saml2AuthenticationException . class )
. isThrownBy ( ( ) - > this . provider . authenticate ( token ) ) . withStackTraceContaining ( "malformed_request_data" ) ;
}
@Test
public void evaluateInResponseToFailsWhenInResponseToInResponseButNoSavedRequest ( ) {
Response response = response ( ) ;
response . setInResponseTo ( "BAD" ) ;
Saml2AuthenticationToken token = token ( response , verifying ( registration ( ) ) ) ;
assertThatExceptionOfType ( Saml2AuthenticationException . class )
. isThrownBy ( ( ) - > this . provider . authenticate ( token ) ) . withStackTraceContaining ( "invalid_in_response_to" ) ;
}
@Test
public void evaluateInResponseToSucceedsWhenNoInResponseToInResponseOrAssertions ( ) {
Response response = response ( ) ;
response . getAssertions ( ) . add ( signed ( assertion ( ) ) ) ;
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest ( "SAML2" ,
Saml2MessageBinding . POST , false ) ;
Saml2AuthenticationToken token = token ( response , verifying ( registration ( ) ) , mockAuthenticationRequest ) ;
this . provider . authenticate ( token ) ;
}
@Test
public void authenticateWhenAssertionContainsAttributesThenItSucceeds ( ) {
Response response = response ( ) ;
@ -658,13 +766,27 @@ public class OpenSaml4AuthenticationProviderTests {
@@ -658,13 +766,27 @@ public class OpenSaml4AuthenticationProviderTests {
return response ;
}
private Assertion assertion ( ) {
private AuthnRequest request ( ) {
AuthnRequest request = TestOpenSamlObjects . authnRequest ( ) ;
return request ;
}
private String serializedRequest ( AuthnRequest request , Saml2MessageBinding binding ) {
String xml = serialize ( request ) ;
return ( binding = = Saml2MessageBinding . POST ) ? Saml2Utils . samlEncode ( xml . getBytes ( StandardCharsets . UTF_8 ) )
: Saml2Utils . samlEncode ( Saml2Utils . samlDeflate ( xml ) ) ;
}
private Assertion assertion ( String inResponseTo ) {
Assertion assertion = TestOpenSamlObjects . assertion ( ) ;
assertion . setIssueInstant ( Instant . now ( ) ) ;
for ( SubjectConfirmation confirmation : assertion . getSubject ( ) . getSubjectConfirmations ( ) ) {
SubjectConfirmationData data = confirmation . getSubjectConfirmationData ( ) ;
data . setNotBefore ( Instant . now ( ) . minus ( Duration . ofMillis ( 5 * 60 * 1000 ) ) ) ;
data . setNotOnOrAfter ( Instant . now ( ) . plus ( Duration . ofMillis ( 5 * 60 * 1000 ) ) ) ;
if ( StringUtils . hasText ( inResponseTo ) ) {
data . setInResponseTo ( inResponseTo ) ;
}
}
Conditions conditions = assertion . getConditions ( ) ;
conditions . setNotBefore ( Instant . now ( ) . minus ( Duration . ofMillis ( 5 * 60 * 1000 ) ) ) ;
@ -672,6 +794,10 @@ public class OpenSaml4AuthenticationProviderTests {
@@ -672,6 +794,10 @@ public class OpenSaml4AuthenticationProviderTests {
return assertion ;
}
private Assertion assertion ( ) {
return assertion ( null ) ;
}
private < T extends SignableSAMLObject > T signed ( T toSign ) {
TestOpenSamlObjects . signed ( toSign , TestSaml2X509Credentials . assertingPartySigningCredential ( ) ,
RELYING_PARTY_ENTITY_ID ) ;
@ -701,6 +827,27 @@ public class OpenSaml4AuthenticationProviderTests {
@@ -701,6 +827,27 @@ public class OpenSaml4AuthenticationProviderTests {
return new Saml2AuthenticationToken ( registration . build ( ) , serialize ( response ) ) ;
}
private Saml2AuthenticationToken token ( Response response , RelyingPartyRegistration . Builder registration ,
AbstractSaml2AuthenticationRequest authenticationRequest ) {
return new Saml2AuthenticationToken ( registration . build ( ) , serialize ( response ) , authenticationRequest ) ;
}
private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest ( String requestId ,
Saml2MessageBinding binding , boolean corruptRequestString ) {
AuthnRequest request = request ( ) ;
if ( requestId ! = null ) {
request . setID ( requestId ) ;
}
String serializedRequest = serializedRequest ( request , binding ) ;
if ( corruptRequestString ) {
serializedRequest = serializedRequest . substring ( 2 , serializedRequest . length ( ) - 2 ) ;
}
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock ( AbstractSaml2AuthenticationRequest . class ) ;
given ( mockAuthenticationRequest . getSamlRequest ( ) ) . willReturn ( serializedRequest ) ;
given ( mockAuthenticationRequest . getBinding ( ) ) . willReturn ( binding ) ;
return mockAuthenticationRequest ;
}
private RelyingPartyRegistration . Builder registration ( ) {
return TestRelyingPartyRegistrations . noCredentials ( ) . entityId ( RELYING_PARTY_ENTITY_ID )
. assertionConsumerServiceLocation ( DESTINATION )