diff --git a/docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc b/docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc index 0bfe7b8a14..bef6ae4759 100644 --- a/docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc +++ b/docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc @@ -1271,8 +1271,29 @@ It's not required to call `OpenSaml4AuthenticationProvider` 's default authentic It returns a `Saml2AuthenticatedPrincipal` containing the attributes it extracted from `AttributeStatement` s as well as the single `ROLE_USER` authority. [[servlet-saml2login-opensamlauthenticationprovider-additionalvalidation]] -==== Performing Additional Validation +==== Performing Additional Response Validation +`OpenSaml4AuthenticationProvider` validates the `Issuer` and `Destination` values right after decrypting the `Response`. +You can customize the validation by extending the default validator concatenating with your own response validator, or you can replace it entirely with yours. + +For example, you can throw a custom exception with any additional information available in the `Response` object, like so: +[source,java] +---- +OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider(); +provider.setResponseValidator((responseToken) -> { + Saml2ResponseValidatorResult result = OpenSamlAuthenticationProvider + .createDefaultResponseValidator() + .convert(responseToken) + .concat(myCustomValidator.convert(responseToken)); + if (!result.getErrors().isEmpty()) { + String inResponseTo = responseToken.getInResponseTo(); + throw new CustomSaml2AuthenticationException(result, inResponseTo); + } + return result; +}); +---- + +==== Performing Additional Assertion Validation `OpenSaml4AuthenticationProvider` performs minimal validation on SAML 2.0 Assertions. After verifying the signature, it will: diff --git a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java index 08f53a82d1..74f5cde0d6 100644 --- a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java @@ -145,7 +145,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv private Consumer responseElementsDecrypter = createDefaultResponseElementsDecrypter(); - private final Converter responseValidator = createDefaultResponseValidator(); + private Converter responseValidator = createDefaultResponseValidator(); private final Converter assertionSignatureValidator = createDefaultAssertionSignatureValidator(); @@ -213,6 +213,28 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv this.responseElementsDecrypter = responseElementsDecrypter; } + /** + * Set the {@link Converter} to use for validating the SAML 2.0 Response. + * + * You can still invoke the default validator by delegating to + * {@link #createDefaultResponseValidator()}, like so: + * + *
+	 * OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider();
+	 * provider.setResponseValidator(responseToken -> {
+	 * 		Saml2ResponseValidatorResult result = createDefaultResponseValidator()
+	 * 			.convert(responseToken)
+	 * 		return result.concat(myCustomValidator.convert(responseToken));
+	 * });
+	 * 
+ * @param responseValidator the {@link Converter} to use + * @since 5.6 + */ + public void setResponseValidator(Converter responseValidator) { + Assert.notNull(responseValidator, "responseValidator cannot be null"); + this.responseValidator = responseValidator; + } + /** * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML * 2.0 Response. @@ -326,6 +348,44 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv this.responseAuthenticationConverter = responseAuthenticationConverter; } + /** + * Construct a default strategy for validating the SAML 2.0 Response + * @return the default response validator strategy + * @since 5.6 + */ + public static Converter createDefaultResponseValidator() { + return (responseToken) -> { + Response response = responseToken.getResponse(); + Saml2AuthenticationToken token = responseToken.getToken(); + Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success(); + String statusCode = getStatusCode(response); + if (!StatusCode.SUCCESS.equals(statusCode)) { + String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode, + response.getID()); + result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message)); + } + String issuer = response.getIssuer().getValue(); + String destination = response.getDestination(); + String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); + if (StringUtils.hasText(destination) && !destination.equals(location)) { + String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + + "]"; + result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message)); + } + String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails() + .getEntityId(); + if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) { + String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID()); + result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message)); + } + if (response.getAssertions().isEmpty()) { + throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, + "No assertions found in response.", null); + } + return result; + }; + } + /** * Construct a default strategy for validating each SAML 2.0 Assertion and associated * {@link Authentication} token @@ -487,40 +547,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv }; } - private Converter createDefaultResponseValidator() { - return (responseToken) -> { - Response response = responseToken.getResponse(); - Saml2AuthenticationToken token = responseToken.getToken(); - Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success(); - String statusCode = getStatusCode(response); - if (!StatusCode.SUCCESS.equals(statusCode)) { - String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode, - response.getID()); - result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message)); - } - String issuer = response.getIssuer().getValue(); - String destination = response.getDestination(); - String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); - if (StringUtils.hasText(destination) && !destination.equals(location)) { - String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() - + "]"; - result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message)); - } - String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails() - .getEntityId(); - if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) { - String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID()); - result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message)); - } - if (response.getAssertions().isEmpty()) { - throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, - "No assertions found in response.", null); - } - return result; - }; - } - - private String getStatusCode(Response response) { + private static String getStatusCode(Response response) { if (response.getStatus() == null) { return StatusCode.SUCCESS; } diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index f57c645e4a..456c6d02ee 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -585,6 +585,34 @@ public class OpenSaml4AuthenticationProviderTests { assertThat(authentication.getName()).isEqualTo("test@saml.user"); } + @Test + public void setResponseValidatorWhenNullThenIllegalArgument() { + assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setResponseValidator(null)); + } + + @Test + public void authenticateWhenCustomResponseValidatorThenUses() { + Converter validator = mock( + Converter.class); + OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider(); + // @formatter:off + provider.setResponseValidator((responseToken) -> OpenSaml4AuthenticationProvider.createDefaultResponseValidator() + .convert(responseToken) + .concat(validator.convert(responseToken)) + ); + // @formatter:on + Response response = response(); + Assertion assertion = assertion(); + response.getAssertions().add(assertion); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + ASSERTING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, verifying(registration())); + given(validator.convert(any(OpenSaml4AuthenticationProvider.ResponseToken.class))) + .willReturn(Saml2ResponseValidatorResult.success()); + provider.authenticate(token); + verify(validator).convert(any(OpenSaml4AuthenticationProvider.ResponseToken.class)); + } + private T build(QName qName) { return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName); }