From 7f5c31995e65dca3ea0c56e80e879d09780cabee Mon Sep 17 00:00:00 2001 From: Ulrich Grave Date: Mon, 16 May 2022 09:17:53 +0200 Subject: [PATCH] Add relyingPartyRegistrationId to AbstractSaml2AuthenticationRequest Closes gh-11195 --- .../Saml2PostAuthenticationRequestMixin.java | 3 +- ...ml2RedirectAuthenticationRequestMixin.java | 3 +- .../AbstractSaml2AuthenticationRequest.java | 34 +++++++++++++- .../Saml2PostAuthenticationRequest.java | 13 +++--- .../Saml2RedirectAuthenticationRequest.java | 11 ++--- .../Saml2AuthenticationTokenConverter.java | 24 +++++----- ...l2PostAuthenticationRequestMixinTests.java | 17 +++++++ ...directAuthenticationRequestMixinTests.java | 20 +++++++++ .../saml2/jackson2/TestSaml2JsonPayloads.java | 13 ++++-- ...aml2AuthenticationTokenConverterTests.java | 44 +++++++++++++++++++ 10 files changed, 154 insertions(+), 28 deletions(-) diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixin.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixin.java index 3f502b61d2..53ddeb73d9 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixin.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixin.java @@ -47,7 +47,8 @@ class Saml2PostAuthenticationRequestMixin { @JsonCreator Saml2PostAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest, @JsonProperty("relayState") String relayState, - @JsonProperty("authenticationRequestUri") String authenticationRequestUri) { + @JsonProperty("authenticationRequestUri") String authenticationRequestUri, + @JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) { } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixin.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixin.java index 9af07a7f6b..247b52104c 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixin.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixin.java @@ -48,7 +48,8 @@ class Saml2RedirectAuthenticationRequestMixin { Saml2RedirectAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest, @JsonProperty("sigAlg") String sigAlg, @JsonProperty("signature") String signature, @JsonProperty("relayState") String relayState, - @JsonProperty("authenticationRequestUri") String authenticationRequestUri) { + @JsonProperty("authenticationRequestUri") String authenticationRequestUri, + @JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) { } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java index 4e3781d1a6..04e6a958f8 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java @@ -20,6 +20,7 @@ import java.io.Serializable; import java.nio.charset.Charset; import org.springframework.security.core.SpringSecurityCoreVersion; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.util.Assert; @@ -46,6 +47,8 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable private final String authenticationRequestUri; + private final String relyingPartyRegistrationId; + /** * Mandatory constructor for the {@link AbstractSaml2AuthenticationRequest} * @param samlRequest - the SAMLRequest XML data, SAML encoded, cannot be empty or @@ -53,13 +56,17 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable * @param relayState - RelayState value that accompanies the request, may be null * @param authenticationRequestUri - The authenticationRequestUri, a URL, where to * send the XML message, cannot be empty or null + * @param relyingPartyRegistrationId the registration id of the relying party, may be + * null */ - AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri) { + AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri, + String relyingPartyRegistrationId) { Assert.hasText(samlRequest, "samlRequest cannot be null or empty"); Assert.hasText(authenticationRequestUri, "authenticationRequestUri cannot be null or empty"); this.authenticationRequestUri = authenticationRequestUri; this.samlRequest = samlRequest; this.relayState = relayState; + this.relyingPartyRegistrationId = relyingPartyRegistrationId; } /** @@ -89,6 +96,16 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable return this.authenticationRequestUri; } + /** + * The identifier for the {@link RelyingPartyRegistration} associated with this + * request + * @return the {@link RelyingPartyRegistration} id + * @since 5.8 + */ + public String getRelyingPartyRegistrationId() { + return this.relyingPartyRegistrationId; + } + /** * Returns the binding this AuthNRequest will be sent and encoded with. If * {@link Saml2MessageBinding#REDIRECT} is used, the DEFLATE encoding will be @@ -108,9 +125,24 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable String relayState; + String relyingPartyRegistrationId; + + /** + * @deprecated Use {@link #Builder(RelyingPartyRegistration)} instead + */ + @Deprecated protected Builder() { } + /** + * Creates a new Builder with relying party registration + * @param registration the registration of the relying party. + * @sine 5.8 + */ + protected Builder(RelyingPartyRegistration registration) { + this.relyingPartyRegistrationId = registration.getRegistrationId(); + } + /** * Casting the return as the generic subtype, when returning itself * @return this object diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java index 9bf732206c..29dc000b39 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java @@ -30,8 +30,9 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes */ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationRequest { - Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri) { - super(samlRequest, relayState, authenticationRequestUri); + Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri, + String relyingPartyRegistrationId) { + super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId); } /** @@ -50,7 +51,7 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR */ public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) { String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation(); - return new Builder().authenticationRequestUri(location); + return new Builder(registration).authenticationRequestUri(location); } /** @@ -58,7 +59,8 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR */ public static final class Builder extends AbstractSaml2AuthenticationRequest.Builder { - private Builder() { + private Builder(RelyingPartyRegistration registration) { + super(registration); } /** @@ -66,7 +68,8 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR * @return an immutable {@link Saml2PostAuthenticationRequest} object. */ public Saml2PostAuthenticationRequest build() { - return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri); + return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri, + this.relyingPartyRegistrationId); } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java index 281d6bf48a..600ef993c9 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java @@ -35,8 +35,8 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe private final String signature; private Saml2RedirectAuthenticationRequest(String samlRequest, String sigAlg, String signature, String relayState, - String authenticationRequestUri) { - super(samlRequest, relayState, authenticationRequestUri); + String authenticationRequestUri, String relyingPartyRegistrationId) { + super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId); this.sigAlg = sigAlg; this.signature = signature; } @@ -74,7 +74,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe */ public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) { String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation(); - return new Builder().authenticationRequestUri(location); + return new Builder(registration).authenticationRequestUri(location); } /** @@ -86,7 +86,8 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe private String signature; - private Builder() { + private Builder(RelyingPartyRegistration registration) { + super(registration); } /** @@ -115,7 +116,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe */ public Saml2RedirectAuthenticationRequest build() { return new Saml2RedirectAuthenticationRequest(this.samlRequest, this.sigAlg, this.signature, - this.relayState, this.authenticationRequestUri); + this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId); } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java index c7fa4ecbe0..56cefd4e00 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java @@ -26,7 +26,6 @@ import jakarta.servlet.http.HttpServletRequest; import org.apache.commons.codec.CodecPolicy; import org.apache.commons.codec.binary.Base64; -import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpMethod; import org.springframework.security.saml2.core.Saml2Error; import org.springframework.security.saml2.core.Saml2ErrorCodes; @@ -50,25 +49,29 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo private static Base64 BASE64 = new Base64(0, new byte[] { '\n' }, false, CodecPolicy.STRICT); - private final Converter relyingPartyRegistrationResolver; + private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; private Function loader; + /** + * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for + * resolving {@link RelyingPartyRegistration}s + * @param relyingPartyRegistrationResolver the strategy for resolving + * {@link RelyingPartyRegistration}s + */ public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); - this.relyingPartyRegistrationResolver = adaptToConverter(relyingPartyRegistrationResolver); + this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest; } - private static Converter adaptToConverter( - RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { - Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); - return (request) -> relyingPartyRegistrationResolver.resolve(request, null); - } - @Override public Saml2AuthenticationToken convert(HttpServletRequest request) { - RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request); + AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request); + String relyingPartyRegistrationId = (authenticationRequest != null) + ? authenticationRequest.getRelyingPartyRegistrationId() : null; + RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request, + relyingPartyRegistrationId); if (relyingPartyRegistration == null) { return null; } @@ -78,7 +81,6 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo } byte[] b = samlDecode(saml2Response); saml2Response = inflateIfRequired(request, b); - AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request); return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixinTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixinTests.java index c7bd5f29b9..3183d27434 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixinTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixinTests.java @@ -56,6 +56,23 @@ class Saml2PostAuthenticationRequestMixinTests { assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE); assertThat(authRequest.getAuthenticationRequestUri()) .isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI); + assertThat(authRequest.getRelyingPartyRegistrationId()) + .isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID); + } + + @Test + void shouldDeserializeWithNoRegistrationId() throws Exception { + String json = TestSaml2JsonPayloads.DEFAULT_POST_AUTH_REQUEST_JSON.replace( + "\"relyingPartyRegistrationId\": \"" + TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID + "\",", ""); + + Saml2PostAuthenticationRequest authRequest = this.mapper.readValue(json, Saml2PostAuthenticationRequest.class); + + assertThat(authRequest).isNotNull(); + assertThat(authRequest.getSamlRequest()).isEqualTo(TestSaml2JsonPayloads.SAML_REQUEST); + assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE); + assertThat(authRequest.getAuthenticationRequestUri()) + .isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI); + assertThat(authRequest.getRelyingPartyRegistrationId()).isNull(); } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixinTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixinTests.java index 9199fb3992..d9cfa77fad 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixinTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixinTests.java @@ -59,6 +59,26 @@ class Saml2RedirectAuthenticationRequestMixinTests { .isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI); assertThat(authRequest.getSigAlg()).isEqualTo(TestSaml2JsonPayloads.SIG_ALG); assertThat(authRequest.getSignature()).isEqualTo(TestSaml2JsonPayloads.SIGNATURE); + assertThat(authRequest.getRelyingPartyRegistrationId()) + .isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID); + } + + @Test + void shouldDeserializeWithNoRegistrationId() throws Exception { + String json = TestSaml2JsonPayloads.DEFAULT_REDIRECT_AUTH_REQUEST_JSON.replace( + "\"relyingPartyRegistrationId\": \"" + TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID + "\",", ""); + + Saml2RedirectAuthenticationRequest authRequest = this.mapper.readValue(json, + Saml2RedirectAuthenticationRequest.class); + + assertThat(authRequest).isNotNull(); + assertThat(authRequest.getSamlRequest()).isEqualTo(TestSaml2JsonPayloads.SAML_REQUEST); + assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE); + assertThat(authRequest.getAuthenticationRequestUri()) + .isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI); + assertThat(authRequest.getSigAlg()).isEqualTo(TestSaml2JsonPayloads.SIG_ALG); + assertThat(authRequest.getSignature()).isEqualTo(TestSaml2JsonPayloads.SIGNATURE); + assertThat(authRequest.getRelyingPartyRegistrationId()).isNull(); } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/TestSaml2JsonPayloads.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/TestSaml2JsonPayloads.java index ed3e36ec3e..18f2e7deb8 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/TestSaml2JsonPayloads.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/TestSaml2JsonPayloads.java @@ -94,6 +94,7 @@ final class TestSaml2JsonPayloads { static final String SAML_REQUEST = "samlRequestValue"; static final String RELAY_STATE = "relayStateValue"; static final String AUTHENTICATION_REQUEST_URI = "authenticationRequestUriValue"; + static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue"; static final String SIG_ALG = "sigAlgValue"; static final String SIGNATURE = "signatureValue"; @@ -103,6 +104,7 @@ final class TestSaml2JsonPayloads { + " \"samlRequest\": \"" + SAML_REQUEST + "\"," + " \"relayState\": \"" + RELAY_STATE + "\"," + " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\"," + + " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\"," + " \"sigAlg\": \"" + SIG_ALG + "\"," + " \"signature\": \"" + SIGNATURE + "\"" + "}"; @@ -113,6 +115,7 @@ final class TestSaml2JsonPayloads { + " \"@class\": \"org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest\"," + " \"samlRequest\": \"" + SAML_REQUEST + "\"," + " \"relayState\": \"" + RELAY_STATE + "\"," + + " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\"," + " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\"" + "}"; // @formatter:on @@ -120,7 +123,6 @@ final class TestSaml2JsonPayloads { static final String ID = "idValue"; static final String LOCATION = "locationValue"; static final String BINDNG = "REDIRECT"; - static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue"; static final String ADDITIONAL_PARAM = "additionalParamValue"; // @formatter:off @@ -140,14 +142,17 @@ final class TestSaml2JsonPayloads { // @formatter:on static Saml2PostAuthenticationRequest createDefaultSaml2PostAuthenticationRequest() { - return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(TestRelyingPartyRegistrations.full() - .assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI)) - .build()).samlRequest(SAML_REQUEST).relayState(RELAY_STATE).build(); + return Saml2PostAuthenticationRequest.withRelyingPartyRegistration( + TestRelyingPartyRegistrations.full().registrationId(RELYINGPARTY_REGISTRATION_ID) + .assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI)) + .build()) + .samlRequest(SAML_REQUEST).relayState(RELAY_STATE).build(); } static Saml2RedirectAuthenticationRequest createDefaultSaml2RedirectAuthenticationRequest() { return Saml2RedirectAuthenticationRequest .withRelyingPartyRegistration(TestRelyingPartyRegistrations.full() + .registrationId(RELYINGPARTY_REGISTRATION_ID) .assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI)) .build()) .samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).build(); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java index 760fbe5d3f..ae5d096c1f 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java @@ -42,8 +42,11 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; @ExtendWith(MockitoExtension.class) public class Saml2AuthenticationTokenConverterTests { @@ -69,6 +72,21 @@ public class Saml2AuthenticationTokenConverterTests { .isEqualTo(this.relyingPartyRegistration.getRegistrationId()); } + @Test + public void convertWhenSamlResponseWithRelyingPartyRegistrationResolver( + @Mock RelyingPartyRegistrationResolver resolver) { + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(resolver); + given(resolver.resolve(any(HttpServletRequest.class), any())).willReturn(this.relyingPartyRegistration); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, + Saml2Utils.samlEncodeNotRfc2045("response".getBytes(StandardCharsets.UTF_8))); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo("response"); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(this.relyingPartyRegistration.getRegistrationId()); + verify(resolver).resolve(any(), isNull()); + } + @Test public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() { Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( @@ -157,6 +175,8 @@ public class Saml2AuthenticationTokenConverterTests { Saml2AuthenticationRequestRepository authenticationRequestRepository = mock( Saml2AuthenticationRequestRepository.class); AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); + given(authenticationRequest.getRelyingPartyRegistrationId()) + .willReturn(this.relyingPartyRegistration.getRegistrationId()); Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( this.relyingPartyRegistrationResolver); converter.setAuthenticationRequestRepository(authenticationRequestRepository); @@ -174,6 +194,30 @@ public class Saml2AuthenticationTokenConverterTests { assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest); } + @Test + public void convertWhenSavedAuthenticationRequestThenTokenWithRelyingPartyRegistrationResolver( + @Mock RelyingPartyRegistrationResolver resolver) { + Saml2AuthenticationRequestRepository authenticationRequestRepository = mock( + Saml2AuthenticationRequestRepository.class); + AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); + given(authenticationRequest.getRelyingPartyRegistrationId()) + .willReturn(this.relyingPartyRegistration.getRegistrationId()); + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(resolver); + converter.setAuthenticationRequestRepository(authenticationRequestRepository); + given(resolver.resolve(any(HttpServletRequest.class), any())).willReturn(this.relyingPartyRegistration); + given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class))) + .willReturn(authenticationRequest); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter(Saml2ParameterNames.SAML_RESPONSE, + Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8))); + Saml2AuthenticationToken token = converter.convert(request); + assertThat(token.getSaml2Response()).isEqualTo("response"); + assertThat(token.getRelyingPartyRegistration().getRegistrationId()) + .isEqualTo(this.relyingPartyRegistration.getRegistrationId()); + assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest); + verify(resolver).resolve(any(), eq(this.relyingPartyRegistration.getRegistrationId())); + } + @Test public void constructorWhenResolverIsNullThenIllegalArgument() { assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null));