diff --git a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java index 8b03382451..58d271f8d5 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java @@ -105,9 +105,11 @@ public class XorCsrfTokenRequestAttributeHandlerTests { @Test public void handleWhenCsrfTokenIsNullThenThrowsIllegalStateException() { + this.handler.handle(this.request, this.response, () -> null); + CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute("_csrf"); // @formatter:off assertThatIllegalStateException() - .isThrownBy(() -> this.handler.handle(this.request, this.response, () -> null)) + .isThrownBy(csrfTokenAttribute::getToken) .withMessage("csrfToken supplier returned null"); // @formatter:on } @@ -128,8 +130,12 @@ public class XorCsrfTokenRequestAttributeHandlerTests { @Test public void handleWhenSecureRandomSetThenUsed() { + willAnswer(fillByteArray()).given(this.secureRandom).nextBytes(anyByteArray()); + this.handler.setSecureRandom(this.secureRandom); this.handler.handle(this.request, this.response, () -> this.token); + CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); + assertThat(csrfTokenAttribute.getToken()).isEqualTo(XOR_CSRF_TOKEN_VALUE); verify(this.secureRandom).nextBytes(anyByteArray()); verifyNoMoreInteractions(this.secureRandom); } @@ -140,12 +146,11 @@ public class XorCsrfTokenRequestAttributeHandlerTests { this.handler.setSecureRandom(this.secureRandom); this.handler.handle(this.request, this.response, () -> this.token); - verify(this.secureRandom).nextBytes(anyByteArray()); - assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); - assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull(); - CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); assertThat(csrfTokenAttribute.getToken()).isEqualTo(XOR_CSRF_TOKEN_VALUE); + verify(this.secureRandom).nextBytes(anyByteArray()); + assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute("_csrf")).isNotNull(); } @Test