diff --git a/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java index 06b4a00c81..2d24c70328 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java @@ -33,10 +33,13 @@ import org.springframework.security.web.context.SaveContextOnUpdateOrErrorRespon import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.WebFilterChainProxy; import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler; +import org.springframework.security.web.server.csrf.CsrfToken; import org.springframework.stereotype.Controller; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThat; @@ -292,12 +295,15 @@ public class FormLoginTests { public static class HomePage { private WebDriver driver; + @FindBy(tagName = "body") + WebElement body; + public HomePage(WebDriver driver) { this.driver = driver; } public void assertAt() { - assertThat(this.driver.getPageSource()).contains("ok"); + assertThat(this.body.getText()).isEqualToIgnoringWhitespace("ok"); } static T to(WebDriver driver, Class page) { @@ -310,8 +316,10 @@ public class FormLoginTests { public static class CustomLoginPageController { @ResponseBody @GetMapping("/login") - public String login() { - return "\n" + public Mono login(ServerWebExchange exchange) { + Mono token = exchange.getAttribute(CsrfToken.class.getName()); + return token.map(t -> + "\n" + "\n" + " \n" + " \n" @@ -332,11 +340,12 @@ public class FormLoginTests { + " \n" + " \n" + "

\n" + + " \n" + " \n" + " \n" + " \n" + " \n" - + ""; + + ""); } }