@ -23,6 +23,7 @@ import java.util.Base64;
@@ -23,6 +23,7 @@ import java.util.Base64;
import java.util.Collection ;
import java.util.Collections ;
import javax.servlet.ServletException ;
import javax.servlet.http.HttpServletRequest ;
import org.junit.After ;
import org.junit.Assert ;
@ -55,9 +56,13 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
@@ -55,9 +56,13 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper ;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider ;
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication ;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext ;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken ;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration ;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository ;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter ;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter ;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver ;
import org.springframework.security.web.FilterChainProxy ;
import org.springframework.security.web.context.HttpRequestResponseHolder ;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository ;
@ -66,10 +71,15 @@ import org.springframework.test.util.ReflectionTestUtils;
@@ -66,10 +71,15 @@ import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.ArgumentMatchers.anyString ;
import static org.mockito.Mockito.mock ;
import static org.mockito.Mockito.verify ;
import static org.mockito.Mockito.when ;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext ;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration ;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get ;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status ;
/ * *
* Tests for different Java configuration for { @link Saml2LoginConfigurer }
@ -133,6 +143,20 @@ public class Saml2LoginConfigurerTests {
@@ -133,6 +143,20 @@ public class Saml2LoginConfigurerTests {
validateSaml2WebSsoAuthenticationFilterConfiguration ( ) ;
}
@Test
public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses ( ) throws Exception {
this . spring . register ( CustomAuthenticationRequestContextResolver . class ) . autowire ( ) ;
Saml2AuthenticationRequestContext context = authenticationRequestContext ( ) . build ( ) ;
Saml2AuthenticationRequestContextResolver resolver =
CustomAuthenticationRequestContextResolver . resolver ;
when ( resolver . resolve ( any ( HttpServletRequest . class ) , any ( RelyingPartyRegistration . class ) ) )
. thenReturn ( context ) ;
this . mvc . perform ( get ( "/saml2/authenticate/registration-id" ) )
. andExpect ( status ( ) . isFound ( ) ) ;
verify ( resolver ) . resolve ( any ( HttpServletRequest . class ) , any ( RelyingPartyRegistration . class ) ) ;
}
private void validateSaml2WebSsoAuthenticationFilterConfiguration ( ) {
// get the OpenSamlAuthenticationProvider
Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter ( this . springSecurityFilterChain ) ;
@ -219,6 +243,38 @@ public class Saml2LoginConfigurerTests {
@@ -219,6 +243,38 @@ public class Saml2LoginConfigurerTests {
}
}
@EnableWebSecurity
@Import ( Saml2LoginConfigBeans . class )
static class CustomAuthenticationRequestContextResolver extends WebSecurityConfigurerAdapter {
private static final Saml2AuthenticationRequestContextResolver resolver =
mock ( Saml2AuthenticationRequestContextResolver . class ) ;
@Override
protected void configure ( HttpSecurity http ) throws Exception {
ObjectPostProcessor < Saml2WebSsoAuthenticationRequestFilter > processor
= new ObjectPostProcessor < Saml2WebSsoAuthenticationRequestFilter > ( ) {
@Override
public < O extends Saml2WebSsoAuthenticationRequestFilter > O postProcess ( O filter ) {
filter . setAuthenticationRequestContextResolver ( resolver ) ;
return filter ;
}
} ;
http
. authorizeRequests ( authz - > authz
. anyRequest ( ) . authenticated ( )
)
. saml2Login ( saml2 - > saml2
. addObjectPostProcessor ( processor )
) ;
}
@Bean
Saml2AuthenticationRequestContextResolver resolver ( ) {
return resolver ;
}
}
private static AuthenticationManager getAuthenticationManagerMock ( String role ) {
return new AuthenticationManager ( ) {