diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java index f712f76d56..4363129c2d 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java @@ -36,7 +36,7 @@ import org.springframework.security.web.csrf.CsrfAuthenticationStrategy; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfLogoutHandler; import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.CsrfTokenRequestResolver; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.LazyCsrfTokenRepository; @@ -91,7 +91,7 @@ public final class CsrfConfigurer> private SessionAuthenticationStrategy sessionAuthenticationStrategy; - private CsrfTokenRequestAttributeHandler requestAttributeHandler; + private CsrfTokenRequestHandler requestHandler; private CsrfTokenRequestResolver requestResolver; @@ -131,14 +131,13 @@ public final class CsrfConfigurer> } /** - * Specify a {@link CsrfTokenRequestAttributeHandler} to use for making the - * {@code CsrfToken} available as a request attribute. - * @param requestAttributeHandler the {@link CsrfTokenRequestAttributeHandler} to use + * Specify a {@link CsrfTokenRequestHandler} to use for making the {@code CsrfToken} + * available as a request attribute. + * @param requestHandler the {@link CsrfTokenRequestHandler} to use * @return the {@link CsrfConfigurer} for further customizations */ - public CsrfConfigurer csrfTokenRequestAttributeHandler( - CsrfTokenRequestAttributeHandler requestAttributeHandler) { - this.requestAttributeHandler = requestAttributeHandler; + public CsrfConfigurer csrfTokenRequestHandler(CsrfTokenRequestHandler requestHandler) { + this.requestHandler = requestHandler; return this; } @@ -247,8 +246,8 @@ public final class CsrfConfigurer> if (sessionConfigurer != null) { sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy()); } - if (this.requestAttributeHandler != null) { - filter.setRequestAttributeHandler(this.requestAttributeHandler); + if (this.requestHandler != null) { + filter.setRequestHandler(this.requestHandler); } if (this.requestResolver != null) { filter.setRequestResolver(this.requestResolver); @@ -343,8 +342,8 @@ public final class CsrfConfigurer> } CsrfAuthenticationStrategy csrfAuthenticationStrategy = new CsrfAuthenticationStrategy( this.csrfTokenRepository); - if (this.requestAttributeHandler != null) { - csrfAuthenticationStrategy.setRequestAttributeHandler(this.requestAttributeHandler); + if (this.requestHandler != null) { + csrfAuthenticationStrategy.setRequestHandler(this.requestHandler); } return csrfAuthenticationStrategy; } diff --git a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java index e8c92d9db3..49ba28cab6 100644 --- a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java @@ -70,7 +70,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { private static final String ATT_REPOSITORY = "token-repository-ref"; - private static final String ATT_REQUEST_ATTRIBUTE_HANDLER = "request-attribute-handler-ref"; + private static final String ATT_REQUEST_HANDLER = "request-handler-ref"; private static final String ATT_REQUEST_RESOLVER = "request-resolver-ref"; @@ -80,7 +80,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { private String requestMatcherRef; - private String requestAttributeHandlerRef; + private String requestHandlerRef; private String requestResolverRef; @@ -102,7 +102,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { if (element != null) { this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY); this.requestMatcherRef = element.getAttribute(ATT_MATCHER); - this.requestAttributeHandlerRef = element.getAttribute(ATT_REQUEST_ATTRIBUTE_HANDLER); + this.requestHandlerRef = element.getAttribute(ATT_REQUEST_HANDLER); this.requestResolverRef = element.getAttribute(ATT_REQUEST_RESOLVER); } if (!StringUtils.hasText(this.csrfRepositoryRef)) { @@ -119,8 +119,8 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { if (StringUtils.hasText(this.requestMatcherRef)) { builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef); } - if (StringUtils.hasText(this.requestAttributeHandlerRef)) { - builder.addPropertyReference("requestAttributeHandler", this.requestAttributeHandlerRef); + if (StringUtils.hasText(this.requestHandlerRef)) { + builder.addPropertyReference("requestHandler", this.requestHandlerRef); } if (StringUtils.hasText(this.requestResolverRef)) { builder.addPropertyReference("requestResolver", this.requestResolverRef); diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc index dcec72a232..0738bd164e 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc @@ -1152,8 +1152,8 @@ csrf-options.attlist &= ## The CsrfTokenRepository to use. The default is HttpSessionCsrfTokenRepository wrapped by LazyCsrfTokenRepository. attribute token-repository-ref { xsd:token }? csrf-options.attlist &= - ## The CsrfTokenRequestAttributeHandler to use. The default is CsrfTokenRequestProcessor. - attribute request-attribute-handler-ref { xsd:token }? + ## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor. + attribute request-handler-ref { xsd:token }? csrf-options.attlist &= ## The CsrfTokenRequestResolver to use. The default is CsrfTokenRequestProcessor. attribute request-resolver-ref { xsd:token }? diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd index dc2911daac..fbc507bdcf 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd @@ -3256,9 +3256,9 @@ - + - The CsrfTokenRequestAttributeHandler to use. The default is CsrfTokenRequestProcessor. + The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor. diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc b/config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc index f64965cee0..72c51bbd50 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc +++ b/config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc @@ -1124,8 +1124,8 @@ csrf-options.attlist &= ## The CsrfTokenRepository to use. The default is HttpSessionCsrfTokenRepository wrapped by LazyCsrfTokenRepository. attribute token-repository-ref { xsd:token }? csrf-options.attlist &= - ## The CsrfTokenRequestAttributeHandler to use. The default is CsrfTokenRequestProcessor. - attribute request-attribute-handler-ref { xsd:token }? + ## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor. + attribute request-handler-ref { xsd:token }? csrf-options.attlist &= ## The CsrfTokenRequestResolver to use. The default is CsrfTokenRequestProcessor. attribute request-resolver-ref { xsd:token }? diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd b/config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd index 704daae069..6a98bd2b5a 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd +++ b/config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd @@ -3166,9 +3166,9 @@ - + - The CsrfTokenRequestAttributeHandler to use. The default is CsrfTokenRequestProcessor. + The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor. diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java index e0cdbab6c2..92ad4b5e2c 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java @@ -32,8 +32,8 @@ import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.web.DefaultSecurityFilterChain; import org.springframework.security.web.FilterChainProxy; +import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; -import org.springframework.security.web.csrf.LazyCsrfTokenRepository; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.never; @@ -78,8 +78,7 @@ public class DeferHttpSessionJavaConfigTests { @Bean DefaultSecurityFilterChain springSecurity(HttpSecurity http) throws Exception { - LazyCsrfTokenRepository csrfRepository = new LazyCsrfTokenRepository(new HttpSessionCsrfTokenRepository()); - csrfRepository.setDeferLoadToken(true); + CsrfTokenRepository csrfRepository = new HttpSessionCsrfTokenRepository(); // @formatter:off http .authorizeHttpRequests((requests) -> requests diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java index ecdc1706b0..bd0a93b4ae 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java @@ -67,7 +67,6 @@ import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.springframework.security.config.Customizer.withDefaults; @@ -234,8 +233,6 @@ public class CsrfConfigurerTests { this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf()) .session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound()) .andExpect(redirectedUrl(redirectUrl)); - verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce()) - .loadToken(any(HttpServletRequest.class)); } // SEC-2422 @@ -280,12 +277,12 @@ public class CsrfConfigurerTests { } @Test - public void getWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed() throws Exception { + public void postWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed() throws Exception { CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class); given(CsrfTokenRepositoryConfig.REPO.loadToken(any())) .willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")); this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire(); - this.mvc.perform(get("/")).andExpect(status().isOk()); + this.mvc.perform(post("/")); verify(CsrfTokenRepositoryConfig.REPO).loadToken(any(HttpServletRequest.class)); } @@ -322,7 +319,7 @@ public class CsrfConfigurerTests { given(CsrfTokenRepositoryInLambdaConfig.REPO.loadToken(any())) .willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")); this.spring.register(CsrfTokenRepositoryInLambdaConfig.class, BasicController.class).autowire(); - this.mvc.perform(get("/")).andExpect(status().isOk()); + this.mvc.perform(post("/")); verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadToken(any(HttpServletRequest.class)); } @@ -427,8 +424,8 @@ public class CsrfConfigurerTests { CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken); - CsrfTokenRequestProcessorConfig.REPO = csrfTokenRepository; CsrfTokenRequestProcessorConfig.PROCESSOR = new CsrfTokenRequestProcessor(); + CsrfTokenRequestProcessorConfig.PROCESSOR.setTokenRepository(csrfTokenRepository); this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire(); this.mvc.perform(get("/login")).andExpect(status().isOk()) .andExpect(content().string(containsString(csrfToken.getToken()))); @@ -443,10 +440,11 @@ public class CsrfConfigurerTests { public void loginWhenCsrfTokenRequestProcessorSetAndNormalCsrfTokenThenSuccess() throws Exception { CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); - given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(csrfToken); + given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(null, csrfToken); given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken); - CsrfTokenRequestProcessorConfig.REPO = csrfTokenRepository; CsrfTokenRequestProcessorConfig.PROCESSOR = new CsrfTokenRequestProcessor(); + CsrfTokenRequestProcessorConfig.PROCESSOR.setTokenRepository(csrfTokenRepository); + this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire(); // @formatter:off MockHttpServletRequestBuilder loginRequest = post("/login") @@ -455,8 +453,7 @@ public class CsrfConfigurerTests { .param("password", "password"); // @formatter:on this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")); - verify(csrfTokenRepository, times(2)).loadToken(any(HttpServletRequest.class)); - verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(csrfTokenRepository).loadToken(any(HttpServletRequest.class)); verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class)); verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class), any(HttpServletResponse.class)); @@ -826,8 +823,6 @@ public class CsrfConfigurerTests { @EnableWebSecurity static class CsrfTokenRequestProcessorConfig { - static CsrfTokenRepository REPO; - static CsrfTokenRequestProcessor PROCESSOR; @Bean @@ -839,8 +834,7 @@ public class CsrfConfigurerTests { ) .formLogin(Customizer.withDefaults()) .csrf((csrf) -> csrf - .csrfTokenRepository(REPO) - .csrfTokenRequestAttributeHandler(PROCESSOR) + .csrfTokenRequestHandler(PROCESSOR) .csrfTokenRequestResolver(PROCESSOR) ); // @formatter:on diff --git a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java index 9b5579d87c..d017e57a3a 100644 --- a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java @@ -29,6 +29,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpMethod; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.config.test.SpringTestContext; @@ -41,6 +42,7 @@ import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.stereotype.Controller; import org.springframework.test.context.junit.jupiter.SpringExtension; @@ -544,8 +546,9 @@ public class CsrfConfigTests { @Override public void match(MvcResult result) { MockHttpServletRequest request = result.getRequest(); - CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request); - assertThat(token).isNotNull(); + MockHttpServletResponse response = result.getResponse(); + DeferredCsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response); + assertThat(token.isGenerated()).isFalse(); } } @@ -561,7 +564,8 @@ public class CsrfConfigTests { @Override public void match(MvcResult result) throws Exception { MockHttpServletRequest request = result.getRequest(); - CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request); + MockHttpServletResponse response = result.getResponse(); + CsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response).get(); assertThat(token).isNotNull(); assertThat(token.getToken()).isEqualTo(this.token.apply(result)); } diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/CsrfDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/CsrfDslTests.kt index 7a54a1b98e..1db8493faf 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/CsrfDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/CsrfDslTests.kt @@ -41,7 +41,6 @@ import org.springframework.security.web.csrf.DefaultCsrfToken import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository import org.springframework.security.web.util.matcher.AntPathRequestMatcher import org.springframework.test.web.servlet.MockMvc -import org.springframework.test.web.servlet.get import org.springframework.test.web.servlet.post import org.springframework.web.bind.annotation.PostMapping import org.springframework.web.bind.annotation.RestController @@ -125,9 +124,9 @@ class CsrfDslTests { CustomRepositoryConfig.REPO.loadToken(any()) } returns DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token") - this.mockMvc.get("/test1") + this.mockMvc.post("/test1") - verify(exactly = 1) { CustomRepositoryConfig.REPO.loadToken(any()) } + verify(exactly = 1) { CustomRepositoryConfig.REPO.loadToken(any()) } } @Configuration diff --git a/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml index 541f66453f..37950840c8 100644 --- a/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml +++ b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml @@ -23,10 +23,10 @@ http://www.springframework.org/schema/beans https://www.springframework.org/schema/beans/spring-beans.xsd"> - + - diff --git a/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml b/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml index 2e7552dfa9..2a13f5fa4d 100644 --- a/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml +++ b/config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml @@ -40,5 +40,7 @@ + diff --git a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc index 50cd7bdab0..bb146d8e47 100644 --- a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc +++ b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc @@ -774,9 +774,9 @@ It is highly recommended to leave CSRF protection enabled. The CsrfTokenRepository to use. The default is `HttpSessionCsrfTokenRepository`. -[[nsa-csrf-request-attribute-handler-ref]] -* **request-attribute-handler-ref** -The optional `CsrfTokenRequestAttributeHandler` to use. The default is `CsrfTokenRequestProcessor`. +[[nsa-csrf-request-handler-ref]] +* **request-handler-ref** +The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRequestProcessor`. [[nsa-csrf-request-resolver-ref]] * **request-resolver-ref** diff --git a/etc/checkstyle/checkstyle.xml b/etc/checkstyle/checkstyle.xml index 166b84e46f..40bce12ae5 100644 --- a/etc/checkstyle/checkstyle.xml +++ b/etc/checkstyle/checkstyle.xml @@ -17,6 +17,7 @@ + diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java index 31dd06bd28..ac7993d5ac 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java @@ -94,7 +94,8 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfToken; -import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; @@ -508,14 +509,13 @@ public final class SecurityMockMvcRequestPostProcessors { @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); - if (!(repository instanceof TestCsrfTokenRepository)) { - repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository()); - WebTestUtils.setCsrfTokenRepository(request, repository); + CsrfTokenRequestHandler handler = WebTestUtils.getCsrfTokenRequestHandler(request); + if (!(handler instanceof TestCsrfTokenRequestHandler)) { + handler = new TestCsrfTokenRequestHandler(handler); + WebTestUtils.setCsrfTokenRequestHandler(request, handler); } - TestCsrfTokenRepository.enable(request); - CsrfToken token = repository.generateToken(request); - repository.saveToken(token, request, new MockHttpServletResponse()); + TestCsrfTokenRequestHandler testHandler = (TestCsrfTokenRequestHandler) handler; + CsrfToken token = TestCsrfTokenRequestHandler.createTestCsrfToken(request); String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken(); if (this.asHeader) { request.addHeader(token.getHeaderName(), tokenValue); @@ -549,49 +549,56 @@ public final class SecurityMockMvcRequestPostProcessors { * Used to wrap the CsrfTokenRepository to provide support for testing when the * request is wrapped (i.e. Spring Session is in use). */ - static class TestCsrfTokenRepository implements CsrfTokenRepository { + static class TestCsrfTokenRequestHandler implements CsrfTokenRequestHandler { - static final String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".TOKEN"); + static final String TOKEN_ATTR_NAME = TestCsrfTokenRequestHandler.class.getName().concat(".TOKEN"); - static final String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".ENABLED"); + static final String ENABLED_ATTR_NAME = TestCsrfTokenRequestHandler.class.getName().concat(".ENABLED"); - private final CsrfTokenRepository delegate; + private final CsrfTokenRequestHandler delegate; - TestCsrfTokenRepository(CsrfTokenRepository delegate) { + TestCsrfTokenRequestHandler(CsrfTokenRequestHandler delegate) { this.delegate = delegate; } - @Override - public CsrfToken generateToken(HttpServletRequest request) { - return this.delegate.generateToken(request); - } - - @Override - public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) { - if (isEnabled(request)) { - request.setAttribute(TOKEN_ATTR_NAME, token); - } - else { - this.delegate.saveToken(token, request, response); + static CsrfToken createTestCsrfToken(HttpServletRequest request) { + CsrfToken existingToken = getExistingToken(request); + if (existingToken != null) { + return existingToken; } + HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository(); + CsrfToken csrfToken = repository.generateToken(request); + request.setAttribute(ENABLED_ATTR_NAME, true); + request.setAttribute(TOKEN_ATTR_NAME, csrfToken); + return csrfToken; } - @Override - public CsrfToken loadToken(HttpServletRequest request) { - if (isEnabled(request)) { - return (CsrfToken) request.getAttribute(TOKEN_ATTR_NAME); - } - else { - return this.delegate.loadToken(request); - } + private static CsrfToken getExistingToken(HttpServletRequest request) { + Object existingToken = request.getAttribute(TOKEN_ATTR_NAME); + return (CsrfToken) existingToken; } - static void enable(HttpServletRequest request) { - request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE); + boolean isEnabled(HttpServletRequest request) { + return getExistingToken(request) != null; } - boolean isEnabled(HttpServletRequest request) { - return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME)); + @Override + public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) { + request.setAttribute(HttpServletResponse.class.getName(), response); + if (!isEnabled(request)) { + return this.delegate.handle(request, response); + } + return new DeferredCsrfToken() { + @Override + public CsrfToken get() { + return getExistingToken(request); + } + + @Override + public boolean isGenerated() { + return false; + } + }; } } diff --git a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java index 6b065f525d..ffb2c131a2 100644 --- a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java +++ b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java @@ -31,6 +31,8 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; +import org.springframework.security.web.csrf.CsrfTokenRequestProcessor; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.web.context.WebApplicationContext; @@ -46,7 +48,7 @@ public abstract class WebTestUtils { private static final SecurityContextRepository DEFAULT_CONTEXT_REPO = new HttpSessionSecurityContextRepository(); - private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository(); + private static final CsrfTokenRequestProcessor DEFAULT_CSRF_PROCESSOR = new CsrfTokenRequestProcessor(); private WebTestUtils() { } @@ -99,24 +101,24 @@ public abstract class WebTestUtils { * @return the {@link CsrfTokenRepository} for the specified * {@link HttpServletRequest} */ - public static CsrfTokenRepository getCsrfTokenRepository(HttpServletRequest request) { + public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) { CsrfFilter filter = findFilter(request, CsrfFilter.class); if (filter == null) { - return DEFAULT_TOKEN_REPO; + return DEFAULT_CSRF_PROCESSOR; } - return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository"); + return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler"); } /** * Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}. * @param request the {@link HttpServletRequest} to obtain the * {@link CsrfTokenRepository} - * @param repository the {@link CsrfTokenRepository} to set + * @param handler the {@link CsrfTokenRepository} to set */ - public static void setCsrfTokenRepository(HttpServletRequest request, CsrfTokenRepository repository) { + public static void setCsrfTokenRequestHandler(HttpServletRequest request, CsrfTokenRequestHandler handler) { CsrfFilter filter = findFilter(request, CsrfFilter.class); if (filter != null) { - ReflectionTestUtils.setField(filter, "tokenRepository", repository); + ReflectionTestUtils.setField(filter, "requestHandler", handler); } } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java index 9dea5175bf..374aa68414 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java @@ -53,7 +53,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { public void defaults() { MockHttpServletRequest request = formLogin().buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("password")).isEqualTo("password"); assertThat(request.getMethod()).isEqualTo("POST"); @@ -67,7 +67,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret") .buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getMethod()).isEqualTo("POST"); @@ -80,7 +80,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2") .user("username", "admin").password("password", "secret").buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getMethod()).isEqualTo("POST"); diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java index df6e7cfef2..c6856fb821 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java @@ -53,7 +53,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { public void defaults() { MockHttpServletRequest request = logout().buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/logout"); @@ -63,7 +63,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { public void custom() { MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/admin/logout"); @@ -74,7 +74,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2") .buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2"); diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java deleted file mode 100644 index e1ecae23e2..0000000000 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright 2002-2016 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.test.web.servlet.request; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; - -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.context.annotation.Configuration; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.security.config.annotation.web.builders.HttpSecurity; -import org.springframework.security.config.annotation.web.builders.WebSecurity; -import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; -import org.springframework.security.test.web.support.WebTestUtils; -import org.springframework.security.web.csrf.CookieCsrfTokenRepository; -import org.springframework.security.web.csrf.CsrfTokenRepository; -import org.springframework.test.context.ContextConfiguration; -import org.springframework.test.context.junit.jupiter.SpringExtension; -import org.springframework.test.context.web.WebAppConfiguration; -import org.springframework.web.context.WebApplicationContext; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; - -@ExtendWith(SpringExtension.class) -@ContextConfiguration -@WebAppConfiguration -public class SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests { - - @Autowired - private WebApplicationContext wac; - - // SEC-3836 - @Test - public void findCookieCsrfTokenRepository() { - MockHttpServletRequest request = post("/").buildRequest(this.wac.getServletContext()); - CsrfTokenRepository csrfTokenRepository = WebTestUtils.getCsrfTokenRepository(request); - assertThat(csrfTokenRepository).isNotNull(); - assertThat(csrfTokenRepository).isEqualTo(Config.cookieCsrfTokenRepository); - } - - @Configuration - @EnableWebSecurity - static class Config extends WebSecurityConfigurerAdapter { - - static CsrfTokenRepository cookieCsrfTokenRepository = new CookieCsrfTokenRepository(); - - @Override - protected void configure(HttpSecurity http) throws Exception { - http.csrf().csrfTokenRepository(cookieCsrfTokenRepository); - } - - @Override - public void configure(WebSecurity web) { - // Enable the DebugFilter - web.debug(true); - } - - } - -} diff --git a/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java b/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java index 2ccfe3e068..38b4faa7c4 100644 --- a/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java +++ b/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java @@ -39,6 +39,7 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRequestProcessor; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.web.context.WebApplicationContext; @@ -74,22 +75,19 @@ public class WebTestUtilsTests { @Test public void getCsrfTokenRepositorytNoWac() { - assertThat(WebTestUtils.getCsrfTokenRepository(this.request)) - .isInstanceOf(HttpSessionCsrfTokenRepository.class); + assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class); } @Test public void getCsrfTokenRepositorytNoSecurity() { loadConfig(Config.class); - assertThat(WebTestUtils.getCsrfTokenRepository(this.request)) - .isInstanceOf(HttpSessionCsrfTokenRepository.class); + assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class); } @Test public void getCsrfTokenRepositorytSecurityNoCsrf() { loadConfig(SecurityNoCsrfConfig.class); - assertThat(WebTestUtils.getCsrfTokenRepository(this.request)) - .isInstanceOf(HttpSessionCsrfTokenRepository.class); + assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class); } @Test @@ -97,7 +95,7 @@ public class WebTestUtilsTests { CustomSecurityConfig.CONTEXT_REPO = this.contextRepo; CustomSecurityConfig.CSRF_REPO = this.csrfRepo; loadConfig(CustomSecurityConfig.class); - assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo); + // assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo); } // getSecurityContextRepository diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java index ce33eae8bf..5da2cf58ca 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java @@ -40,7 +40,7 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt private final CsrfTokenRepository csrfTokenRepository; - private CsrfTokenRequestAttributeHandler requestAttributeHandler = new CsrfTokenRequestProcessor(); + private CsrfTokenRequestHandler requestHandler; /** * Creates a new instance @@ -48,30 +48,28 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt */ public CsrfAuthenticationStrategy(CsrfTokenRepository csrfTokenRepository) { Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); + CsrfTokenRequestProcessor processor = new CsrfTokenRequestProcessor(); + processor.setTokenRepository(csrfTokenRepository); + this.requestHandler = processor; this.csrfTokenRepository = csrfTokenRepository; } /** - * Specify a {@link CsrfTokenRequestAttributeHandler} to use for making the - * {@code CsrfToken} available as a request attribute. - * @param requestAttributeHandler the {@link CsrfTokenRequestAttributeHandler} to use + * Specify a {@link CsrfTokenRequestHandler} to use for making the {@code CsrfToken} + * available as a request attribute. + * @param requestHandler the {@link CsrfTokenRequestHandler} to use */ - public void setRequestAttributeHandler(CsrfTokenRequestAttributeHandler requestAttributeHandler) { - Assert.notNull(requestAttributeHandler, "requestAttributeHandler cannot be null"); - this.requestAttributeHandler = requestAttributeHandler; + public void setRequestHandler(CsrfTokenRequestHandler requestHandler) { + Assert.notNull(requestHandler, "requestHandler cannot be null"); + this.requestHandler = requestHandler; } @Override public void onAuthentication(Authentication authentication, HttpServletRequest request, HttpServletResponse response) throws SessionAuthenticationException { - boolean containsToken = this.csrfTokenRepository.loadToken(request) != null; - if (containsToken) { - this.csrfTokenRepository.saveToken(null, request, response); - CsrfToken newToken = this.csrfTokenRepository.generateToken(request); - this.csrfTokenRepository.saveToken(newToken, request, response); - this.requestAttributeHandler.handle(request, response, () -> newToken); - this.logger.debug("Replaced CSRF Token"); - } + this.csrfTokenRepository.saveToken(null, request, response); + this.requestHandler.handle(request, response); + this.logger.debug("Replaced CSRF Token"); } } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index c8d184e3a7..eb2ab9f979 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -81,21 +81,19 @@ public final class CsrfFilter extends OncePerRequestFilter { private final Log logger = LogFactory.getLog(getClass()); - private final CsrfTokenRepository tokenRepository; - private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER; private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl(); - private CsrfTokenRequestAttributeHandler requestAttributeHandler; + private CsrfTokenRequestHandler requestHandler; private CsrfTokenRequestResolver requestResolver; public CsrfFilter(CsrfTokenRepository csrfTokenRepository) { Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); - this.tokenRepository = csrfTokenRepository; CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor(); - this.requestAttributeHandler = csrfTokenRequestProcessor; + csrfTokenRequestProcessor.setTokenRepository(csrfTokenRepository); + this.requestHandler = csrfTokenRequestProcessor; this.requestResolver = csrfTokenRequestProcessor; } @@ -107,15 +105,7 @@ public final class CsrfFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - request.setAttribute(HttpServletResponse.class.getName(), response); - CsrfToken csrfToken = this.tokenRepository.loadToken(request); - boolean missingToken = (csrfToken == null); - if (missingToken) { - csrfToken = this.tokenRepository.generateToken(request); - this.tokenRepository.saveToken(csrfToken, request, response); - } - final CsrfToken finalCsrfToken = csrfToken; - this.requestAttributeHandler.handle(request, response, () -> finalCsrfToken); + DeferredCsrfToken deferredCsrfToken = this.requestHandler.handle(request, response); if (!this.requireCsrfProtectionMatcher.matches(request)) { if (this.logger.isTraceEnabled()) { this.logger.trace("Did not protect against CSRF since request did not match " @@ -124,8 +114,10 @@ public final class CsrfFilter extends OncePerRequestFilter { filterChain.doFilter(request, response); return; } + CsrfToken csrfToken = deferredCsrfToken.get(); String actualToken = this.requestResolver.resolveCsrfTokenValue(request, csrfToken); if (!equalsConstantTime(csrfToken.getToken(), actualToken)) { + boolean missingToken = deferredCsrfToken.isGenerated(); this.logger.debug( LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request))); AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken) @@ -172,18 +164,18 @@ public final class CsrfFilter extends OncePerRequestFilter { } /** - * Specifies a {@link CsrfTokenRequestAttributeHandler} that is used to make the + * Specifies a {@link CsrfTokenRequestHandler} that is used to make the * {@link CsrfToken} available as a request attribute. * *

* The default is {@link CsrfTokenRequestProcessor}. *

- * @param requestAttributeHandler the {@link CsrfTokenRequestAttributeHandler} to use + * @param requestHandler the {@link CsrfTokenRequestHandler} to use * @since 5.8 */ - public void setRequestAttributeHandler(CsrfTokenRequestAttributeHandler requestAttributeHandler) { - Assert.notNull(requestAttributeHandler, "requestAttributeHandler cannot be null"); - this.requestAttributeHandler = requestAttributeHandler; + public void setRequestHandler(CsrfTokenRequestHandler requestHandler) { + Assert.notNull(requestHandler, "requestHandler cannot be null"); + this.requestHandler = requestHandler; } /** diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java similarity index 73% rename from web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandler.java rename to web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java index b45e340532..6fc4db61f5 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandler.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java @@ -16,14 +16,12 @@ package org.springframework.security.web.csrf; -import java.util.function.Supplier; - import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; /** - * A callback interface that is used to make the {@link CsrfToken} created by the - * {@link CsrfTokenRepository} available as a request attribute. Implementations of this + * A callback interface that is used to determine the {@link CsrfToken} to use and make + * the {@link CsrfToken} available as a request attribute. Implementations of this * interface may choose to perform additional tasks or customize how the token is made * available to the application through request attributes. * @@ -32,14 +30,13 @@ import jakarta.servlet.http.HttpServletResponse; * @see CsrfTokenRequestProcessor */ @FunctionalInterface -public interface CsrfTokenRequestAttributeHandler { +public interface CsrfTokenRequestHandler { /** * Handles a request using a {@link CsrfToken}. * @param request the {@code HttpServletRequest} being handled * @param response the {@code HttpServletResponse} being handled - * @param csrfToken the {@link CsrfToken} created by the {@link CsrfTokenRepository} */ - void handle(HttpServletRequest request, HttpServletResponse response, Supplier csrfToken); + DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response); } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java index 013d2190d2..dec8e54e8b 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java @@ -24,7 +24,7 @@ import jakarta.servlet.http.HttpServletResponse; import org.springframework.util.Assert; /** - * An implementation of the {@link CsrfTokenRequestAttributeHandler} and + * An implementation of the {@link CsrfTokenRequestHandler} and * {@link CsrfTokenRequestResolver} interfaces that is capable of making the * {@link CsrfToken} available as a request attribute and resolving the token value as * either a header or parameter value of the request. @@ -32,10 +32,22 @@ import org.springframework.util.Assert; * @author Steve Riesenberg * @since 5.8 */ -public class CsrfTokenRequestProcessor implements CsrfTokenRequestAttributeHandler, CsrfTokenRequestResolver { +public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfTokenRequestResolver { private String csrfRequestAttributeName = "_csrf"; + private CsrfTokenRepository tokenRepository = new HttpSessionCsrfTokenRepository(); + + /** + * Sets the {@link CsrfTokenRepository} to use. + * @param tokenRepository the {@link CsrfTokenRepository} to use. Default + * {@link HttpSessionCsrfTokenRepository} + */ + public void setTokenRepository(CsrfTokenRepository tokenRepository) { + Assert.notNull(tokenRepository, "tokenRepository cannot be null"); + this.tokenRepository = tokenRepository; + } + /** * The {@link CsrfToken} is available as a request attribute named * {@code CsrfToken.class.getName()}. By default, an additional request attribute that @@ -49,16 +61,18 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestAttributeHandl } @Override - public void handle(HttpServletRequest request, HttpServletResponse response, Supplier csrfToken) { + public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) { Assert.notNull(request, "request cannot be null"); Assert.notNull(response, "response cannot be null"); - Assert.notNull(csrfToken, "csrfToken supplier cannot be null"); - CsrfToken actualCsrfToken = csrfToken.get(); - Assert.notNull(actualCsrfToken, "csrfToken cannot be null"); - request.setAttribute(CsrfToken.class.getName(), actualCsrfToken); + + request.setAttribute(HttpServletResponse.class.getName(), response); + DeferredCsrfToken deferredCsrfToken = new RepositoryDeferredCsrfToken(request, response); + CsrfToken csrfToken = new SupplierCsrfToken(deferredCsrfToken::get); + request.setAttribute(CsrfToken.class.getName(), csrfToken); String csrfAttrName = (this.csrfRequestAttributeName != null) ? this.csrfRequestAttributeName - : actualCsrfToken.getParameterName(); - request.setAttribute(csrfAttrName, actualCsrfToken); + : csrfToken.getParameterName(); + request.setAttribute(csrfAttrName, csrfToken); + return deferredCsrfToken; } @Override @@ -72,4 +86,78 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestAttributeHandl return actualToken; } + private static final class SupplierCsrfToken implements CsrfToken { + + private final Supplier csrfTokenSupplier; + + private SupplierCsrfToken(Supplier csrfTokenSupplier) { + this.csrfTokenSupplier = csrfTokenSupplier; + } + + @Override + public String getHeaderName() { + return getDelegate().getHeaderName(); + } + + @Override + public String getParameterName() { + return getDelegate().getParameterName(); + } + + @Override + public String getToken() { + return getDelegate().getToken(); + } + + private CsrfToken getDelegate() { + CsrfToken delegate = this.csrfTokenSupplier.get(); + if (delegate == null) { + throw new IllegalStateException("csrfTokenSupplier returned null delegate"); + } + return delegate; + } + + } + + private final class RepositoryDeferredCsrfToken implements DeferredCsrfToken { + + private final HttpServletRequest request; + + private final HttpServletResponse response; + + private CsrfToken csrfToken; + + private Boolean missingToken; + + RepositoryDeferredCsrfToken(HttpServletRequest request, HttpServletResponse response) { + this.request = request; + this.response = response; + } + + @Override + public CsrfToken get() { + init(); + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + init(); + return this.missingToken; + } + + private void init() { + if (this.csrfToken != null) { + return; + } + this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.loadToken(this.request); + this.missingToken = (this.csrfToken == null); + if (this.missingToken) { + this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.generateToken(this.request); + CsrfTokenRequestProcessor.this.tokenRepository.saveToken(this.csrfToken, this.request, this.response); + } + } + + } + } diff --git a/web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java b/web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java new file mode 100644 index 0000000000..d8ab774570 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +/** + * An interface that allows delayed access to a {@link CsrfToken} that may be generated. + * + * @author Rob Winch + * @since 5.8 + */ +public interface DeferredCsrfToken { + + /*** + * Gets the {@link CsrfToken} + * @return a non-null {@link CsrfToken} + */ + CsrfToken get(); + + /** + * Returns true if {@link #get()} refers to a generated {@link CsrfToken} or false if + * it already existed. + * @return true if {@link #get()} refers to a generated {@link CsrfToken} or false if + * it already existed. + */ + boolean isGenerated(); + +} diff --git a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java index 082cbc12bf..ea1974a06a 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java @@ -27,7 +27,10 @@ import org.springframework.util.Assert; * * @author Rob Winch * @since 4.1 + * @deprecated Use org.springframework.security.web.csrf.CsrfTokenRequestHandler which + * returns a {@link DeferredCsrfToken} */ +@Deprecated public final class LazyCsrfTokenRepository implements CsrfTokenRepository { /** diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java index 60f1631bb3..f8010a0f76 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java @@ -32,6 +32,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -74,46 +75,44 @@ public class CsrfAuthenticationStrategyTests { } @Test - public void setRequestAttributeHandlerWhenNullThenIllegalStateException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.strategy.setRequestAttributeHandler(null)) - .withMessage("requestAttributeHandler cannot be null"); + public void setRequestHandlerWhenNullThenIllegalStateException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.strategy.setRequestHandler(null)) + .withMessage("requestHandler cannot be null"); } @Test - public void onAuthenticationWhenCustomRequestAttributeHandlerThenUsed() { - given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken); - given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken); - - CsrfTokenRequestAttributeHandler requestAttributeHandler = mock(CsrfTokenRequestAttributeHandler.class); - this.strategy.setRequestAttributeHandler(requestAttributeHandler); + public void onAuthenticationWhenCustomRequestHandlerThenUsed() { + CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); + this.strategy.setRequestHandler(requestHandler); this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); - verify(requestAttributeHandler).handle(eq(this.request), eq(this.response), any()); - verifyNoMoreInteractions(requestAttributeHandler); + verify(requestHandler).handle(eq(this.request), eq(this.response)); + verifyNoMoreInteractions(requestHandler); } @Test public void logoutRemovesCsrfTokenAndSavesNew() { - given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken); + given(this.csrfTokenRepository.loadToken(this.request)).willReturn(null, this.existingToken); given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken); this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); - verify(this.csrfTokenRepository).saveToken(null, this.request, this.response); - verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class), - any(HttpServletResponse.class)); // SEC-2404, SEC-2832 CsrfToken tokenInRequest = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); assertThat(tokenInRequest.getToken()).isSameAs(this.generatedToken.getToken()); assertThat(tokenInRequest.getHeaderName()).isSameAs(this.generatedToken.getHeaderName()); assertThat(tokenInRequest.getParameterName()).isSameAs(this.generatedToken.getParameterName()); assertThat(this.request.getAttribute(this.generatedToken.getParameterName())).isSameAs(tokenInRequest); + // verify after the test accesses the CsrfToken which causes the lazy save to + // occur + verify(this.csrfTokenRepository).saveToken(null, this.request, this.response); + verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class), + any(HttpServletResponse.class)); } // SEC-2872 @Test public void delaySavingCsrf() { this.strategy = new CsrfAuthenticationStrategy(new LazyCsrfTokenRepository(this.csrfTokenRepository)); - given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken); given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken); this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); @@ -127,10 +126,10 @@ public class CsrfAuthenticationStrategyTests { } @Test - public void logoutRemovesNoActionIfNullToken() { + public void logoutWhenNoCsrfToken() { this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); - verify(this.csrfTokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), + verify(this.csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class)); } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index a33d9114ef..335a49471d 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -23,8 +23,6 @@ import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; -import org.assertj.core.api.AbstractObjectAssert; -import org.assertj.core.api.ObjectAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -45,10 +43,12 @@ import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; /** * @author Rob Winch @@ -127,8 +127,8 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -139,8 +139,8 @@ public class CsrfFilterTests { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -151,8 +151,8 @@ public class CsrfFilterTests { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -165,8 +165,8 @@ public class CsrfFilterTests { this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -176,8 +176,8 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(false); given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -187,8 +187,8 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(false); given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - assertToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -199,8 +199,8 @@ public class CsrfFilterTests { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -213,8 +213,8 @@ public class CsrfFilterTests { this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -225,8 +225,8 @@ public class CsrfFilterTests { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), @@ -239,8 +239,8 @@ public class CsrfFilterTests { given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); // LazyCsrfTokenRepository requires the response as an attribute assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); verify(this.filterChain).doFilter(this.request, this.response); @@ -254,7 +254,6 @@ public class CsrfFilterTests { this.filter.setAccessDeniedHandler(this.deniedHandler); for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { resetRequestResponse(); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.setMethod(method); this.filter.doFilter(this.request, this.response, this.filterChain); verify(this.filterChain).doFilter(this.request, this.response); @@ -305,8 +304,8 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); verifyNoMoreInteractions(this.filterChain); } @@ -337,14 +336,14 @@ public class CsrfFilterTests { } @Test - public void doFilterWhenRequestAttributeHandlerThenUsed() throws Exception { - given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); - CsrfTokenRequestAttributeHandler requestAttributeHandler = mock(CsrfTokenRequestAttributeHandler.class); - this.filter.setRequestAttributeHandler(requestAttributeHandler); + public void doFilterWhenRequestHandlerThenUsed() throws Exception { + CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); + given(requestHandler.handle(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); + this.filter.setRequestHandler(requestHandler); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - verify(requestAttributeHandler).handle(eq(this.request), eq(this.response), any()); + verify(requestHandler).handle(eq(this.request), eq(this.response)); verify(this.filterChain).doFilter(this.request, this.response); } @@ -377,39 +376,40 @@ public class CsrfFilterTests { CsrfFilter filter = createCsrfFilter(this.tokenRepository); String csrfAttrName = "_csrf"; CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor(); + csrfTokenRequestProcessor.setTokenRepository(this.tokenRepository); csrfTokenRequestProcessor.setCsrfRequestAttributeName(csrfAttrName); - filter.setRequestAttributeHandler(csrfTokenRequestProcessor); - CsrfToken expectedCsrfToken = mock(CsrfToken.class); + filter.setRequestHandler(csrfTokenRequestProcessor); + CsrfToken expectedCsrfToken = spy(this.token); given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken); filter.doFilter(this.request, this.response, this.filterChain); verifyNoInteractions(expectedCsrfToken); CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName); - assertThat(tokenFromRequest).isEqualTo(expectedCsrfToken); + assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken); } - private static CsrfTokenAssert assertToken(Object token) { - return new CsrfTokenAssert((CsrfToken) token); - } + private static final class TestDeferredCsrfToken implements DeferredCsrfToken { - private static class CsrfTokenAssert extends AbstractObjectAssert { + private final CsrfToken csrfToken; - /** - * Creates a new {@link ObjectAssert}. - * @param actual the target to verify. - */ - protected CsrfTokenAssert(CsrfToken actual) { - super(actual, CsrfTokenAssert.class); + private final boolean isGenerated; + + private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) { + this.csrfToken = csrfToken; + this.isGenerated = isGenerated; } - CsrfTokenAssert isEqualTo(CsrfToken expected) { - assertThat(this.actual.getHeaderName()).isEqualTo(expected.getHeaderName()); - assertThat(this.actual.getParameterName()).isEqualTo(expected.getParameterName()); - assertThat(this.actual.getToken()).isEqualTo(expected.getToken()); - return this; + @Override + public CsrfToken get() { + return this.csrfToken; } - } + @Override + public boolean isGenerated() { + return this.isGenerated; + } + + }; } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenAssert.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenAssert.java new file mode 100644 index 0000000000..cca2591110 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenAssert.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.Assertions; + +/** + * Assertion for validating the properties on CsrfToken are the same. + */ +public class CsrfTokenAssert extends AbstractAssert { + + protected CsrfTokenAssert(CsrfToken csrfToken) { + super(csrfToken, CsrfTokenAssert.class); + } + + public static CsrfTokenAssert assertThatCsrfToken(Object csrfToken) { + return new CsrfTokenAssert((CsrfToken) csrfToken); + } + + public static CsrfTokenAssert assertThat(CsrfToken csrfToken) { + return new CsrfTokenAssert(csrfToken); + } + + public CsrfTokenAssert isEqualTo(CsrfToken csrfToken) { + isNotNull(); + assertThat(csrfToken).isNotNull(); + Assertions.assertThat(this.actual.getHeaderName()).isEqualTo(csrfToken.getHeaderName()); + Assertions.assertThat(this.actual.getParameterName()).isEqualTo(csrfToken.getParameterName()); + Assertions.assertThat(this.actual.getToken()).isEqualTo(csrfToken.getToken()); + return this; + } + +} diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java index 542e0ef1f9..529ab5c821 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java @@ -18,12 +18,17 @@ package org.springframework.security.web.csrf; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; /** * Tests for {@link CsrfTokenRequestProcessor}. @@ -31,8 +36,12 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException * @author Steve Riesenberg * @since 5.8 */ +@ExtendWith(MockitoExtension.class) public class CsrfTokenRequestProcessorTests { + @Mock + CsrfTokenRepository tokenRepository; + private MockHttpServletRequest request; private MockHttpServletResponse response; @@ -47,48 +56,36 @@ public class CsrfTokenRequestProcessorTests { this.response = new MockHttpServletResponse(); this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue"); this.processor = new CsrfTokenRequestProcessor(); + this.processor.setTokenRepository(this.tokenRepository); } @Test public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.processor.handle(null, this.response, () -> this.token)) + assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(null, this.response)) .withMessage("request cannot be null"); } @Test public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.processor.handle(this.request, null, () -> this.token)) + assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(this.request, null)) .withMessage("response cannot be null"); } - @Test - public void handleWhenCsrfTokenSupplierIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(this.request, this.response, null)) - .withMessage("csrfToken supplier cannot be null"); - } - - @Test - public void handleWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.processor.handle(this.request, this.response, () -> null)) - .withMessage("csrfToken cannot be null"); - } - @Test public void handleWhenCsrfRequestAttributeSetThenUsed() { - this.processor.setCsrfRequestAttributeName("_csrf.attr"); - this.processor.handle(this.request, this.response, () -> this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); - assertThat(this.request.getAttribute("_csrf.attr")).isEqualTo(this.token); + given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); + this.processor.setCsrfRequestAttributeName("_csrf"); + this.processor.handle(this.request, this.response); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token); } @Test public void handleWhenValidParametersThenRequestAttributesSet() { - this.processor.handle(this.request, this.response, () -> this.token); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); - assertThat(this.request.getAttribute("_csrf")).isEqualTo(this.token); + given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); + this.processor.handle(this.request, this.response); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token); } @Test