diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java index 0e96bc0c04..d72d6772e8 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java @@ -29,6 +29,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -62,6 +63,11 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter { this.saml2MetadataResolver = saml2MetadataResolver; } + public Saml2MetadataFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, + Saml2MetadataResolver saml2MetadataResolver) { + this(new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository), saml2MetadataResolver); + } + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws ServletException, IOException { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java index 35ad396daf..a8fc1cdaae 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java @@ -64,9 +64,7 @@ public class Saml2MetadataFilterTests { public void setup() { this.repository = mock(RelyingPartyRegistrationRepository.class); this.resolver = mock(Saml2MetadataResolver.class); - RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver( - this.repository); - this.filter = new Saml2MetadataFilter(relyingPartyRegistrationResolver, this.resolver); + this.filter = new Saml2MetadataFilter(this.repository, this.resolver); this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); this.chain = mock(FilterChain.class); @@ -152,6 +150,19 @@ public class Saml2MetadataFilterTests { verify(this.repository).findByRegistrationId("registration-id"); } + @Test + public void doFilterWhenPathStartsWithOneThenServesMetadata() throws Exception { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); + given(this.repository.findByRegistrationId("one")).willReturn(registration); + given(this.resolver.resolve(any())).willReturn("metadata"); + this.filter = new Saml2MetadataFilter((id) -> this.repository.findByRegistrationId("one"), + this.resolver); + this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata")); + this.request.setPathInfo("/metadata"); + this.filter.doFilter(this.request, this.response, new MockFilterChain()); + verify(this.repository).findByRegistrationId("one"); + } + // gh-12026 @Test public void doFilterWhenCharacterEncodingThenEncodeSpecialCharactersCorrectly() throws Exception {