diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java index d72d6772e8..84c7e20332 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java @@ -63,6 +63,16 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter { this.saml2MetadataResolver = saml2MetadataResolver; } + /** + * Constructs an instance of {@link Saml2MetadataFilter} using the provided + * parameters. The {@link #relyingPartyRegistrationResolver} field will be initialized + * with a {@link DefaultRelyingPartyRegistrationResolver} instance using the provided + * {@link RelyingPartyRegistrationRepository} + * @param relyingPartyRegistrationRepository the + * {@link RelyingPartyRegistrationRepository} to use + * @param saml2MetadataResolver the {@link Saml2MetadataResolver} to use + * @since 6.1 + */ public Saml2MetadataFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, Saml2MetadataResolver saml2MetadataResolver) { this(new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository), saml2MetadataResolver); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java index a8fc1cdaae..3c5771e86e 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java @@ -33,6 +33,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.test.util.ReflectionTestUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -137,7 +138,7 @@ public class Saml2MetadataFilterTests { } @Test - public void doFilterWhenPathStartsWithRegistrationIdThenServesMetadata() throws Exception { + public void doFilterWhenResolverConstructorAndPathStartsWithRegistrationIdThenServesMetadata() throws Exception { RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); given(this.repository.findByRegistrationId("registration-id")).willReturn(registration); given(this.resolver.resolve(any())).willReturn("metadata"); @@ -151,16 +152,17 @@ public class Saml2MetadataFilterTests { } @Test - public void doFilterWhenPathStartsWithOneThenServesMetadata() throws Exception { + public void doFilterWhenRelyingPartyRegistrationRepositoryConstructorAndPathStartsWithRegistrationIdThenServesMetadata() + throws Exception { RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); - given(this.repository.findByRegistrationId("one")).willReturn(registration); + given(this.repository.findByRegistrationId("registration-id")).willReturn(registration); given(this.resolver.resolve(any())).willReturn("metadata"); - this.filter = new Saml2MetadataFilter((id) -> this.repository.findByRegistrationId("one"), + this.filter = new Saml2MetadataFilter((id) -> this.repository.findByRegistrationId("registration-id"), this.resolver); this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata")); this.request.setPathInfo("/metadata"); this.filter.doFilter(this.request, this.response, new MockFilterChain()); - verify(this.repository).findByRegistrationId("one"); + verify(this.repository).findByRegistrationId("registration-id"); } // gh-12026 @@ -196,4 +198,14 @@ public class Saml2MetadataFilterTests { .withMessage("metadataFilename must contain a {registrationId} match variable"); } + @Test + public void constructorWhenRelyingPartyRegistrationRepositoryThenUses() { + RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class); + this.filter = new Saml2MetadataFilter(repository, this.resolver); + DefaultRelyingPartyRegistrationResolver relyingPartyRegistrationResolver = (DefaultRelyingPartyRegistrationResolver) ReflectionTestUtils + .getField(this.filter, "relyingPartyRegistrationResolver"); + relyingPartyRegistrationResolver.resolve(this.request, "one"); + verify(repository).findByRegistrationId("one"); + } + }