diff --git a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java index d5c0c21190..082cbc12bf 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java @@ -38,6 +38,8 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository { private final CsrfTokenRepository delegate; + private boolean deferLoadToken; + /** * Creates a new instance * @param delegate the {@link CsrfTokenRepository} to use. Cannot be null @@ -48,6 +50,15 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository { this.delegate = delegate; } + /** + * Determines if {@link #loadToken(HttpServletRequest)} should be lazily loaded. + * @param deferLoadToken true if should lazily load + * {@link #loadToken(HttpServletRequest)}. Default false. + */ + public void setDeferLoadToken(boolean deferLoadToken) { + this.deferLoadToken = deferLoadToken; + } + /** * Generates a new token * @param request the {@link HttpServletRequest} to use. The @@ -77,6 +88,9 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository { */ @Override public CsrfToken loadToken(HttpServletRequest request) { + if (this.deferLoadToken) { + return new LazyLoadCsrfToken(request, this.delegate); + } return this.delegate.loadToken(request); } @@ -92,6 +106,55 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository { return response; } + private final class LazyLoadCsrfToken implements CsrfToken { + + private final HttpServletRequest request; + + private final CsrfTokenRepository tokenRepository; + + private CsrfToken token; + + private LazyLoadCsrfToken(HttpServletRequest request, CsrfTokenRepository tokenRepository) { + this.request = request; + this.tokenRepository = tokenRepository; + } + + private CsrfToken getDelegate() { + if (this.token != null) { + return this.token; + } + // load from the delegate repository + this.token = LazyCsrfTokenRepository.this.delegate.loadToken(this.request); + if (this.token == null) { + // return a generated token that is lazily saved since + // LazyCsrfTokenRepository#loadToken always returns a value + this.token = generateToken(this.request); + } + return this.token; + } + + @Override + public String getHeaderName() { + return getDelegate().getHeaderName(); + } + + @Override + public String getParameterName() { + return getDelegate().getParameterName(); + } + + @Override + public String getToken() { + return getDelegate().getToken(); + } + + @Override + public String toString() { + return "LazyLoadCsrfToken{" + "token=" + this.token + '}'; + } + + } + private static final class SaveOnAccessCsrfToken implements CsrfToken { private transient CsrfTokenRepository tokenRepository; diff --git a/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java index e41ed59131..9be9d96518 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java @@ -30,6 +30,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyZeroInteractions; /** @@ -97,4 +98,15 @@ public class LazyCsrfTokenRepositoryTests { verify(this.delegate).loadToken(this.request); } + @Test + public void loadTokenWhenDeferLoadToken() { + given(this.delegate.loadToken(this.request)).willReturn(this.token); + this.repository.setDeferLoadToken(true); + CsrfToken loadToken = this.repository.loadToken(this.request); + verifyNoInteractions(this.delegate); + assertThat(loadToken.getToken()).isEqualTo(this.token.getToken()); + assertThat(loadToken.getHeaderName()).isEqualTo(this.token.getHeaderName()); + assertThat(loadToken.getParameterName()).isEqualTo(this.token.getParameterName()); + } + }