From 6e41246a2bb01de4380d45d15cc991de1f9fe2c9 Mon Sep 17 00:00:00 2001 From: Han YanJing Date: Sat, 27 Feb 2021 16:37:33 +0800 Subject: [PATCH] Throw Saml2AuthenticationException Closes gh-9310 --- .../saml2/Saml2LoginConfigurerTests.java | 41 ++++++++++++++++++- .../Saml2AuthenticationTokenConverter.java | 19 ++++++--- ...aml2AuthenticationTokenConverterTests.java | 40 +++++++++++++++++- 3 files changed, 93 insertions(+), 7 deletions(-) diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index bad65f790c..be71694088 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,12 +30,14 @@ import java.util.zip.InflaterOutputStream; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.AuthnRequest; @@ -62,10 +64,13 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.core.Saml2Utils; import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; @@ -78,6 +83,7 @@ import org.springframework.security.saml2.provider.service.servlet.filter.Saml2W import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; @@ -210,6 +216,24 @@ public class Saml2LoginConfigurerTests { verify(CustomAuthenticationConverter.authenticationConverter).convert(any(HttpServletRequest.class)); } + @Test + public void authenticateWithInvalidDeflatedSAMLResponseThenFailureHandlerUses() throws Exception { + this.spring.register(CustomAuthenticationFailureHandler.class).autowire(); + byte[] invalidDeflated = "invalid".getBytes(); + String encoded = Saml2Utils.samlEncode(invalidDeflated); + MockHttpServletRequestBuilder request = get("/login/saml2/sso/registration-id").queryParam("SAMLResponse", + encoded); + this.mvc.perform(request); + ArgumentCaptor captor = ArgumentCaptor + .forClass(Saml2AuthenticationException.class); + verify(CustomAuthenticationFailureHandler.authenticationFailureHandler).onAuthenticationFailure( + any(HttpServletRequest.class), any(HttpServletResponse.class), captor.capture()); + Saml2AuthenticationException exception = captor.getValue(); + assertThat(exception.getSaml2Error().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE); + assertThat(exception.getSaml2Error().getDescription()).isEqualTo("Unable to inflate string"); + assertThat(exception.getCause()).isInstanceOf(IOException.class); + } + private void validateSaml2WebSsoAuthenticationFilterConfiguration() { // get the OpenSamlAuthenticationProvider Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain); @@ -314,6 +338,21 @@ public class Saml2LoginConfigurerTests { } + @EnableWebSecurity + @Import(Saml2LoginConfigBeans.class) + static class CustomAuthenticationFailureHandler extends WebSecurityConfigurerAdapter { + + static final AuthenticationFailureHandler authenticationFailureHandler = mock( + AuthenticationFailureHandler.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + http.authorizeRequests((authz) -> authz.anyRequest().authenticated()) + .saml2Login((saml2) -> saml2.failureHandler(authenticationFailureHandler)); + } + + } + @EnableWebSecurity @Import(Saml2LoginConfigBeans.class) static class CustomAuthenticationRequestContextResolver extends WebSecurityConfigurerAdapter { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java index bcce7e6fa8..9a43a880cf 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,9 @@ import org.apache.commons.codec.binary.Base64; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpMethod; -import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.web.authentication.AuthenticationConverter; @@ -83,7 +85,13 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo } private byte[] samlDecode(String s) { - return BASE64.decode(s); + try { + return BASE64.decode(s); + } + catch (Exception ex) { + throw new Saml2AuthenticationException( + new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Failed to decode SAMLResponse"), ex); + } } private String samlInflate(byte[] b) { @@ -94,8 +102,9 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo inflaterOutputStream.finish(); return new String(out.toByteArray(), StandardCharsets.UTF_8); } - catch (IOException ex) { - throw new Saml2Exception("Unable to inflate string", ex); + catch (Exception ex) { + throw new Saml2AuthenticationException( + new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Unable to inflate string"), ex); } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java index 91038de6ae..4a5e9e2fea 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,9 @@ import org.mockito.junit.MockitoJUnitRunner; import org.springframework.core.convert.converter.Converter; import org.springframework.core.io.ClassPathResource; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2Utils; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; @@ -37,6 +39,7 @@ import org.springframework.util.StreamUtils; import org.springframework.web.util.UriUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; @@ -64,6 +67,22 @@ public class Saml2AuthenticationTokenConverterTests { .isEqualTo(this.relyingPartyRegistration.getRegistrationId()); } + @Test + public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() { + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( + this.relyingPartyRegistrationResolver); + given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) + .willReturn(this.relyingPartyRegistration); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter("SAMLResponse", "invalid"); + assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request)) + .withCauseInstanceOf(IllegalArgumentException.class) + .satisfies((ex) -> assertThat(ex.getSaml2Error().getErrorCode()) + .isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE)) + .satisfies((ex) -> assertThat(ex.getSaml2Error().getDescription()) + .isEqualTo("Failed to decode SAMLResponse")); + } + @Test public void convertWhenNoSamlResponseThenNull() { Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( @@ -100,6 +119,25 @@ public class Saml2AuthenticationTokenConverterTests { .isEqualTo(this.relyingPartyRegistration.getRegistrationId()); } + @Test + public void convertWhenGetRequestInvalidDeflatedThenSaml2AuthenticationException() { + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( + this.relyingPartyRegistrationResolver); + given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) + .willReturn(this.relyingPartyRegistration); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod("GET"); + byte[] invalidDeflated = "invalid".getBytes(); + String encoded = Saml2Utils.samlEncode(invalidDeflated); + request.setParameter("SAMLResponse", encoded); + assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request)) + .withCauseInstanceOf(IOException.class) + .satisfies((ex) -> assertThat(ex.getSaml2Error().getErrorCode()) + .isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE)) + .satisfies( + (ex) -> assertThat(ex.getSaml2Error().getDescription()).isEqualTo("Unable to inflate string")); + } + @Test public void constructorWhenResolverIsNullThenIllegalArgument() { assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null));