From d55db837e10570e24d0355c7be070dcb591c091c Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Mon, 20 Nov 2017 14:16:49 -0600 Subject: [PATCH] CsrfWebFilter places Mono Fixes: gh-4855 --- .../config/web/server/FormLoginTests.java | 11 +-- .../config/web/server/RequestCacheTests.java | 2 - .../java/sample/CsrfControllerAdvice.java | 38 ++++++++++ .../view/CsrfRequestDataValueProcessor.java | 6 +- .../web/server/csrf/CsrfWebFilter.java | 16 +++-- .../WebSessionServerCsrfTokenRepository.java | 71 +------------------ .../ui/LoginPageGeneratingWebFilter.java | 4 +- .../ui/LogoutPageGeneratingWebFilter.java | 4 +- .../CsrfRequestDataValueProcessorTests.java | 5 +- .../web/server/csrf/CsrfWebFilterTests.java | 6 -- ...SessionServerCsrfTokenRepositoryTests.java | 24 ++----- 11 files changed, 73 insertions(+), 114 deletions(-) create mode 100644 samples/javaconfig/webflux-form/src/main/java/sample/CsrfControllerAdvice.java diff --git a/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java index 3dc4ba41b7..70b0fa60f8 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java @@ -33,6 +33,7 @@ import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThat; @@ -314,9 +315,9 @@ public class FormLoginTests { public static class CustomLoginPageController { @ResponseBody @GetMapping("/login") - public String login(ServerWebExchange exchange) { - CsrfToken token = exchange.getAttribute(CsrfToken.class.getName()); - return + public Mono login(ServerWebExchange exchange) { + Mono token = exchange.getAttributeOrDefault(CsrfToken.class.getName(), Mono.empty()); + return token.map(t -> "\n" + "\n" + " \n" @@ -338,12 +339,12 @@ public class FormLoginTests { + " \n" + " \n" + "

\n" - + " \n" + + " \n" + " \n" + " \n" + " \n" + " \n" - + ""; + + ""); } } } diff --git a/config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java b/config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java index 994ca5ee1f..d163db34a4 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java @@ -26,7 +26,6 @@ import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverB import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.WebFilterChainProxy; -import org.springframework.security.web.server.csrf.CsrfToken; import org.springframework.security.web.server.savedrequest.NoOpServerRequestCache; import org.springframework.stereotype.Controller; import org.springframework.test.web.reactive.server.WebTestClient; @@ -126,7 +125,6 @@ public class RequestCacheTests { @ResponseBody @GetMapping("/secured") public String login(ServerWebExchange exchange) { - CsrfToken token = exchange.getAttribute(CsrfToken.class.getName()); return "\n" + "\n" diff --git a/samples/javaconfig/webflux-form/src/main/java/sample/CsrfControllerAdvice.java b/samples/javaconfig/webflux-form/src/main/java/sample/CsrfControllerAdvice.java new file mode 100644 index 0000000000..fbc695f192 --- /dev/null +++ b/samples/javaconfig/webflux-form/src/main/java/sample/CsrfControllerAdvice.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sample; + +import org.springframework.security.web.server.csrf.CsrfToken; +import org.springframework.web.bind.annotation.ControllerAdvice; +import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +import static org.springframework.security.web.reactive.result.view.CsrfRequestDataValueProcessor.DEFAULT_CSRF_ATTR_NAME; + +/** + * @author Rob Winch + * @since 5.0 + */ +@ControllerAdvice +public class CsrfControllerAdvice { + @ModelAttribute + public Mono csrfToken(ServerWebExchange exchange) { + Mono csrfToken = exchange.getAttribute(CsrfToken.class.getName()); + return csrfToken.doOnSuccess(token -> exchange.getAttributes().put(DEFAULT_CSRF_ATTR_NAME, token)); + } +} diff --git a/web/src/main/java/org/springframework/security/web/reactive/result/view/CsrfRequestDataValueProcessor.java b/web/src/main/java/org/springframework/security/web/reactive/result/view/CsrfRequestDataValueProcessor.java index a962a127ab..506ec2c2fe 100644 --- a/web/src/main/java/org/springframework/security/web/reactive/result/view/CsrfRequestDataValueProcessor.java +++ b/web/src/main/java/org/springframework/security/web/reactive/result/view/CsrfRequestDataValueProcessor.java @@ -30,6 +30,10 @@ import java.util.regex.Pattern; * @since 5.0 */ public class CsrfRequestDataValueProcessor implements RequestDataValueProcessor { + /** + * The default request attribute to look for a {@link CsrfToken}. + */ + public static final String DEFAULT_CSRF_ATTR_NAME = "_csrf"; private static final Pattern DISABLE_CSRF_TOKEN_PATTERN = Pattern .compile("(?i)^(GET|HEAD|TRACE|OPTIONS)$"); @@ -62,7 +66,7 @@ public class CsrfRequestDataValueProcessor implements RequestDataValueProcessor exchange.getAttributes().remove(DISABLE_CSRF_TOKEN_ATTR); return Collections.emptyMap(); } - CsrfToken token = exchange.getAttribute(CsrfToken.class.getName()); + CsrfToken token = exchange.getAttribute(DEFAULT_CSRF_ATTR_NAME); if(token == null) { return Collections.emptyMap(); } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java index ab8e748339..f797fbe6e7 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java @@ -47,12 +47,16 @@ import java.util.Set; * {@link WebSessionServerCsrfTokenRepository}. This is preferred to storing the token in * a cookie which can be modified by a client application. *

+ *

+ * The {@code Mono<CsrfToken>} is exposes as a request attribute with the name of + * {@code CsrfToken.class.getName()}. If the token is new it will automatically be saved + * at the time it is subscribed. + *

* * @author Rob Winch * @since 5.0 */ public class CsrfWebFilter implements WebFilter { - private ServerWebExchangeMatcher requireCsrfProtectionMatcher = new DefaultRequireCsrfProtectionMatcher(); private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository(); @@ -105,11 +109,11 @@ public class CsrfWebFilter implements WebFilter { } private Mono continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) { - return csrfToken(exchange) - .doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken)) - .doOnSuccess(csrfToken -> exchange.getAttributes().put(csrfToken.getParameterName(), csrfToken)) - .flatMap( t -> chain.filter(exchange)) - .then(); + return Mono.defer(() ->{ + Mono csrfToken = csrfToken(exchange); + exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken); + return chain.filter(exchange); + }); } private Mono csrfToken(ServerWebExchange exchange) { diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java index 47ad336d82..908f00ae07 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java @@ -17,7 +17,6 @@ package org.springframework.security.web.server.csrf; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; -import org.springframework.web.server.WebSession; import reactor.core.publisher.Mono; import javax.servlet.http.HttpServletRequest; @@ -49,20 +48,15 @@ public class WebSessionServerCsrfTokenRepository @Override public Mono generateToken(ServerWebExchange exchange) { - return exchange.getSession() - .map(WebSession::getAttributes) - .map(this::createCsrfToken); + return Mono.fromCallable(() -> createCsrfToken()); } @Override public Mono saveToken(ServerWebExchange exchange, CsrfToken token) { - if(token != null) { - return Mono.just(token); - } return exchange.getSession() - .doOnSuccess(session -> putToken(session.getAttributes(), token)) + .doOnNext(session -> putToken(session.getAttributes(), token)) .flatMap(session -> session.changeSessionId()) - .flatMap(r -> Mono.justOrEmpty(token)); + .then(Mono.justOrEmpty(token)); } private void putToken(Map attributes, CsrfToken token) { @@ -111,11 +105,6 @@ public class WebSessionServerCsrfTokenRepository this.sessionAttributeName = sessionAttributeName; } - - private CsrfToken createCsrfToken(Map attributes) { - return new LazyCsrfToken(attributes, createCsrfToken()); - } - private CsrfToken createCsrfToken() { return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken()); } @@ -124,58 +113,4 @@ public class WebSessionServerCsrfTokenRepository return UUID.randomUUID().toString(); } - private class LazyCsrfToken implements CsrfToken { - private final Map attributes; - private final CsrfToken delegate; - - private LazyCsrfToken(Map attributes, CsrfToken delegate) { - this.attributes = attributes; - this.delegate = delegate; - } - - @Override - public String getHeaderName() { - return this.delegate.getHeaderName(); - } - - @Override - public String getParameterName() { - return this.delegate.getParameterName(); - } - - @Override - public String getToken() { - putToken(this.attributes, this.delegate); - return this.delegate.getToken(); - } - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || !(o instanceof CsrfToken)) - return false; - - CsrfToken that = (CsrfToken) o; - - if (!getToken().equals(that.getToken())) - return false; - if (!getParameterName().equals(that.getParameterName())) - return false; - return getHeaderName().equals(that.getHeaderName()); - } - - @Override - public int hashCode() { - int result = getToken().hashCode(); - result = 31 * result + getParameterName().hashCode(); - result = 31 * result + getHeaderName().hashCode(); - return result; - } - - @Override - public String toString() { - return "LazyCsrfToken{" + "delegate=" + this.delegate + '}'; - } - } } diff --git a/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java b/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java index 32b622f33e..e077ebcb2a 100644 --- a/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java @@ -60,8 +60,8 @@ public class LoginPageGeneratingWebFilter implements WebFilter { private Mono createBuffer(ServerWebExchange exchange) { MultiValueMap queryParams = exchange.getRequest() .getQueryParams(); - CsrfToken token = exchange.getAttribute(CsrfToken.class.getName()); - return Mono.justOrEmpty(token) + Mono token = exchange.getAttributeOrDefault(CsrfToken.class.getName(), Mono.empty()); + return token .map(LoginPageGeneratingWebFilter::csrfToken) .defaultIfEmpty("") .map(csrfTokenHtmlInput -> { diff --git a/web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java b/web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java index 96073d896d..7100a7b99e 100644 --- a/web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java @@ -57,8 +57,8 @@ public class LogoutPageGeneratingWebFilter implements WebFilter { } private Mono createBuffer(ServerWebExchange exchange) { - CsrfToken token = exchange.getAttribute(CsrfToken.class.getName()); - return Mono.justOrEmpty(token) + Mono token = exchange.getAttributeOrDefault(CsrfToken.class.getName(), Mono.empty()); + return token .map(LogoutPageGeneratingWebFilter::csrfToken) .defaultIfEmpty("") .map(csrfTokenHtmlInput -> { diff --git a/web/src/test/java/org/springframework/security/web/reactive/result/view/CsrfRequestDataValueProcessorTests.java b/web/src/test/java/org/springframework/security/web/reactive/result/view/CsrfRequestDataValueProcessorTests.java index 0091370c09..fe33730317 100644 --- a/web/src/test/java/org/springframework/security/web/reactive/result/view/CsrfRequestDataValueProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/reactive/result/view/CsrfRequestDataValueProcessorTests.java @@ -30,6 +30,7 @@ import java.util.HashMap; import java.util.Map; import static org.assertj.core.api.Assertions.*; +import static org.springframework.security.web.reactive.result.view.CsrfRequestDataValueProcessor.DEFAULT_CSRF_ATTR_NAME; /** * @author Rob Winch @@ -46,7 +47,7 @@ public class CsrfRequestDataValueProcessorTests { @Before public void setup() { this.expected.put(this.token.getParameterName(), this.token.getToken()); - this.exchange.getAttributes().put(CsrfToken.class.getName(), this.token); + this.exchange.getAttributes().put(DEFAULT_CSRF_ATTR_NAME, this.token); } @Test @@ -122,7 +123,7 @@ public class CsrfRequestDataValueProcessorTests { @Test public void createGetExtraHiddenFieldsHasCsrfToken() { CsrfToken token = new DefaultCsrfToken("1", "a", "b"); - this.exchange.getAttributes().put(CsrfToken.class.getName(), token); + this.exchange.getAttributes().put(DEFAULT_CSRF_ATTR_NAME, token); Map expected = new HashMap(); expected.put(token.getParameterName(), token.getToken()); diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java index e033992774..d6a4a4c272 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java @@ -89,8 +89,6 @@ public class CsrfWebFilterTests { this.csrfFilter.setCsrfTokenRepository(this.repository); when(this.repository.loadToken(any())) .thenReturn(Mono.just(this.token)); - when(this.repository.generateToken(any())) - .thenReturn(Mono.just(this.token)); Mono result = this.csrfFilter.filter(this.post, this.chain); @@ -106,8 +104,6 @@ public class CsrfWebFilterTests { this.csrfFilter.setCsrfTokenRepository(this.repository); when(this.repository.loadToken(any())) .thenReturn(Mono.just(this.token)); - when(this.repository.generateToken(any())) - .thenReturn(Mono.just(this.token)); this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") .body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID")); @@ -146,8 +142,6 @@ public class CsrfWebFilterTests { this.csrfFilter.setCsrfTokenRepository(this.repository); when(this.repository.loadToken(any())) .thenReturn(Mono.just(this.token)); - when(this.repository.generateToken(any())) - .thenReturn(Mono.just(this.token)); this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") .header(this.token.getHeaderName(), this.token.getToken()+"INVALID")); diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java index 06832c0fc5..41627e31ba 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java @@ -61,9 +61,10 @@ public class WebSessionServerCsrfTokenRepositoryTests { } @Test - public void generateTokenWhenGetTokenThenAddsToSession() { - Mono result = this.repository.generateToken(this.exchange); - result.block().getToken(); + public void saveTokenWhenDefaultThenAddsToSession() { + Mono result = this.repository.generateToken(this.exchange) + .delayUntil(t-> this.repository.saveToken(this.exchange, t)); + result.block(); WebSession session = this.exchange.getSession().block(); Map attributes = session.getAttributes(); @@ -76,7 +77,6 @@ public class WebSessionServerCsrfTokenRepositoryTests { @Test public void saveTokenWhenNullThenDeletes() { CsrfToken token = this.repository.generateToken(this.exchange).block(); - token.getToken(); Mono result = this.repository.saveToken(this.exchange, null); StepVerifier.create(result) @@ -87,22 +87,6 @@ public class WebSessionServerCsrfTokenRepositoryTests { assertThat(session.getAttributes()).isEmpty(); } - @Test - public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() { - CsrfToken generate = this.repository.generateToken(this.exchange).block(); - generate.getToken(); - - CsrfToken load = this.repository.loadToken(this.exchange).block(); - assertThat(load).isEqualTo(generate); - - this.repository.saveToken(this.exchange, null).block(); - WebSession session = this.exchange.getSession().block(); - assertThat(session.getAttributes()).isEmpty(); - - load = this.repository.loadToken(this.exchange).block(); - assertThat(load).isNull(); - } - @Test public void saveTokenChangeSessionId() { String originalSessionId = this.exchange.getSession().block().getId();