diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java index e856905cf0..718ccdf41c 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java @@ -191,7 +191,7 @@ public class CsrfWebFilter implements WebFilter { private Mono generateToken(ServerWebExchange exchange) { return this.csrfTokenRepository.generateToken(exchange) - .delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token)); + .delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token)).cache(); } private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher { diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java index c531bfca3f..80d9fd74ad 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java @@ -16,6 +16,9 @@ package org.springframework.security.web.server.csrf; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -227,6 +230,26 @@ public class CsrfWebFilterTests { .isForbidden(); } + // gh-9113 + @Test + public void filterWhenSubscribingCsrfTokenMultipleTimesThenGenerateOnlyOnce() { + this.csrfFilter.setCsrfTokenRepository(this.repository); + given(this.repository.loadToken(any())).willReturn(Mono.empty()); + AtomicInteger count = new AtomicInteger(); + given(this.repository.generateToken(any())).willReturn(Mono.fromCallable(() -> { + count.incrementAndGet(); + return this.token; + })); + given(this.repository.saveToken(any(), any())).willReturn(Mono.empty()); + AtomicReference> tokenFromExchange = new AtomicReference<>(); + given(this.chain.filter(any())).willReturn( + Mono.fromRunnable(() -> tokenFromExchange.set(this.get.getAttribute(CsrfToken.class.getName())))); + this.csrfFilter.filter(this.get, this.chain).block(); + tokenFromExchange.get().block(); + tokenFromExchange.get().block(); + assertThat(count).hasValue(1); + } + @RestController static class OkController {