diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultFiltersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultFiltersTests.java index c1a074af76..bf94a98ee3 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultFiltersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultFiltersTests.java @@ -51,8 +51,11 @@ import org.springframework.security.web.context.SecurityContextHolderFilter; import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; +import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler; import org.springframework.security.web.header.HeaderWriterFilter; import org.springframework.security.web.savedrequest.RequestCacheAwareFilter; import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter; @@ -121,8 +124,12 @@ public class DefaultFiltersTests { MockHttpServletRequest request = new MockHttpServletRequest("POST", ""); request.setServletPath("/logout"); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); - new HttpSessionCsrfTokenRepository().saveToken(csrfToken, request, response); - request.setParameter(csrfToken.getParameterName(), csrfToken.getToken()); + CsrfTokenRepository repository = new HttpSessionCsrfTokenRepository(); + repository.saveToken(csrfToken, request, response); + CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); + handler.handle(request, response, () -> csrfToken); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); + request.setParameter(token.getParameterName(), token.getToken()); this.spring.getContext().getBean("springSecurityFilterChain", Filter.class).doFilter(request, response, new MockFilterChain()); assertThat(response.getRedirectedUrl()).isEqualTo("/login?logout"); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurerTests.java index 75746f223e..e906f89f47 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurerTests.java @@ -85,7 +85,9 @@ public class DefaultLoginPageConfigurerTests { String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); // @formatter:off this.mvc.perform(get("/login").sessionAttr(csrfAttributeName, csrfToken)) - .andExpect(content().string("\n" + .andExpect((result) -> { + CsrfToken token = (CsrfToken) result.getRequest().getAttribute(CsrfToken.class.getName()); + assertThat(result.getResponse().getContentAsString()).isEqualTo("\n" + "\n" + "
\n" + " \n" @@ -108,11 +110,12 @@ public class DefaultLoginPageConfigurerTests { + " \n" + " \n" + " \n" - + "\n" + + "\n" + " \n" + " \n" + "\n" - + "")); + + "