@ -1,5 +1,5 @@
/ *
/ *
* Copyright 2002 - 2021 the original author or authors .
* Copyright 2002 - 2022 the original author or authors .
*
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
* you may not use this file except in compliance with the License .
@ -30,6 +30,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration ;
import org.springframework.context.annotation.Configuration ;
import org.springframework.http.HttpMethod ;
import org.springframework.http.HttpMethod ;
import org.springframework.mock.web.MockHttpSession ;
import org.springframework.mock.web.MockHttpSession ;
import org.springframework.security.config.Customizer ;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder ;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder ;
import org.springframework.security.config.annotation.web.builders.HttpSecurity ;
import org.springframework.security.config.annotation.web.builders.HttpSecurity ;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity ;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity ;
@ -38,9 +39,12 @@ import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension ;
import org.springframework.security.config.test.SpringTestContextExtension ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.core.userdetails.PasswordEncodedUser ;
import org.springframework.security.core.userdetails.PasswordEncodedUser ;
import org.springframework.security.web.SecurityFilterChain ;
import org.springframework.security.web.access.AccessDeniedHandler ;
import org.springframework.security.web.access.AccessDeniedHandler ;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy ;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy ;
import org.springframework.security.web.csrf.CsrfToken ;
import org.springframework.security.web.csrf.CsrfTokenRepository ;
import org.springframework.security.web.csrf.CsrfTokenRepository ;
import org.springframework.security.web.csrf.CsrfTokenRequestProcessor ;
import org.springframework.security.web.csrf.DefaultCsrfToken ;
import org.springframework.security.web.csrf.DefaultCsrfToken ;
import org.springframework.security.web.firewall.StrictHttpFirewall ;
import org.springframework.security.web.firewall.StrictHttpFirewall ;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher ;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher ;
@ -55,12 +59,16 @@ import org.springframework.web.servlet.support.RequestDataValueProcessor;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType ;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType ;
import static org.hamcrest.Matchers.containsString ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.ArgumentMatchers.eq ;
import static org.mockito.ArgumentMatchers.isNull ;
import static org.mockito.ArgumentMatchers.isNull ;
import static org.mockito.BDDMockito.given ;
import static org.mockito.BDDMockito.given ;
import static org.mockito.Mockito.atLeastOnce ;
import static org.mockito.Mockito.atLeastOnce ;
import static org.mockito.Mockito.mock ;
import static org.mockito.Mockito.mock ;
import static org.mockito.Mockito.times ;
import static org.mockito.Mockito.verify ;
import static org.mockito.Mockito.verify ;
import static org.mockito.Mockito.verifyNoMoreInteractions ;
import static org.springframework.security.config.Customizer.withDefaults ;
import static org.springframework.security.config.Customizer.withDefaults ;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf ;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf ;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user ;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user ;
@ -74,6 +82,7 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post ;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post ;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put ;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put ;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.request ;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.request ;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content ;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl ;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl ;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status ;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status ;
@ -84,6 +93,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
* @author Eleftheria Stein
* @author Eleftheria Stein
* @author Michael Vitz
* @author Michael Vitz
* @author Sam Simmons
* @author Sam Simmons
* @author Steve Riesenberg
* /
* /
@ExtendWith ( SpringTestContextExtension . class )
@ExtendWith ( SpringTestContextExtension . class )
public class CsrfConfigurerTests {
public class CsrfConfigurerTests {
@ -407,6 +417,47 @@ public class CsrfConfigurerTests {
any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
}
}
@Test
public void getLoginWhenCsrfTokenRequestProcessorSetThenRespondsWithNormalCsrfToken ( ) throws Exception {
CsrfTokenRepository csrfTokenRepository = mock ( CsrfTokenRepository . class ) ;
CsrfToken csrfToken = new DefaultCsrfToken ( "X-CSRF-TOKEN" , "_csrf" , "token" ) ;
given ( csrfTokenRepository . generateToken ( any ( HttpServletRequest . class ) ) ) . willReturn ( csrfToken ) ;
CsrfTokenRequestProcessorConfig . REPO = csrfTokenRepository ;
CsrfTokenRequestProcessorConfig . PROCESSOR = new CsrfTokenRequestProcessor ( ) ;
this . spring . register ( CsrfTokenRequestProcessorConfig . class , BasicController . class ) . autowire ( ) ;
this . mvc . perform ( get ( "/login" ) ) . andExpect ( status ( ) . isOk ( ) )
. andExpect ( content ( ) . string ( containsString ( csrfToken . getToken ( ) ) ) ) ;
verify ( csrfTokenRepository ) . loadToken ( any ( HttpServletRequest . class ) ) ;
verify ( csrfTokenRepository ) . generateToken ( any ( HttpServletRequest . class ) ) ;
verify ( csrfTokenRepository ) . saveToken ( eq ( csrfToken ) , any ( HttpServletRequest . class ) ,
any ( HttpServletResponse . class ) ) ;
verifyNoMoreInteractions ( csrfTokenRepository ) ;
}
@Test
public void loginWhenCsrfTokenRequestProcessorSetAndNormalCsrfTokenThenSuccess ( ) throws Exception {
CsrfToken csrfToken = new DefaultCsrfToken ( "X-CSRF-TOKEN" , "_csrf" , "token" ) ;
CsrfTokenRepository csrfTokenRepository = mock ( CsrfTokenRepository . class ) ;
given ( csrfTokenRepository . loadToken ( any ( HttpServletRequest . class ) ) ) . willReturn ( csrfToken ) ;
given ( csrfTokenRepository . generateToken ( any ( HttpServletRequest . class ) ) ) . willReturn ( csrfToken ) ;
CsrfTokenRequestProcessorConfig . REPO = csrfTokenRepository ;
CsrfTokenRequestProcessorConfig . PROCESSOR = new CsrfTokenRequestProcessor ( ) ;
this . spring . register ( CsrfTokenRequestProcessorConfig . class , BasicController . class ) . autowire ( ) ;
// @formatter:off
MockHttpServletRequestBuilder loginRequest = post ( "/login" )
. header ( csrfToken . getHeaderName ( ) , csrfToken . getToken ( ) )
. param ( "username" , "user" )
. param ( "password" , "password" ) ;
// @formatter:on
this . mvc . perform ( loginRequest ) . andExpect ( redirectedUrl ( "/" ) ) ;
verify ( csrfTokenRepository , times ( 2 ) ) . loadToken ( any ( HttpServletRequest . class ) ) ;
verify ( csrfTokenRepository ) . saveToken ( isNull ( ) , any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
verify ( csrfTokenRepository ) . generateToken ( any ( HttpServletRequest . class ) ) ;
verify ( csrfTokenRepository ) . saveToken ( eq ( csrfToken ) , any ( HttpServletRequest . class ) ,
any ( HttpServletResponse . class ) ) ;
verifyNoMoreInteractions ( csrfTokenRepository ) ;
}
@Configuration
@Configuration
static class AllowHttpMethodsFirewallConfig {
static class AllowHttpMethodsFirewallConfig {
@ -748,6 +799,43 @@ public class CsrfConfigurerTests {
}
}
@Configuration
@EnableWebSecurity
static class CsrfTokenRequestProcessorConfig {
static CsrfTokenRepository REPO ;
static CsrfTokenRequestProcessor PROCESSOR ;
@Bean
SecurityFilterChain securityFilterChain ( HttpSecurity http ) throws Exception {
// @formatter:off
http
. authorizeHttpRequests ( ( authorize ) - > authorize
. anyRequest ( ) . authenticated ( )
)
. formLogin ( Customizer . withDefaults ( ) )
. csrf ( ( csrf ) - > csrf
. csrfTokenRepository ( REPO )
. csrfTokenRequestAttributeHandler ( PROCESSOR )
. csrfTokenRequestResolver ( PROCESSOR )
) ;
// @formatter:on
return http . build ( ) ;
}
@Autowired
void configure ( AuthenticationManagerBuilder auth ) throws Exception {
// @formatter:off
auth
. inMemoryAuthentication ( )
. withUser ( PasswordEncodedUser . user ( ) ) ;
// @formatter:on
}
}
@RestController
@RestController
static class BasicController {
static class BasicController {