@ -23,10 +23,13 @@ import org.junit.jupiter.api.Test;
@@ -23,10 +23,13 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest ;
import org.junit.jupiter.params.provider.Arguments ;
import org.junit.jupiter.params.provider.MethodSource ;
import org.mockito.Answers ;
import org.mockito.MockedStatic ;
import org.opensaml.xmlsec.signature.support.SignatureConstants ;
import org.springframework.mock.web.MockHttpServletRequest ;
import org.springframework.security.saml2.Saml2Exception ;
import org.springframework.security.saml2.core.Saml2ParameterNames ;
import org.springframework.security.saml2.core.Saml2X509Credential ;
import org.springframework.security.saml2.core.TestSaml2X509Credentials ;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest ;
@ -39,6 +42,12 @@ import org.springframework.security.saml2.provider.service.web.RelyingPartyRegis
@@ -39,6 +42,12 @@ import org.springframework.security.saml2.provider.service.web.RelyingPartyRegis
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.ArgumentMatchers.eq ;
import static org.mockito.Mockito.mockStatic ;
import static org.mockito.Mockito.never ;
import static org.mockito.Mockito.spy ;
import static org.mockito.Mockito.verify ;
/ * *
* Tests for { @link OpenSamlAuthenticationRequestResolver }
@ -198,6 +207,58 @@ public class OpenSamlAuthenticationRequestResolverTests {
@@ -198,6 +207,58 @@ public class OpenSamlAuthenticationRequestResolverTests {
assertThat ( result . getId ( ) ) . isNotEmpty ( ) ;
}
@Test
public void resolveAuthenticationRequestWhenSignedAndRelayStateIsNullThenSignsWithoutRelayState ( ) {
try ( MockedStatic < OpenSamlSigningUtils > openSamlSigningUtilsMockedStatic = mockStatic (
OpenSamlSigningUtils . class , Answers . CALLS_REAL_METHODS ) ) {
MockHttpServletRequest request = new MockHttpServletRequest ( ) ;
request . setPathInfo ( "/saml2/authenticate/registration-id" ) ;
RelyingPartyRegistration registration = this . relyingPartyRegistrationBuilder
. assertingPartyDetails ( ( party ) - > party . wantAuthnRequestsSigned ( true ) )
. build ( ) ;
OpenSamlSigningUtils . QueryParametersPartial queryParametersPartialSpy = spy (
new OpenSamlSigningUtils . QueryParametersPartial ( registration ) ) ;
openSamlSigningUtilsMockedStatic . when ( ( ) - > OpenSamlSigningUtils . sign ( any ( ) ) )
. thenReturn ( queryParametersPartialSpy ) ;
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver ( registration ) ;
resolver . setRelayStateResolver ( ( source ) - > null ) ;
Saml2RedirectAuthenticationRequest result = resolver . resolve ( request , ( r , authnRequest ) - > {
} ) ;
assertThat ( result . getSamlRequest ( ) ) . isNotEmpty ( ) ;
assertThat ( result . getRelayState ( ) ) . isNull ( ) ;
assertThat ( result . getSigAlg ( ) ) . isNotNull ( ) ;
assertThat ( result . getSignature ( ) ) . isNotNull ( ) ;
assertThat ( result . getBinding ( ) ) . isEqualTo ( Saml2MessageBinding . REDIRECT ) ;
verify ( queryParametersPartialSpy , never ( ) ) . param ( eq ( Saml2ParameterNames . RELAY_STATE ) , any ( ) ) ;
}
}
@Test
public void resolveAuthenticationRequestWhenSignedAndRelayStateIsEmptyThenSignsWithEmptyRelayState ( ) {
try ( MockedStatic < OpenSamlSigningUtils > openSamlSigningUtilsMockedStatic = mockStatic (
OpenSamlSigningUtils . class , Answers . CALLS_REAL_METHODS ) ) {
MockHttpServletRequest request = new MockHttpServletRequest ( ) ;
request . setPathInfo ( "/saml2/authenticate/registration-id" ) ;
RelyingPartyRegistration registration = this . relyingPartyRegistrationBuilder
. assertingPartyDetails ( ( party ) - > party . wantAuthnRequestsSigned ( true ) )
. build ( ) ;
OpenSamlSigningUtils . QueryParametersPartial queryParametersPartialSpy = spy (
new OpenSamlSigningUtils . QueryParametersPartial ( registration ) ) ;
openSamlSigningUtilsMockedStatic . when ( ( ) - > OpenSamlSigningUtils . sign ( any ( ) ) )
. thenReturn ( queryParametersPartialSpy ) ;
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver ( registration ) ;
resolver . setRelayStateResolver ( ( source ) - > "" ) ;
Saml2RedirectAuthenticationRequest result = resolver . resolve ( request , ( r , authnRequest ) - > {
} ) ;
assertThat ( result . getSamlRequest ( ) ) . isNotEmpty ( ) ;
assertThat ( result . getRelayState ( ) ) . isEmpty ( ) ;
assertThat ( result . getSigAlg ( ) ) . isNotNull ( ) ;
assertThat ( result . getSignature ( ) ) . isNotNull ( ) ;
assertThat ( result . getBinding ( ) ) . isEqualTo ( Saml2MessageBinding . REDIRECT ) ;
verify ( queryParametersPartialSpy ) . param ( eq ( Saml2ParameterNames . RELAY_STATE ) , eq ( "" ) ) ;
}
}
private OpenSamlAuthenticationRequestResolver authenticationRequestResolver ( RelyingPartyRegistration registration ) {
return new OpenSamlAuthenticationRequestResolver ( ( request , id ) - > registration ) ;
}