diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index f23c644fbb..11fa522588 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -28,6 +28,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Consumer; import javax.annotation.Nonnull; import javax.xml.namespace.QName; @@ -196,10 +197,23 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion))); }; - private Converter assertionValidator = assertionToken -> { - ValidationContext context = createValidationContext(assertionToken); - return createDefaultAssertionValidator(context).convert(assertionToken); - }; + private Converter assertionSignatureValidator = + createDefaultAssertionValidator(INVALID_SIGNATURE, + assertionToken -> { + SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(assertionToken.token); + return SAML20AssertionValidators.createSignatureValidator(engine); + }, + assertionToken -> + new ValidationContext(Collections.singletonMap(SIGNATURE_REQUIRED, false)) + ); + + private Converter assertionValidator = + createDefaultAssertionValidator(INVALID_ASSERTION, + assertionToken -> SAML20AssertionValidators.attributeValidator, + assertionToken -> createValidationContext( + assertionToken, + params -> params.put(CLOCK_SKEW, this.responseTimeValidationSkew.toMillis()) + )); private Converter signatureTrustEngineConverter = new SignatureTrustEngineConverter(); @@ -220,34 +234,40 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML 2.0 Response. * * You can still invoke the default validator by delgating to - * {@link #createDefaultAssertionValidator(ValidationContext)}, like so: + * {@link #createDefaultAssertionValidator}, like so: * *
 	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
 	 *  provider.setAssertionValidator(assertionToken -> {
-	 *		ValidationContext context = // ... build using authentication token
-	 *		Saml2ResponseValidatorResult result = createDefaultAssertionValidator(context)
+	 *		Saml2ResponseValidatorResult result = createDefaultAssertionValidator()
 	 *			.convert(assertionToken)
-	 *		return result.concat(myCustomValiator.convert(assertionToken));
+	 *		return result.concat(myCustomValidator.convert(assertionToken));
 	 *  });
 	 * 
* - * Consider taking a look at {@link #createValidationContext(AssertionToken)} to see how it - * constructs a {@link ValidationContext}. - * * You can also use this method to configure the provider to use a different * {@link ValidationContext} from the default, like so: * *
 	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
-	 *	ValidationContext context = // ...
-	 *	provider.setAssertionValidator(createDefaultAssertionValidator(context));
+	 *	provider.setAssertionValidator(
+	 *		createDefaultAssertionValidator(assertionToken -> {
+	 *			Map<String, Object> params = new HashMap<>();
+	 *			params.put(CLOCK_SKEW, 2 * 60 * 1000);
+	 *			// other parameters
+	 *			return new ValidationContext(params);
+	 *		}));
 	 * 
* + * Consider taking a look at {@link #createValidationContext} to see how it + * constructs a {@link ValidationContext}. + * * It is not necessary to delegate to the default validator. You can safely replace it * entirely with your own. Note that signature verification is performed as a separate * step from this validator. * + * This method takes precedence over {@link #setResponseTimeValidationSkew}. + * * @param assertionValidator * @since 5.4 */ @@ -314,11 +334,45 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi * Sets the duration for how much time skew an assertion may tolerate during * timestamp, NotOnOrBefore and NotOnOrAfter, validation. * @param responseTimeValidationSkew duration for skew tolerance + * @deprecated Use {@link #setAssertionValidator(Converter)} instead */ public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) { this.responseTimeValidationSkew = responseTimeValidationSkew; } + + /** + * Construct a default strategy for validating each SAML 2.0 Assertion and + * associated {@link Authentication} token + * + * @return the default assertion validator strategy + * @since 5.4 + */ + public static Converter + createDefaultAssertionValidator() { + + return createDefaultAssertionValidator(INVALID_ASSERTION, + assertionToken -> SAML20AssertionValidators.attributeValidator, + assertionToken -> createValidationContext(assertionToken, params -> {})); + } + + /** + * Construct a default strategy for validating each SAML 2.0 Assertion and + * associated {@link Authentication} token + * + * @return the default assertion validator strategy + * @param contextConverter the conversion strategy to use to generate a {@link ValidationContext} + * for each assertion being validated + * @since 5.4 + */ + public static Converter + createDefaultAssertionValidator(Converter contextConverter) { + + return createDefaultAssertionValidator(INVALID_ASSERTION, + assertionToken -> SAML20AssertionValidators.attributeValidator, + contextConverter); + } + /** * Construct a default strategy for converting a SAML 2.0 Response and {@link Authentication} * token into a {@link Saml2Authentication} @@ -501,19 +555,13 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi logger.debug("Validating " + assertions.size() + " assertions"); } - ValidationContext signatureContext = new ValidationContext - (Collections.singletonMap(SIGNATURE_REQUIRED, false)); // check already performed - SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(token); - Converter signatureValidator = - createDefaultAssertionValidator(INVALID_SIGNATURE, - SAML20AssertionValidators.createSignatureValidator(engine), signatureContext); for (Assertion assertion : assertions) { if (logger.isTraceEnabled()) { logger.trace("Validating assertion " + assertion.getID()); } AssertionToken assertionToken = new AssertionToken(assertion, token); result = result - .concat(signatureValidator.convert(assertionToken)) + .concat(this.assertionSignatureValidator.convert(assertionToken)) .concat(this.assertionValidator.convert(assertionToken)); } @@ -613,18 +661,15 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } } - public static Converter - createDefaultAssertionValidator(ValidationContext context) { - - return createDefaultAssertionValidator(INVALID_ASSERTION, - SAML20AssertionValidators.createAttributeValidator(), context); - } - - private static Converter - createDefaultAssertionValidator(String errorCode, SAML20AssertionValidator validator, ValidationContext context) { + private static Converter createDefaultAssertionValidator( + String errorCode, + Converter validatorConverter, + Converter contextConverter) { return assertionToken -> { Assertion assertion = assertionToken.assertion; + SAML20AssertionValidator validator = validatorConverter.convert(assertionToken); + ValidationContext context = contextConverter.convert(assertionToken); try { ValidationResult result = validator.validate(assertion, context); if (result == ValidationResult.VALID) { @@ -643,13 +688,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi }; } - private ValidationContext createValidationContext(AssertionToken assertionToken) { + private static ValidationContext createValidationContext( + AssertionToken assertionToken, Consumer> paramsConsumer) { String audience = assertionToken.token.getRelyingPartyRegistration().getEntityId(); String recipient = assertionToken.token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); Map params = new HashMap<>(); - params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis()); params.put(COND_VALID_AUDIENCES, singleton(audience)); params.put(SC_VALID_RECIPIENTS, singleton(recipient)); + paramsConsumer.accept(params); return new ValidationContext(params); } @@ -687,15 +733,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi }); } - static SAML20AssertionValidator createAttributeValidator() { - return new SAML20AssertionValidator(conditions, subjects, statements, null, null) { + private static final SAML20AssertionValidator attributeValidator = + new SAML20AssertionValidator(conditions, subjects, statements, null, null) { @Nonnull @Override protected ValidationResult validateSignature(Assertion token, ValidationContext context) { return ValidationResult.VALID; } }; - } static SAML20AssertionValidator createSignatureValidator(SignatureTrustEngine engine) { return new SAML20AssertionValidator(new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), @@ -792,7 +837,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi private final Saml2AuthenticationToken token; private final Assertion assertion; - private AssertionToken(Assertion assertion, Saml2AuthenticationToken token) { + AssertionToken(Assertion assertion, Saml2AuthenticationToken token) { this.token = token; this.assertion = assertion; } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index ab60156562..ecf2f785f7 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -19,7 +19,6 @@ package org.springframework.security.saml2.provider.service.authentication; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; -import java.io.StringReader; import java.time.Instant; import java.util.Arrays; import java.util.Collections; @@ -28,8 +27,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import javax.xml.namespace.QName; -import javax.xml.parsers.DocumentBuilder; -import javax.xml.parsers.DocumentBuilderFactory; import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.hamcrest.BaseMatcher; @@ -51,9 +48,7 @@ import org.opensaml.saml.saml2.core.EncryptedID; import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.Response; -import org.w3c.dom.Document; import org.w3c.dom.Element; -import org.xml.sax.InputSource; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.Authentication; @@ -74,7 +69,6 @@ import static org.mockito.Mockito.when; import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory; import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory; import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS; -import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED; import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION; import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE; import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.success; @@ -350,14 +344,23 @@ public class OpenSamlAuthenticationProviderTests { objectOutputStream.flush(); } + @Test + public void createDefaultAssertionValidatorWhenAssertionThenValidates() { + Response response = signedResponseWithOneAssertion(); + Assertion assertion = response.getAssertions().get(0); + OpenSamlAuthenticationProvider.AssertionToken assertionToken = + new OpenSamlAuthenticationProvider.AssertionToken(assertion, token()); + assertThat( + createDefaultAssertionValidator().convert(assertionToken) + .hasErrors()).isFalse(); + } + @Test public void authenticateWhenDelegatingToDefaultAssertionValidatorThenUses() { OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setAssertionValidator(assertionToken -> { - ValidationContext context = new ValidationContext(); - return createDefaultAssertionValidator(context).convert(assertionToken) - .concat(new Saml2Error("wrong error", "wrong error")); - }); + provider.setAssertionValidator(assertionToken -> + createDefaultAssertionValidator(token -> new ValidationContext()).convert(assertionToken) + .concat(new Saml2Error("wrong error", "wrong error"))); Response response = response(); Assertion assertion = assertion(); OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME); @@ -375,12 +378,9 @@ public class OpenSamlAuthenticationProviderTests { Converter validator = mock(Converter.class); OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setAssertionValidator(assertionToken -> { - ValidationContext context = new ValidationContext( - Collections.singletonMap(SC_VALID_RECIPIENTS, singleton(DESTINATION))); - return createDefaultAssertionValidator(context).convert(assertionToken) - .concat(validator.convert(assertionToken)); - }); + provider.setAssertionValidator(assertionToken -> + createDefaultAssertionValidator().convert(assertionToken) + .concat(validator.convert(assertionToken))); Response response = response(); Assertion assertion = assertion(); response.getAssertions().add(assertion); @@ -410,18 +410,19 @@ public class OpenSamlAuthenticationProviderTests { @Test public void authenticateWhenValidationContextCustomizedThenUsers() { Map parameters = new HashMap<>(); - parameters.put(SC_VALID_RECIPIENTS, singleton(DESTINATION)); - parameters.put(SIGNATURE_REQUIRED, false); + parameters.put(SC_VALID_RECIPIENTS, singleton("blah")); ValidationContext context = mock(ValidationContext.class); when(context.getStaticParameters()).thenReturn(parameters); OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setAssertionValidator(assertionToken -> createDefaultAssertionValidator(context).convert(assertionToken)); + provider.setAssertionValidator(createDefaultAssertionValidator(assertionToken -> context)); Response response = response(); Assertion assertion = assertion(); response.getAssertions().add(assertion); signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID); Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); - provider.authenticate(token); + assertThatThrownBy(() -> provider.authenticate(token)) + .isInstanceOf(Saml2AuthenticationException.class) + .hasMessageContaining("Invalid assertion"); verify(context, atLeastOnce()).getStaticParameters(); } @@ -506,6 +507,10 @@ public class OpenSamlAuthenticationProviderTests { }; } + private Saml2AuthenticationToken token() { + return token(response(), relyingPartyVerifyingCredential()); + } + private Saml2AuthenticationToken token(Response response, Saml2X509Credential... credentials) { String payload = serialize(response); return token(payload, credentials); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java index 1df38bd978..3032ae3af6 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java @@ -19,7 +19,10 @@ package org.springframework.security.saml2.provider.service.authentication; import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.UUID; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; @@ -45,6 +48,7 @@ import org.opensaml.core.xml.schema.impl.XSStringBuilder; import org.opensaml.core.xml.schema.impl.XSURIBuilder; import org.opensaml.saml.common.SAMLVersion; import org.opensaml.saml.common.SignableSAMLObject; +import org.opensaml.saml.common.assertion.ValidationContext; import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.Attribute; import org.opensaml.saml.saml2.core.AttributeStatement; @@ -79,6 +83,7 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2X509Credential; import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory; +import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS; import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential; final class TestOpenSamlObjects { @@ -371,6 +376,12 @@ final class TestOpenSamlObjects { return attributeStatements; } + static ValidationContext validationContext() { + Map params = new HashMap<>(); + params.put(SC_VALID_RECIPIENTS, Collections.singleton(DESTINATION)); + return new ValidationContext(params); + } + static T build(QName qName) { return (T) getBuilderFactory().getBuilder(qName).buildObject(qName); }