diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index d4be974be4..dccebcaf93 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -35,6 +35,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.AuthnRequest; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ConfigurableApplicationContext; @@ -89,6 +90,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential; +import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest; import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext; import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials; import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; @@ -176,8 +178,8 @@ public class Saml2LoginConfigurerTests { } @Test - public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception { - this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire(); + public void authenticationRequestWhenAuthnRequestContextConverterThenUses() throws Exception { + this.spring.register(CustomAuthenticationRequestContextConverterResolver.class).autowire(); MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")) .andReturn(); @@ -315,7 +317,7 @@ public class Saml2LoginConfigurerTests { @EnableWebSecurity @Import(Saml2LoginConfigBeans.class) - static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter { + static class CustomAuthenticationRequestContextConverterResolver extends WebSecurityConfigurerAdapter { @Override protected void configure(HttpSecurity http) throws Exception { @@ -330,8 +332,12 @@ public class Saml2LoginConfigurerTests { Saml2AuthenticationRequestFactory authenticationRequestFactory() { OpenSamlAuthenticationRequestFactory authenticationRequestFactory = new OpenSamlAuthenticationRequestFactory(); - authenticationRequestFactory.setAuthnRequestConsumerResolver( - context -> authnRequest -> authnRequest.setForceAuthn(true)); + authenticationRequestFactory.setAuthenticationRequestContextConverter( + context -> { + AuthnRequest authnRequest = authnRequest(); + authnRequest.setForceAuthn(true); + return authnRequest; + }); return authenticationRequestFactory; } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java index f35686bbc2..8900cc00a2 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java @@ -25,8 +25,6 @@ import java.util.Collection; import java.util.LinkedHashMap; import java.util.Map; import java.util.UUID; -import java.util.function.Consumer; -import java.util.function.Function; import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.joda.time.DateTime; @@ -88,8 +86,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn(); }; - private Function> authnRequestConsumerResolver - = context -> authnRequest -> {}; + private Converter authenticationRequestContextConverter + = this::createAuthnRequest; /** * Creates an {@link OpenSamlAuthenticationRequestFactory} @@ -124,7 +122,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication */ @Override public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) { - AuthnRequest authnRequest = createAuthnRequest(context); + AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context); String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() ? serialize(sign(authnRequest, context.getRelyingPartyRegistration())) : serialize(authnRequest); @@ -139,7 +137,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication */ @Override public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) { - AuthnRequest authnRequest = createAuthnRequest(context); + AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context); String xml = serialize(authnRequest); Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context); String deflatedAndEncoded = samlEncode(samlDeflate(xml)); @@ -168,11 +166,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication } private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) { - AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(), + return createAuthnRequest(context.getIssuer(), context.getDestination(), context.getAssertionConsumerServiceUrl(), this.protocolBindingResolver.convert(context)); - this.authnRequestConsumerResolver.apply(context).accept(authnRequest); - return authnRequest; } private AuthnRequest createAuthnRequest @@ -194,13 +190,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication /** * Set the {@link AuthnRequest} post-processor resolver * - * @param authnRequestConsumerResolver + * @param authenticationRequestContextConverter * @since 5.4 */ - public void setAuthnRequestConsumerResolver( - Function> authnRequestConsumerResolver) { - Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null"); - this.authnRequestConsumerResolver = authnRequestConsumerResolver; + public void setAuthenticationRequestContextConverter( + Converter authenticationRequestContextConverter) { + Assert.notNull(authenticationRequestContextConverter, "authenticationRequestContextConverter cannot be null"); + this.authenticationRequestContextConverter = authenticationRequestContextConverter; } /** diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java index 8b8da0de9e..bd4313b599 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java @@ -17,8 +17,6 @@ package org.springframework.security.saml2.provider.service.authentication; import java.io.ByteArrayInputStream; -import java.util.function.Consumer; -import java.util.function.Function; import org.junit.Assert; import org.junit.Before; @@ -31,6 +29,7 @@ import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller; import org.w3c.dom.Document; import org.w3c.dom.Element; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; @@ -47,6 +46,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getU import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential; import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode; import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate; +import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest; import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT; @@ -63,8 +63,7 @@ public class OpenSamlAuthenticationRequestFactoryTests { private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder; private RelyingPartyRegistration relyingPartyRegistration; - private AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) getUnmarshallerFactory() - .getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); + private AuthnRequestUnmarshaller unmarshaller; @Rule public ExpectedException exception = ExpectedException.none(); @@ -84,6 +83,8 @@ public class OpenSamlAuthenticationRequestFactoryTests { .assertionConsumerServiceUrl("https://issuer/sso"); context = contextBuilder.build(); factory = new OpenSamlAuthenticationRequestFactory(); + this.unmarshaller =(AuthnRequestUnmarshaller) getUnmarshallerFactory() + .getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); } @Test @@ -182,29 +183,29 @@ public class OpenSamlAuthenticationRequestFactoryTests { @Test public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() { - Function> authnRequestConsumerResolver = - mock(Function.class); - when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {}); - this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver); + Converter authenticationRequestContextConverter = + mock(Converter.class); + when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest()); + this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter); this.factory.createPostAuthenticationRequest(this.context); - verify(authnRequestConsumerResolver).apply(this.context); + verify(authenticationRequestContextConverter).convert(this.context); } @Test public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() { - Function> authnRequestConsumerResolver = - mock(Function.class); - when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {}); - this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver); + Converter authenticationRequestContextConverter = + mock(Converter.class); + when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest()); + this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter); this.factory.createRedirectAuthenticationRequest(this.context); - verify(authnRequestConsumerResolver).apply(this.context); + verify(authenticationRequestContextConverter).convert(this.context); } @Test - public void setAuthnRequestConsumerResolverWhenNullThenException() { - assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null)) + public void setAuthenticationRequestContextConverterWhenNullThenException() { + assertThatCode(() -> this.factory.setAuthenticationRequestContextConverter(null)) .isInstanceOf(IllegalArgumentException.class); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java index 3032ae3af6..b237a64498 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java @@ -53,6 +53,7 @@ import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.Attribute; import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.AttributeValue; +import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.Conditions; import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedID; @@ -86,7 +87,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getB import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS; import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential; -final class TestOpenSamlObjects { +public final class TestOpenSamlObjects { static { OpenSamlInitializationService.initialize(); } @@ -188,6 +189,16 @@ final class TestOpenSamlObjects { return conditions; } + public static AuthnRequest authnRequest() { + Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME); + issuer.setValue(ASSERTING_PARTY_ENTITY_ID); + AuthnRequest authnRequest = build(AuthnRequest.DEFAULT_ELEMENT_NAME); + authnRequest.setIssuer(issuer); + authnRequest.setDestination(ASSERTING_PARTY_ENTITY_ID + "/SSO.saml2"); + authnRequest.setAssertionConsumerServiceURL(DESTINATION); + return authnRequest; + } + static Credential getSigningCredential(Saml2X509Credential credential, String entityId) { BasicCredential cred = getBasicCredential(credential); cred.setEntityId(entityId);