From 538db29bfea496185011fdca34a53b3dff9af336 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Tue, 28 Feb 2023 13:57:55 -0700 Subject: [PATCH] Add RelyingPartyRegstration#mutate Closes gh-12841 --- .../RelyingPartyRegistration.java | 59 ++++++++++++++----- ...faultRelyingPartyRegistrationResolver.java | 4 +- .../RelyingPartyRegistrationTests.java | 12 ++++ 3 files changed, 57 insertions(+), 18 deletions(-) diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java index ff2fe00452..7360e810ba 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java @@ -130,6 +130,35 @@ public final class RelyingPartyRegistration { this.signingX509Credentials = Collections.unmodifiableList(new LinkedList<>(signingX509Credentials)); } + /** + * Copy the properties in this {@link RelyingPartyRegistration} into a {@link Builder} + * @return a {@link Builder} based off of the properties in this + * {@link RelyingPartyRegistration} + * @since 6.1 + */ + public Builder mutate() { + AssertingPartyDetails party = this.assertingPartyDetails; + return withRegistrationId(this.registrationId).entityId(this.entityId) + .signingX509Credentials((c) -> c.addAll(this.signingX509Credentials)) + .decryptionX509Credentials((c) -> c.addAll(this.decryptionX509Credentials)) + .assertionConsumerServiceLocation(this.assertionConsumerServiceLocation) + .assertionConsumerServiceBinding(this.assertionConsumerServiceBinding) + .singleLogoutServiceLocation(this.singleLogoutServiceLocation) + .singleLogoutServiceResponseLocation(this.singleLogoutServiceResponseLocation) + .singleLogoutServiceBindings((c) -> c.addAll(this.singleLogoutServiceBindings)) + .nameIdFormat(this.nameIdFormat) + .assertingPartyDetails((assertingParty) -> assertingParty.entityId(party.getEntityId()) + .wantAuthnRequestsSigned(party.getWantAuthnRequestsSigned()) + .signingAlgorithms((algorithms) -> algorithms.addAll(party.getSigningAlgorithms())) + .verificationX509Credentials((c) -> c.addAll(party.getVerificationX509Credentials())) + .encryptionX509Credentials((c) -> c.addAll(party.getEncryptionX509Credentials())) + .singleSignOnServiceLocation(party.getSingleSignOnServiceLocation()) + .singleSignOnServiceBinding(party.getSingleSignOnServiceBinding()) + .singleLogoutServiceLocation(party.getSingleLogoutServiceLocation()) + .singleLogoutServiceResponseLocation(party.getSingleLogoutServiceResponseLocation()) + .singleLogoutServiceBinding(party.getSingleLogoutServiceBinding())); + } + /** * Get the unique registration id for this RP/AP pair * @return the unique registration id for this RP/AP pair @@ -292,7 +321,7 @@ public final class RelyingPartyRegistration { */ public static Builder withRegistrationId(String registrationId) { Assert.hasText(registrationId, "registrationId cannot be empty"); - return new Builder(registrationId); + return new Builder(registrationId, new AssertingPartyDetails.Builder()); } public static Builder withAssertingPartyDetails(AssertingPartyDetails assertingPartyDetails) { @@ -315,7 +344,9 @@ public final class RelyingPartyRegistration { * object * @param registration the {@code RelyingPartyRegistration} * @return {@code Builder} to create a {@code RelyingPartyRegistration} object + * @deprecated Use {@link #mutate()} instead */ + @Deprecated(forRemoval = true, since = "6.1") public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) { Assert.notNull(registration, "registration cannot be null"); return withRegistrationId(registration.getRegistrationId()).entityId(registration.getEntityId()) @@ -736,9 +767,9 @@ public final class RelyingPartyRegistration { } - public static final class Builder { + public static class Builder { - private Converter registrationId = AssertingPartyDetails::getEntityId; + private String registrationId; private String entityId = "{baseUrl}/saml2/service-provider-metadata/{registrationId}"; @@ -760,13 +791,9 @@ public final class RelyingPartyRegistration { private AssertingPartyDetails.Builder assertingPartyDetailsBuilder; - private Builder(String registrationId) { - this.registrationId = (party) -> registrationId; - this.assertingPartyDetailsBuilder = new AssertingPartyDetails.Builder(); - } - - Builder(AssertingPartyDetails.Builder builder) { - this.assertingPartyDetailsBuilder = builder; + protected Builder(String registrationId, AssertingPartyDetails.Builder assertingPartyDetailsBuilder) { + this.registrationId = registrationId; + this.assertingPartyDetailsBuilder = assertingPartyDetailsBuilder; } /** @@ -775,7 +802,7 @@ public final class RelyingPartyRegistration { * @return this object */ public Builder registrationId(String id) { - this.registrationId = (party) -> id; + this.registrationId = id; return this; } @@ -974,11 +1001,11 @@ public final class RelyingPartyRegistration { } AssertingPartyDetails party = this.assertingPartyDetailsBuilder.build(); - String registrationId = this.registrationId.convert(party); - return new RelyingPartyRegistration(registrationId, this.entityId, this.assertionConsumerServiceLocation, - this.assertionConsumerServiceBinding, this.singleLogoutServiceLocation, - this.singleLogoutServiceResponseLocation, this.singleLogoutServiceBindings, party, - this.nameIdFormat, this.decryptionX509Credentials, this.signingX509Credentials); + return new RelyingPartyRegistration(this.registrationId, this.entityId, + this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding, + this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation, + this.singleLogoutServiceBindings, party, this.nameIdFormat, this.decryptionX509Credentials, + this.signingX509Credentials); } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java index 83235e72e9..446cbf3bc6 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java @@ -101,8 +101,8 @@ public final class DefaultRelyingPartyRegistrationResolver .apply(relyingPartyRegistration.getSingleLogoutServiceLocation()); String singleLogoutServiceResponseLocation = templateResolver .apply(relyingPartyRegistration.getSingleLogoutServiceResponseLocation()); - return RelyingPartyRegistration.withRelyingPartyRegistration(relyingPartyRegistration) - .entityId(relyingPartyEntityId).assertionConsumerServiceLocation(assertionConsumerServiceLocation) + return relyingPartyRegistration.mutate().entityId(relyingPartyEntityId) + .assertionConsumerServiceLocation(assertionConsumerServiceLocation) .singleLogoutServiceLocation(singleLogoutServiceLocation) .singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation).build(); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java index dc202b4574..05293bf28a 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java @@ -38,6 +38,18 @@ public class RelyingPartyRegistrationTests { compareRegistrations(registration, copy); } + @Test + void mutateWhenInvokedThenCreatesCopy() { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration() + .nameIdFormat("format") + .assertingPartyDetails((a) -> a.singleSignOnServiceBinding(Saml2MessageBinding.POST)) + .assertingPartyDetails((a) -> a.wantAuthnRequestsSigned(false)) + .assertingPartyDetails((a) -> a.signingAlgorithms((algs) -> algs.add("alg"))) + .assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT).build(); + RelyingPartyRegistration copy = registration.mutate().build(); + compareRegistrations(registration, copy); + } + private void compareRegistrations(RelyingPartyRegistration registration, RelyingPartyRegistration copy) { assertThat(copy.getRegistrationId()).isEqualTo(registration.getRegistrationId()).isEqualTo("simplesamlphp"); assertThat(copy.getAssertingPartyDetails().getEntityId())