From 7622826b69438db6ec3fb6d3aef82bb632577dfa Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Tue, 7 Nov 2017 21:59:47 -0600 Subject: [PATCH] WebSessionServerCsrfTokenRepository saves on getToken Fixes gh-4801 --- .../config/web/server/FormLoginTests.java | 10 +-- .../web/server/csrf/CsrfWebFilter.java | 11 ++- .../web/server/csrf/DefaultCsrfToken.java | 24 +++++++ .../WebSessionServerCsrfTokenRepository.java | 68 ++++++++++++++++++- .../ui/LoginPageGeneratingWebFilter.java | 5 +- .../ui/LogoutPageGeneratingWebFilter.java | 5 +- ...SessionServerCsrfTokenRepositoryTests.java | 40 +++++------ 7 files changed, 123 insertions(+), 40 deletions(-) diff --git a/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java index 2d24c70328..1b8b8f6c7d 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java @@ -316,9 +316,9 @@ public class FormLoginTests { public static class CustomLoginPageController { @ResponseBody @GetMapping("/login") - public Mono login(ServerWebExchange exchange) { - Mono token = exchange.getAttribute(CsrfToken.class.getName()); - return token.map(t -> + public String login(ServerWebExchange exchange) { + CsrfToken token = exchange.getAttribute(CsrfToken.class.getName()); + return "\n" + "\n" + " \n" @@ -340,12 +340,12 @@ public class FormLoginTests { + " \n" + " \n" + "

\n" - + " \n" + + " \n" + " \n" + " \n" + " \n" + " \n" - + ""); + + ""; } } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java index 38616fa4c4..33a1570d30 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java @@ -106,14 +106,19 @@ public class CsrfWebFilter implements WebFilter { private Mono continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) { return csrfToken(exchange) .doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken)) + .doOnSuccess(csrfToken -> exchange.getAttributes().put(csrfToken.getParameterName(), csrfToken)) .flatMap( t -> chain.filter(exchange)) .then(); } - private Mono> csrfToken(ServerWebExchange exchange) { + private Mono csrfToken(ServerWebExchange exchange) { return this.serverCsrfTokenRepository.loadToken(exchange) - .switchIfEmpty(this.serverCsrfTokenRepository.generateToken(exchange)) - .as(Mono::just); // FIXME eager saving of CsrfToken with .as + .switchIfEmpty(generateToken(exchange)); + } + + private Mono generateToken(ServerWebExchange exchange) { + return this.serverCsrfTokenRepository.generateToken(exchange) + .flatMap(token -> this.serverCsrfTokenRepository.saveToken(exchange, token)); } private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher { diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java b/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java index 0e75316ebb..a5d543673d 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java @@ -74,4 +74,28 @@ public final class DefaultCsrfToken implements CsrfToken { public String getToken() { return this.token; } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || !(o instanceof CsrfToken)) + return false; + + CsrfToken that = (CsrfToken) o; + + if (!getToken().equals(that.getToken())) + return false; + if (!getParameterName().equals(that.getParameterName())) + return false; + return getHeaderName().equals(that.getHeaderName()); + } + + @Override + public int hashCode() { + int result = getToken().hashCode(); + result = 31 * result + getParameterName().hashCode(); + result = 31 * result + getHeaderName().hashCode(); + return result; + } } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java index 021359272a..80157c7272 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java @@ -49,12 +49,16 @@ public class WebSessionServerCsrfTokenRepository @Override public Mono generateToken(ServerWebExchange exchange) { - return Mono.defer(() -> Mono.just(createCsrfToken())) - .flatMap(token -> saveToken(exchange, token)); + return exchange.getSession() + .map(WebSession::getAttributes) + .map(this::createCsrfToken); } @Override public Mono saveToken(ServerWebExchange exchange, CsrfToken token) { + if(token != null) { + return Mono.just(token); + } return exchange.getSession() .map(WebSession::getAttributes) .flatMap( attrs -> save(attrs, token)); @@ -113,6 +117,11 @@ public class WebSessionServerCsrfTokenRepository this.sessionAttributeName = sessionAttributeName; } + + private CsrfToken createCsrfToken(Map attributes) { + return new LazyCsrfToken(attributes, createCsrfToken()); + } + private CsrfToken createCsrfToken() { return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken()); } @@ -120,4 +129,59 @@ public class WebSessionServerCsrfTokenRepository private String createNewToken() { return UUID.randomUUID().toString(); } + + private class LazyCsrfToken implements CsrfToken { + private final Map attributes; + private final CsrfToken delegate; + + private LazyCsrfToken(Map attributes, CsrfToken delegate) { + this.attributes = attributes; + this.delegate = delegate; + } + + @Override + public String getHeaderName() { + return this.delegate.getHeaderName(); + } + + @Override + public String getParameterName() { + return this.delegate.getParameterName(); + } + + @Override + public String getToken() { + putToken(this.attributes, this.delegate); + return this.delegate.getToken(); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || !(o instanceof CsrfToken)) + return false; + + CsrfToken that = (CsrfToken) o; + + if (!getToken().equals(that.getToken())) + return false; + if (!getParameterName().equals(that.getParameterName())) + return false; + return getHeaderName().equals(that.getHeaderName()); + } + + @Override + public int hashCode() { + int result = getToken().hashCode(); + result = 31 * result + getParameterName().hashCode(); + result = 31 * result + getHeaderName().hashCode(); + return result; + } + + @Override + public String toString() { + return "LazyCsrfToken{" + "delegate=" + this.delegate + '}'; + } + } } diff --git a/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java b/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java index 78d8a463a0..39d82b53ac 100644 --- a/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java @@ -61,9 +61,8 @@ public class LoginPageGeneratingWebFilter implements WebFilter { private Mono createBuffer(ServerWebExchange exchange) { MultiValueMap queryParams = exchange.getRequest() .getQueryParams(); - Mono token = (Mono) exchange.getAttributes() - .getOrDefault(CsrfToken.class.getName(), Mono.empty()); - return token + CsrfToken token = exchange.getAttribute(CsrfToken.class.getName()); + return Mono.justOrEmpty(token) .map(LoginPageGeneratingWebFilter::csrfToken) .defaultIfEmpty("") .map(csrfTokenHtmlInput -> { diff --git a/web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java b/web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java index aa411ce0d0..c86f5b65e9 100644 --- a/web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java @@ -58,9 +58,8 @@ public class LogoutPageGeneratingWebFilter implements WebFilter { } private Mono createBuffer(ServerWebExchange exchange) { - Mono token = (Mono) exchange.getAttributes() - .getOrDefault(CsrfToken.class.getName(), Mono.empty()); - return token + CsrfToken token = exchange.getAttribute(CsrfToken.class.getName()); + return Mono.justOrEmpty(token) .map(LogoutPageGeneratingWebFilter::csrfToken) .defaultIfEmpty("") .map(csrfTokenHtmlInput -> { diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java index 4480f11f75..e9fb8cc7ab 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java @@ -37,7 +37,7 @@ public class WebSessionServerCsrfTokenRepositoryTests { private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); @Test - public void generateTokenWhenNoSubscriptionThenNoSession() { + public void generateTokenThenNoSession() { Mono result = this.repository.generateToken(this.exchange); Mono isSessionStarted = this.exchange.getSession() @@ -49,43 +49,34 @@ public class WebSessionServerCsrfTokenRepositoryTests { } @Test - public void generateTokenWhenSubscriptionThenAddsToSession() { + public void generateTokenWhenSubscriptionThenNoSession() { Mono result = this.repository.generateToken(this.exchange); - StepVerifier.create(result) - .consumeNextWith( t -> assertThat(t).isNotNull()) - .verifyComplete(); - - WebSession session = this.exchange.getSession().block(); - Map attributes = session.getAttributes(); - - assertThat(session.isStarted()).isTrue(); - assertThat(attributes).hasSize(1); - assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class); + Mono isSessionStarted = this.exchange.getSession() + .map(WebSession::isStarted); + StepVerifier.create(isSessionStarted) + .expectNext(false) + .verifyComplete(); } @Test - public void saveTokenWhenSetSessionAttributeNameAndSubscriptionThenAddsToSession() { - CsrfToken token = new DefaultCsrfToken("h","p", "t"); - String attrName = "ATTR"; - this.repository.setSessionAttributeName(attrName); - Mono result = this.repository.saveToken(this.exchange, token); - - StepVerifier.create(result) - .consumeNextWith(n -> assertThat(n).isEqualTo(token)) - .verifyComplete(); + public void generateTokenWhenGetTokenThenAddsToSession() { + Mono result = this.repository.generateToken(this.exchange); + result.block().getToken(); WebSession session = this.exchange.getSession().block(); + Map attributes = session.getAttributes(); assertThat(session.isStarted()).isTrue(); - assertThat(session.getAttribute(attrName)).isEqualTo(token); + assertThat(attributes).hasSize(1); + assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class); } @Test public void saveTokenWhenNullThenDeletes() { - CsrfToken token = new DefaultCsrfToken("h","p", "t"); - this.repository.saveToken(this.exchange, token).block(); + CsrfToken token = this.repository.generateToken(this.exchange).block(); + token.getToken(); Mono result = this.repository.saveToken(this.exchange, null); StepVerifier.create(result) @@ -99,6 +90,7 @@ public class WebSessionServerCsrfTokenRepositoryTests { @Test public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() { CsrfToken generate = this.repository.generateToken(this.exchange).block(); + generate.getToken(); CsrfToken load = this.repository.loadToken(this.exchange).block(); assertThat(load).isEqualTo(generate);