2 changed files with 209 additions and 0 deletions
@ -0,0 +1,135 @@ |
|||||||
|
/* |
||||||
|
* Copyright 2002-2020 the original author or authors. |
||||||
|
* |
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
* you may not use this file except in compliance with the License. |
||||||
|
* You may obtain a copy of the License at |
||||||
|
* |
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* |
||||||
|
* Unless required by applicable law or agreed to in writing, software |
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
* See the License for the specific language governing permissions and |
||||||
|
* limitations under the License. |
||||||
|
*/ |
||||||
|
|
||||||
|
package org.springframework.security.saml2.provider.service.web; |
||||||
|
|
||||||
|
import java.util.HashMap; |
||||||
|
import java.util.Map; |
||||||
|
import java.util.function.Function; |
||||||
|
import javax.servlet.http.HttpServletRequest; |
||||||
|
|
||||||
|
import org.springframework.core.convert.converter.Converter; |
||||||
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; |
||||||
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; |
||||||
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; |
||||||
|
import org.springframework.security.web.util.matcher.RequestMatcher; |
||||||
|
import org.springframework.util.Assert; |
||||||
|
import org.springframework.util.StringUtils; |
||||||
|
import org.springframework.web.util.UriComponents; |
||||||
|
import org.springframework.web.util.UriComponentsBuilder; |
||||||
|
|
||||||
|
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration; |
||||||
|
import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl; |
||||||
|
import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl; |
||||||
|
|
||||||
|
/** |
||||||
|
* A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the |
||||||
|
* registration id from the request, querying a {@link RelyingPartyRegistrationRepository}, |
||||||
|
* and resolving any template values. |
||||||
|
* |
||||||
|
* @since 5.4 |
||||||
|
* @author Josh Cummings |
||||||
|
*/ |
||||||
|
public final class DefaultRelyingPartyRegistrationResolver |
||||||
|
implements Converter<HttpServletRequest, RelyingPartyRegistration> { |
||||||
|
|
||||||
|
private static final char PATH_DELIMITER = '/'; |
||||||
|
|
||||||
|
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; |
||||||
|
private final Converter<HttpServletRequest, String> registrationIdResolver = new RegistrationIdResolver(); |
||||||
|
|
||||||
|
public DefaultRelyingPartyRegistrationResolver |
||||||
|
(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { |
||||||
|
|
||||||
|
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null"); |
||||||
|
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public RelyingPartyRegistration convert(HttpServletRequest request) { |
||||||
|
String registrationId = this.registrationIdResolver.convert(request); |
||||||
|
if (registrationId == null) { |
||||||
|
return null; |
||||||
|
} |
||||||
|
RelyingPartyRegistration relyingPartyRegistration = |
||||||
|
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId); |
||||||
|
if (relyingPartyRegistration == null) { |
||||||
|
return null; |
||||||
|
} |
||||||
|
|
||||||
|
String applicationUri = getApplicationUri(request); |
||||||
|
Function<String, String> templateResolver = templateResolver(applicationUri, relyingPartyRegistration); |
||||||
|
String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId()); |
||||||
|
String assertionConsumerServiceLocation = templateResolver.apply( |
||||||
|
relyingPartyRegistration.getAssertionConsumerServiceLocation()); |
||||||
|
return withRelyingPartyRegistration(relyingPartyRegistration) |
||||||
|
.entityId(relyingPartyEntityId) |
||||||
|
.assertionConsumerServiceLocation(assertionConsumerServiceLocation) |
||||||
|
.build(); |
||||||
|
} |
||||||
|
|
||||||
|
private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) { |
||||||
|
return template -> resolveUrlTemplate(template, applicationUri, relyingParty); |
||||||
|
} |
||||||
|
|
||||||
|
private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) { |
||||||
|
String entityId = relyingParty.getAssertingPartyDetails().getEntityId(); |
||||||
|
String registrationId = relyingParty.getRegistrationId(); |
||||||
|
Map<String, String> uriVariables = new HashMap<>(); |
||||||
|
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl) |
||||||
|
.replaceQuery(null) |
||||||
|
.fragment(null) |
||||||
|
.build(); |
||||||
|
String scheme = uriComponents.getScheme(); |
||||||
|
uriVariables.put("baseScheme", scheme == null ? "" : scheme); |
||||||
|
String host = uriComponents.getHost(); |
||||||
|
uriVariables.put("baseHost", host == null ? "" : host); |
||||||
|
// following logic is based on HierarchicalUriComponents#toUriString()
|
||||||
|
int port = uriComponents.getPort(); |
||||||
|
uriVariables.put("basePort", port == -1 ? "" : ":" + port); |
||||||
|
String path = uriComponents.getPath(); |
||||||
|
if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) { |
||||||
|
path = PATH_DELIMITER + path; |
||||||
|
} |
||||||
|
uriVariables.put("basePath", path == null ? "" : path); |
||||||
|
uriVariables.put("baseUrl", uriComponents.toUriString()); |
||||||
|
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : ""); |
||||||
|
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); |
||||||
|
|
||||||
|
return UriComponentsBuilder.fromUriString(template) |
||||||
|
.buildAndExpand(uriVariables) |
||||||
|
.toUriString(); |
||||||
|
} |
||||||
|
|
||||||
|
private static String getApplicationUri(HttpServletRequest request) { |
||||||
|
UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request)) |
||||||
|
.replacePath(request.getContextPath()) |
||||||
|
.replaceQuery(null) |
||||||
|
.fragment(null) |
||||||
|
.build(); |
||||||
|
return uriComponents.toUriString(); |
||||||
|
} |
||||||
|
|
||||||
|
private static class RegistrationIdResolver implements Converter<HttpServletRequest, String> { |
||||||
|
private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/**/{registrationId}"); |
||||||
|
|
||||||
|
@Override |
||||||
|
public String convert(HttpServletRequest request) { |
||||||
|
RequestMatcher.MatchResult result = this.requestMatcher.matcher(request); |
||||||
|
return result.getVariables().get("registrationId"); |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
@ -0,0 +1,74 @@ |
|||||||
|
/* |
||||||
|
* Copyright 2002-2020 the original author or authors. |
||||||
|
* |
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
* you may not use this file except in compliance with the License. |
||||||
|
* You may obtain a copy of the License at |
||||||
|
* |
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* |
||||||
|
* Unless required by applicable law or agreed to in writing, software |
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
* See the License for the specific language governing permissions and |
||||||
|
* limitations under the License. |
||||||
|
*/ |
||||||
|
|
||||||
|
package org.springframework.security.saml2.provider.service.web; |
||||||
|
|
||||||
|
import org.junit.Test; |
||||||
|
|
||||||
|
import org.springframework.mock.web.MockHttpServletRequest; |
||||||
|
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository; |
||||||
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; |
||||||
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; |
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat; |
||||||
|
import static org.assertj.core.api.Assertions.assertThatCode; |
||||||
|
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; |
||||||
|
|
||||||
|
/** |
||||||
|
* Tests for {@link DefaultRelyingPartyRegistrationResolver} |
||||||
|
*/ |
||||||
|
public class DefaultRelyingPartyRegistrationResolverTests { |
||||||
|
private final RelyingPartyRegistration registration = relyingPartyRegistration().build(); |
||||||
|
private final RelyingPartyRegistrationRepository repository = |
||||||
|
new InMemoryRelyingPartyRegistrationRepository(this.registration); |
||||||
|
private final DefaultRelyingPartyRegistrationResolver resolver = |
||||||
|
new DefaultRelyingPartyRegistrationResolver(this.repository); |
||||||
|
|
||||||
|
@Test |
||||||
|
public void resolveWhenRequestContainsRegistrationIdThenResolves() { |
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest(); |
||||||
|
request.setPathInfo("/some/path/" + this.registration.getRegistrationId()); |
||||||
|
RelyingPartyRegistration registration = this.resolver.convert(request); |
||||||
|
assertThat(registration).isNotNull(); |
||||||
|
assertThat(registration.getRegistrationId()) |
||||||
|
.isEqualTo(this.registration.getRegistrationId()); |
||||||
|
assertThat(registration.getEntityId()) |
||||||
|
.isEqualTo("http://localhost/saml2/service-provider-metadata/" + this.registration.getRegistrationId()); |
||||||
|
assertThat(registration.getAssertionConsumerServiceLocation()) |
||||||
|
.isEqualTo("http://localhost/login/saml2/sso/" + this.registration.getRegistrationId()); |
||||||
|
} |
||||||
|
|
||||||
|
@Test |
||||||
|
public void resolveWhenRequestContainsInvalidRegistrationIdThenNull() { |
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest(); |
||||||
|
request.setPathInfo("/some/path/not-" + this.registration.getRegistrationId()); |
||||||
|
RelyingPartyRegistration registration = this.resolver.convert(request); |
||||||
|
assertThat(registration).isNull(); |
||||||
|
} |
||||||
|
|
||||||
|
@Test |
||||||
|
public void resolveWhenRequestIsMissingRegistrationIdThenNull() { |
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest(); |
||||||
|
RelyingPartyRegistration registration = this.resolver.convert(request); |
||||||
|
assertThat(registration).isNull(); |
||||||
|
} |
||||||
|
|
||||||
|
@Test |
||||||
|
public void constructorWhenNullRelyingPartyRegistrationThenIllegalArgument() { |
||||||
|
assertThatCode(() -> new DefaultRelyingPartyRegistrationResolver(null)) |
||||||
|
.isInstanceOf(IllegalArgumentException.class); |
||||||
|
} |
||||||
|
} |
||||||
Loading…
Reference in new issue