Browse Source

Add test support for Xor CSRF tokens

Issue gh-4001
pull/12044/head
Steve Riesenberg 3 years ago
parent
commit
440748ec65
No known key found for this signature in database
GPG Key ID: 5F311AB48A55D521
  1. 19
      config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java
  2. 15
      test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java
  3. 21
      test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java
  4. 10
      test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java
  5. 10
      test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java

19
config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java

@ -41,6 +41,7 @@ import org.springframework.security.web.FilterChainProxy; @@ -41,6 +41,7 @@ import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.stereotype.Controller;
import org.springframework.test.context.junit.jupiter.SpringExtension;
@ -301,7 +302,7 @@ public class CsrfConfigTests { @@ -301,7 +302,7 @@ public class CsrfConfigTests {
}
@Test
public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorThenOk() throws Exception {
public void postWhenUsingCsrfAndXorCsrfTokenRequestAttributeHandlerThenOk() throws Exception {
this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
.autowire();
// @formatter:off
@ -309,25 +310,27 @@ public class CsrfConfigTests { @@ -309,25 +310,27 @@ public class CsrfConfigTests {
.andExpect(status().isOk())
.andReturn();
MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession();
CsrfToken csrfToken = (CsrfToken) mvcResult.getRequest().getAttribute("_csrf");
MockHttpServletRequestBuilder ok = post("/ok")
.header(csrfToken.getHeaderName(), csrfToken.getToken())
.with(csrf())
.session(session);
this.mvc.perform(ok).andExpect(status().isOk());
// @formatter:on
}
@Test
public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorWithRawTokenThenForbidden() throws Exception {
public void postWhenUsingCsrfAndXorCsrfTokenRequestAttributeHandlerWithRawTokenThenForbidden() throws Exception {
this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
.autowire();
// @formatter:off
MvcResult mvcResult = this.mvc.perform(get("/ok"))
MvcResult mvcResult = this.mvc.perform(get("/csrf"))
.andExpect(status().isOk())
.andReturn();
MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession();
MockHttpServletRequest request = mvcResult.getRequest();
MockHttpSession session = (MockHttpSession) request.getSession();
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
CsrfToken csrfToken = repository.loadToken(request);
MockHttpServletRequestBuilder ok = post("/ok")
.with(csrf())
.header(csrfToken.getHeaderName(), csrfToken.getToken())
.session(session);
this.mvc.perform(ok).andExpect(status().isForbidden());
// @formatter:on
@ -594,7 +597,7 @@ public class CsrfConfigTests { @@ -594,7 +597,7 @@ public class CsrfConfigTests {
@Override
public void match(MvcResult result) throws Exception {
MockHttpServletRequest request = result.getRequest();
CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
assertThat(token).isNotNull();
assertThat(token.getToken()).isEqualTo(this.token.apply(result));
}

15
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

@ -95,6 +95,8 @@ import org.springframework.security.web.context.SecurityContextRepository; @@ -95,6 +95,8 @@ import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc;
@ -499,6 +501,10 @@ public final class SecurityMockMvcRequestPostProcessors { @@ -499,6 +501,10 @@ public final class SecurityMockMvcRequestPostProcessors {
*/
public static final class CsrfRequestPostProcessor implements RequestPostProcessor {
private static final byte[] INVALID_TOKEN_BYTES = new byte[] { 1, 1, 1, 96, 99, 98 };
private static final String INVALID_TOKEN_VALUE = Base64.getEncoder().encodeToString(INVALID_TOKEN_BYTES);
private boolean asHeader;
private boolean useInvalidToken;
@ -509,14 +515,17 @@ public final class SecurityMockMvcRequestPostProcessors { @@ -509,14 +515,17 @@ public final class SecurityMockMvcRequestPostProcessors {
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
CsrfTokenRequestHandler handler = WebTestUtils.getCsrfTokenRequestHandler(request);
if (!(repository instanceof TestCsrfTokenRepository)) {
repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository());
WebTestUtils.setCsrfTokenRepository(request, repository);
}
TestCsrfTokenRepository.enable(request);
CsrfToken token = repository.generateToken(request);
repository.saveToken(token, request, new MockHttpServletResponse());
String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken();
MockHttpServletResponse response = new MockHttpServletResponse();
DeferredCsrfToken deferredCsrfToken = repository.loadDeferredToken(request, response);
handler.handle(request, response, deferredCsrfToken::get);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
String tokenValue = this.useInvalidToken ? INVALID_TOKEN_VALUE : token.getToken();
if (this.asHeader) {
request.addHeader(token.getHeaderName(), tokenValue);
}

21
test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java

@ -31,7 +31,9 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter @@ -31,7 +31,9 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;
@ -48,6 +50,8 @@ public abstract class WebTestUtils { @@ -48,6 +50,8 @@ public abstract class WebTestUtils {
private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository();
private static final CsrfTokenRequestHandler DEFAULT_CSRF_HANDLER = new XorCsrfTokenRequestAttributeHandler();
private WebTestUtils() {
}
@ -107,6 +111,23 @@ public abstract class WebTestUtils { @@ -107,6 +111,23 @@ public abstract class WebTestUtils {
return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository");
}
/**
* Gets the {@link CsrfTokenRequestHandler} for the specified
* {@link HttpServletRequest}. If one is not found, the default
* {@link XorCsrfTokenRequestAttributeHandler} is used.
* @param request the {@link HttpServletRequest} to obtain the
* {@link CsrfTokenRequestHandler}
* @return the {@link CsrfTokenRequestHandler} for the specified
* {@link HttpServletRequest}
*/
public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) {
CsrfFilter filter = findFilter(request, CsrfFilter.class);
if (filter == null) {
return DEFAULT_CSRF_HANDLER;
}
return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler");
}
/**
* Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}.
* @param request the {@link HttpServletRequest} to obtain the

10
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java

@ -25,7 +25,6 @@ import org.springframework.http.HttpMethod; @@ -25,7 +25,6 @@ import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
@ -52,8 +51,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { @@ -52,8 +51,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
@Test
public void defaults() {
MockHttpServletRequest request = formLogin().buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
assertThat(request.getParameter("username")).isEqualTo("user");
assertThat(request.getParameter("password")).isEqualTo("password");
assertThat(request.getMethod()).isEqualTo("POST");
@ -66,8 +64,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { @@ -66,8 +64,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
public void custom() {
MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret")
.buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
assertThat(request.getParameter("username")).isEqualTo("admin");
assertThat(request.getParameter("password")).isEqualTo("secret");
assertThat(request.getMethod()).isEqualTo("POST");
@ -79,8 +76,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { @@ -79,8 +76,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
public void customWithUriVars() {
MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2")
.user("username", "admin").password("password", "secret").buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
assertThat(request.getParameter("username")).isEqualTo("admin");
assertThat(request.getParameter("password")).isEqualTo("secret");
assertThat(request.getMethod()).isEqualTo("POST");

10
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java

@ -25,7 +25,6 @@ import org.springframework.http.HttpMethod; @@ -25,7 +25,6 @@ import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
@ -52,8 +51,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { @@ -52,8 +51,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
@Test
public void defaults() {
MockHttpServletRequest request = logout().buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/logout");
@ -62,8 +60,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { @@ -62,8 +60,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
@Test
public void custom() {
MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
@ -73,8 +70,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { @@ -73,8 +70,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
public void customWithUriVars() {
MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2")
.buildRequest(this.servletContext);
CsrfToken token = (CsrfToken) request
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2");

Loading…
Cancel
Save