@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2019 the original author or authors .
* Copyright 2002 - 2020 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -16,12 +16,16 @@
@@ -16,12 +16,16 @@
package org.springframework.security.config.annotation.web.configurers.saml2 ;
import java.io.ByteArrayOutputStream ;
import java.io.IOException ;
import java.net.URLDecoder ;
import java.time.Duration ;
import java.util.Arrays ;
import java.util.Base64 ;
import java.util.Collection ;
import java.util.Collections ;
import java.util.zip.Inflater ;
import java.util.zip.InflaterOutputStream ;
import javax.servlet.ServletException ;
import javax.servlet.http.HttpServletRequest ;
@ -54,9 +58,12 @@ import org.springframework.security.core.AuthenticationException;
@@ -54,9 +58,12 @@ import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority ;
import org.springframework.security.core.authority.SimpleGrantedAuthority ;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper ;
import org.springframework.security.saml2.Saml2Exception ;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider ;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory ;
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication ;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext ;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory ;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken ;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration ;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository ;
@ -69,7 +76,11 @@ import org.springframework.security.web.context.HttpSessionSecurityContextReposi
@@ -69,7 +76,11 @@ import org.springframework.security.web.context.HttpSessionSecurityContextReposi
import org.springframework.security.web.context.SecurityContextRepository ;
import org.springframework.test.util.ReflectionTestUtils ;
import org.springframework.test.web.servlet.MockMvc ;
import org.springframework.test.web.servlet.MvcResult ;
import org.springframework.web.util.UriComponents ;
import org.springframework.web.util.UriComponentsBuilder ;
import static java.nio.charset.StandardCharsets.UTF_8 ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.ArgumentMatchers.anyString ;
@ -157,6 +168,20 @@ public class Saml2LoginConfigurerTests {
@@ -157,6 +168,20 @@ public class Saml2LoginConfigurerTests {
verify ( resolver ) . resolve ( any ( HttpServletRequest . class ) , any ( RelyingPartyRegistration . class ) ) ;
}
@Test
public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses ( ) throws Exception {
this . spring . register ( CustomAuthnRequestConsumerResolver . class ) . autowire ( ) ;
MvcResult result = this . mvc . perform ( get ( "/saml2/authenticate/registration-id" ) )
. andReturn ( ) ;
UriComponents components = UriComponentsBuilder
. fromHttpUrl ( result . getResponse ( ) . getRedirectedUrl ( ) ) . build ( ) ;
String samlRequest = components . getQueryParams ( ) . getFirst ( "SAMLRequest" ) ;
String decoded = URLDecoder . decode ( samlRequest , "UTF-8" ) ;
String inflated = samlInflate ( samlDecode ( decoded ) ) ;
assertThat ( inflated ) . contains ( "ForceAuthn=\"true\"" ) ;
}
private void validateSaml2WebSsoAuthenticationFilterConfiguration ( ) {
// get the OpenSamlAuthenticationProvider
Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter ( this . springSecurityFilterChain ) ;
@ -275,6 +300,29 @@ public class Saml2LoginConfigurerTests {
@@ -275,6 +300,29 @@ public class Saml2LoginConfigurerTests {
}
}
@EnableWebSecurity
@Import ( Saml2LoginConfigBeans . class )
static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter {
@Override
protected void configure ( HttpSecurity http ) throws Exception {
http
. authorizeRequests ( authz - > authz
. anyRequest ( ) . authenticated ( )
)
. saml2Login ( saml2 - > { } ) ;
}
@Bean
Saml2AuthenticationRequestFactory authenticationRequestFactory ( ) {
OpenSamlAuthenticationRequestFactory authenticationRequestFactory =
new OpenSamlAuthenticationRequestFactory ( ) ;
authenticationRequestFactory . setAuthnRequestConsumerResolver (
context - > authnRequest - > authnRequest . setForceAuthn ( true ) ) ;
return authenticationRequestFactory ;
}
}
private static AuthenticationManager getAuthenticationManagerMock ( String role ) {
return new AuthenticationManager ( ) {
@ -315,4 +363,23 @@ public class Saml2LoginConfigurerTests {
@@ -315,4 +363,23 @@ public class Saml2LoginConfigurerTests {
}
}
private static org . apache . commons . codec . binary . Base64 BASE64 =
new org . apache . commons . codec . binary . Base64 ( 0 , new byte [ ] { '\n' } ) ;
private static byte [ ] samlDecode ( String s ) {
return BASE64 . decode ( s ) ;
}
private static String samlInflate ( byte [ ] b ) {
try {
ByteArrayOutputStream out = new ByteArrayOutputStream ( ) ;
InflaterOutputStream iout = new InflaterOutputStream ( out , new Inflater ( true ) ) ;
iout . write ( b ) ;
iout . finish ( ) ;
return new String ( out . toByteArray ( ) , UTF_8 ) ;
}
catch ( IOException e ) {
throw new Saml2Exception ( "Unable to inflate string" , e ) ;
}
}
}