@ -17,8 +17,6 @@
package org.springframework.security.saml2.provider.service.authentication ;
package org.springframework.security.saml2.provider.service.authentication ;
import java.io.ByteArrayInputStream ;
import java.io.ByteArrayInputStream ;
import java.util.function.Consumer ;
import java.util.function.Function ;
import org.junit.Assert ;
import org.junit.Assert ;
import org.junit.Before ;
import org.junit.Before ;
@ -31,6 +29,7 @@ import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
import org.w3c.dom.Document ;
import org.w3c.dom.Document ;
import org.w3c.dom.Element ;
import org.w3c.dom.Element ;
import org.springframework.core.convert.converter.Converter ;
import org.springframework.security.saml2.Saml2Exception ;
import org.springframework.security.saml2.Saml2Exception ;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration ;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration ;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding ;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding ;
@ -47,6 +46,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getU
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential ;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate ;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate ;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest ;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration ;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration ;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST ;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST ;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT ;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT ;
@ -63,8 +63,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
private RelyingPartyRegistration . Builder relyingPartyRegistrationBuilder ;
private RelyingPartyRegistration . Builder relyingPartyRegistrationBuilder ;
private RelyingPartyRegistration relyingPartyRegistration ;
private RelyingPartyRegistration relyingPartyRegistration ;
private AuthnRequestUnmarshaller unmarshaller = ( AuthnRequestUnmarshaller ) getUnmarshallerFactory ( )
private AuthnRequestUnmarshaller unmarshaller ;
. getUnmarshaller ( AuthnRequest . DEFAULT_ELEMENT_NAME ) ;
@Rule
@Rule
public ExpectedException exception = ExpectedException . none ( ) ;
public ExpectedException exception = ExpectedException . none ( ) ;
@ -84,6 +83,8 @@ public class OpenSamlAuthenticationRequestFactoryTests {
. assertionConsumerServiceUrl ( "https://issuer/sso" ) ;
. assertionConsumerServiceUrl ( "https://issuer/sso" ) ;
context = contextBuilder . build ( ) ;
context = contextBuilder . build ( ) ;
factory = new OpenSamlAuthenticationRequestFactory ( ) ;
factory = new OpenSamlAuthenticationRequestFactory ( ) ;
this . unmarshaller = ( AuthnRequestUnmarshaller ) getUnmarshallerFactory ( )
. getUnmarshaller ( AuthnRequest . DEFAULT_ELEMENT_NAME ) ;
}
}
@Test
@Test
@ -182,29 +183,29 @@ public class OpenSamlAuthenticationRequestFactoryTests {
@Test
@Test
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses ( ) {
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses ( ) {
Function < Saml2AuthenticationRequestContext , Consumer < AuthnRequest > > authnRequestConsumerResolv er =
Converter < Saml2AuthenticationRequestContext , AuthnRequest > authenticationRequestContextConvert er =
mock ( Function . class ) ;
mock ( Converter . class ) ;
when ( authnRequestConsumerResolver . apply ( this . context ) ) . thenReturn ( authnRequest - > { } ) ;
when ( authenticationRequestContextConverter . convert ( this . context ) ) . thenReturn ( authnRequest ( ) ) ;
this . factory . setAuthnRequestConsumerResolver ( authnRequestConsumerResolv er ) ;
this . factory . setAuthenticationRequestContextConverter ( authenticationRequestContextConvert er ) ;
this . factory . createPostAuthenticationRequest ( this . context ) ;
this . factory . createPostAuthenticationRequest ( this . context ) ;
verify ( authnRequestConsumerResolver ) . apply ( this . context ) ;
verify ( authenticationRequestContextConverter ) . convert ( this . context ) ;
}
}
@Test
@Test
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses ( ) {
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses ( ) {
Function < Saml2AuthenticationRequestContext , Consumer < AuthnRequest > > authnRequestConsumerResolv er =
Converter < Saml2AuthenticationRequestContext , AuthnRequest > authenticationRequestContextConvert er =
mock ( Function . class ) ;
mock ( Converter . class ) ;
when ( authnRequestConsumerResolver . apply ( this . context ) ) . thenReturn ( authnRequest - > { } ) ;
when ( authenticationRequestContextConverter . convert ( this . context ) ) . thenReturn ( authnRequest ( ) ) ;
this . factory . setAuthnRequestConsumerResolver ( authnRequestConsumerResolv er ) ;
this . factory . setAuthenticationRequestContextConverter ( authenticationRequestContextConvert er ) ;
this . factory . createRedirectAuthenticationRequest ( this . context ) ;
this . factory . createRedirectAuthenticationRequest ( this . context ) ;
verify ( authnRequestConsumerResolver ) . apply ( this . context ) ;
verify ( authenticationRequestContextConverter ) . convert ( this . context ) ;
}
}
@Test
@Test
public void setAuthnRequestConsumerResolv erWhenNullThenException ( ) {
public void setAuthenticationRequestContextConvert erWhenNullThenException ( ) {
assertThatCode ( ( ) - > this . factory . setAuthnRequestConsumerResolv er ( null ) )
assertThatCode ( ( ) - > this . factory . setAuthenticationRequestContextConvert er ( null ) )
. isInstanceOf ( IllegalArgumentException . class ) ;
. isInstanceOf ( IllegalArgumentException . class ) ;
}
}