From 440748ec65199a22bc07e43b2f0f55ff6c9d0478 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Wed, 12 Oct 2022 11:11:52 -0500 Subject: [PATCH] Add test support for Xor CSRF tokens Issue gh-4001 --- .../security/config/http/CsrfConfigTests.java | 19 ++++++++++------- .../SecurityMockMvcRequestPostProcessors.java | 15 ++++++++++--- .../test/web/support/WebTestUtils.java | 21 +++++++++++++++++++ ...yMockMvcRequestBuildersFormLoginTests.java | 10 +++------ ...MockMvcRequestBuildersFormLogoutTests.java | 10 +++------ 5 files changed, 50 insertions(+), 25 deletions(-) diff --git a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java index 4da7fbc565..1b12b3e28c 100644 --- a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java @@ -41,6 +41,7 @@ import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.stereotype.Controller; import org.springframework.test.context.junit.jupiter.SpringExtension; @@ -301,7 +302,7 @@ public class CsrfConfigTests { } @Test - public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorThenOk() throws Exception { + public void postWhenUsingCsrfAndXorCsrfTokenRequestAttributeHandlerThenOk() throws Exception { this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers")) .autowire(); // @formatter:off @@ -309,25 +310,27 @@ public class CsrfConfigTests { .andExpect(status().isOk()) .andReturn(); MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession(); - CsrfToken csrfToken = (CsrfToken) mvcResult.getRequest().getAttribute("_csrf"); MockHttpServletRequestBuilder ok = post("/ok") - .header(csrfToken.getHeaderName(), csrfToken.getToken()) + .with(csrf()) .session(session); this.mvc.perform(ok).andExpect(status().isOk()); // @formatter:on } @Test - public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorWithRawTokenThenForbidden() throws Exception { + public void postWhenUsingCsrfAndXorCsrfTokenRequestAttributeHandlerWithRawTokenThenForbidden() throws Exception { this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers")) .autowire(); // @formatter:off - MvcResult mvcResult = this.mvc.perform(get("/ok")) + MvcResult mvcResult = this.mvc.perform(get("/csrf")) .andExpect(status().isOk()) .andReturn(); - MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession(); + MockHttpServletRequest request = mvcResult.getRequest(); + MockHttpSession session = (MockHttpSession) request.getSession(); + CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); + CsrfToken csrfToken = repository.loadToken(request); MockHttpServletRequestBuilder ok = post("/ok") - .with(csrf()) + .header(csrfToken.getHeaderName(), csrfToken.getToken()) .session(session); this.mvc.perform(ok).andExpect(status().isForbidden()); // @formatter:on @@ -594,7 +597,7 @@ public class CsrfConfigTests { @Override public void match(MvcResult result) throws Exception { MockHttpServletRequest request = result.getRequest(); - CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); assertThat(token).isNotNull(); assertThat(token.getToken()).isEqualTo(this.token.apply(result)); } diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java index 4eb0f6b324..d0fcb055b8 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java @@ -95,6 +95,8 @@ import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; @@ -499,6 +501,10 @@ public final class SecurityMockMvcRequestPostProcessors { */ public static final class CsrfRequestPostProcessor implements RequestPostProcessor { + private static final byte[] INVALID_TOKEN_BYTES = new byte[] { 1, 1, 1, 96, 99, 98 }; + + private static final String INVALID_TOKEN_VALUE = Base64.getEncoder().encodeToString(INVALID_TOKEN_BYTES); + private boolean asHeader; private boolean useInvalidToken; @@ -509,14 +515,17 @@ public final class SecurityMockMvcRequestPostProcessors { @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); + CsrfTokenRequestHandler handler = WebTestUtils.getCsrfTokenRequestHandler(request); if (!(repository instanceof TestCsrfTokenRepository)) { repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository()); WebTestUtils.setCsrfTokenRepository(request, repository); } TestCsrfTokenRepository.enable(request); - CsrfToken token = repository.generateToken(request); - repository.saveToken(token, request, new MockHttpServletResponse()); - String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken(); + MockHttpServletResponse response = new MockHttpServletResponse(); + DeferredCsrfToken deferredCsrfToken = repository.loadDeferredToken(request, response); + handler.handle(request, response, deferredCsrfToken::get); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); + String tokenValue = this.useInvalidToken ? INVALID_TOKEN_VALUE : token.getToken(); if (this.asHeader) { request.addHeader(token.getHeaderName(), tokenValue); } diff --git a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java index c13ebdefe3..0978ec370b 100644 --- a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java +++ b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java @@ -31,7 +31,9 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; +import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.support.WebApplicationContextUtils; @@ -48,6 +50,8 @@ public abstract class WebTestUtils { private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository(); + private static final CsrfTokenRequestHandler DEFAULT_CSRF_HANDLER = new XorCsrfTokenRequestAttributeHandler(); + private WebTestUtils() { } @@ -107,6 +111,23 @@ public abstract class WebTestUtils { return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository"); } + /** + * Gets the {@link CsrfTokenRequestHandler} for the specified + * {@link HttpServletRequest}. If one is not found, the default + * {@link XorCsrfTokenRequestAttributeHandler} is used. + * @param request the {@link HttpServletRequest} to obtain the + * {@link CsrfTokenRequestHandler} + * @return the {@link CsrfTokenRequestHandler} for the specified + * {@link HttpServletRequest} + */ + public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) { + CsrfFilter filter = findFilter(request, CsrfFilter.class); + if (filter == null) { + return DEFAULT_CSRF_HANDLER; + } + return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler"); + } + /** * Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}. * @param request the {@link HttpServletRequest} to obtain the diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java index 9dea5175bf..252ef0fa53 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java @@ -25,7 +25,6 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockServletContext; -import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -52,8 +51,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { @Test public void defaults() { MockHttpServletRequest request = formLogin().buildRequest(this.servletContext); - CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("password")).isEqualTo("password"); assertThat(request.getMethod()).isEqualTo("POST"); @@ -66,8 +64,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { public void custom() { MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret") .buildRequest(this.servletContext); - CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getMethod()).isEqualTo("POST"); @@ -79,8 +76,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { public void customWithUriVars() { MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2") .user("username", "admin").password("password", "secret").buildRequest(this.servletContext); - CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getMethod()).isEqualTo("POST"); diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java index df6e7cfef2..728f7c45a4 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java @@ -25,7 +25,6 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockServletContext; -import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -52,8 +51,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { @Test public void defaults() { MockHttpServletRequest request = logout().buildRequest(this.servletContext); - CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/logout"); @@ -62,8 +60,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { @Test public void custom() { MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext); - CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/admin/logout"); @@ -73,8 +70,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { public void customWithUriVars() { MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2") .buildRequest(this.servletContext); - CsrfToken token = (CsrfToken) request - .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2");