|
|
|
@ -16,26 +16,28 @@ |
|
|
|
|
|
|
|
|
|
|
|
package org.springframework.security.web.server.csrf; |
|
|
|
package org.springframework.security.web.server.csrf; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import java.security.MessageDigest; |
|
|
|
|
|
|
|
import java.util.Arrays; |
|
|
|
|
|
|
|
import java.util.HashSet; |
|
|
|
|
|
|
|
import java.util.Set; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import reactor.core.publisher.Mono; |
|
|
|
|
|
|
|
|
|
|
|
import org.springframework.http.HttpHeaders; |
|
|
|
import org.springframework.http.HttpHeaders; |
|
|
|
import org.springframework.http.HttpMethod; |
|
|
|
import org.springframework.http.HttpMethod; |
|
|
|
import org.springframework.http.HttpStatus; |
|
|
|
import org.springframework.http.HttpStatus; |
|
|
|
import org.springframework.http.MediaType; |
|
|
|
import org.springframework.http.MediaType; |
|
|
|
import org.springframework.http.codec.multipart.FormFieldPart; |
|
|
|
import org.springframework.http.codec.multipart.FormFieldPart; |
|
|
|
import org.springframework.http.server.reactive.ServerHttpRequest; |
|
|
|
import org.springframework.http.server.reactive.ServerHttpRequest; |
|
|
|
|
|
|
|
import org.springframework.security.crypto.codec.Utf8; |
|
|
|
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; |
|
|
|
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; |
|
|
|
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; |
|
|
|
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; |
|
|
|
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; |
|
|
|
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; |
|
|
|
|
|
|
|
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult; |
|
|
|
import org.springframework.util.Assert; |
|
|
|
import org.springframework.util.Assert; |
|
|
|
import org.springframework.web.server.ServerWebExchange; |
|
|
|
import org.springframework.web.server.ServerWebExchange; |
|
|
|
import org.springframework.web.server.WebFilter; |
|
|
|
import org.springframework.web.server.WebFilter; |
|
|
|
import org.springframework.web.server.WebFilterChain; |
|
|
|
import org.springframework.web.server.WebFilterChain; |
|
|
|
import reactor.core.publisher.Mono; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import java.util.Arrays; |
|
|
|
|
|
|
|
import java.util.HashSet; |
|
|
|
|
|
|
|
import java.util.Set; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import static java.lang.Boolean.TRUE; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
/** |
|
|
|
* <p> |
|
|
|
* <p> |
|
|
|
@ -64,13 +66,14 @@ import static java.lang.Boolean.TRUE; |
|
|
|
* @since 5.0 |
|
|
|
* @since 5.0 |
|
|
|
*/ |
|
|
|
*/ |
|
|
|
public class CsrfWebFilter implements WebFilter { |
|
|
|
public class CsrfWebFilter implements WebFilter { |
|
|
|
|
|
|
|
|
|
|
|
public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher(); |
|
|
|
public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher(); |
|
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
/** |
|
|
|
* The attribute name to use when marking a given request as one that should not be filtered. |
|
|
|
* The attribute name to use when marking a given request as one that should not be |
|
|
|
|
|
|
|
* filtered. |
|
|
|
* |
|
|
|
* |
|
|
|
* To use, set the attribute on your {@link ServerWebExchange}: |
|
|
|
* To use, set the attribute on your {@link ServerWebExchange}: <pre> |
|
|
|
* <pre> |
|
|
|
|
|
|
|
* CsrfWebFilter.skipExchange(exchange); |
|
|
|
* CsrfWebFilter.skipExchange(exchange); |
|
|
|
* </pre> |
|
|
|
* </pre> |
|
|
|
*/ |
|
|
|
*/ |
|
|
|
@ -80,32 +83,31 @@ public class CsrfWebFilter implements WebFilter { |
|
|
|
|
|
|
|
|
|
|
|
private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository(); |
|
|
|
private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository(); |
|
|
|
|
|
|
|
|
|
|
|
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN); |
|
|
|
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler( |
|
|
|
|
|
|
|
HttpStatus.FORBIDDEN); |
|
|
|
|
|
|
|
|
|
|
|
private boolean isTokenFromMultipartDataEnabled; |
|
|
|
private boolean isTokenFromMultipartDataEnabled; |
|
|
|
|
|
|
|
|
|
|
|
public void setAccessDeniedHandler( |
|
|
|
public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) { |
|
|
|
ServerAccessDeniedHandler accessDeniedHandler) { |
|
|
|
|
|
|
|
Assert.notNull(accessDeniedHandler, "accessDeniedHandler"); |
|
|
|
Assert.notNull(accessDeniedHandler, "accessDeniedHandler"); |
|
|
|
this.accessDeniedHandler = accessDeniedHandler; |
|
|
|
this.accessDeniedHandler = accessDeniedHandler; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public void setCsrfTokenRepository( |
|
|
|
public void setCsrfTokenRepository(ServerCsrfTokenRepository csrfTokenRepository) { |
|
|
|
ServerCsrfTokenRepository csrfTokenRepository) { |
|
|
|
|
|
|
|
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); |
|
|
|
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); |
|
|
|
this.csrfTokenRepository = csrfTokenRepository; |
|
|
|
this.csrfTokenRepository = csrfTokenRepository; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public void setRequireCsrfProtectionMatcher( |
|
|
|
public void setRequireCsrfProtectionMatcher(ServerWebExchangeMatcher requireCsrfProtectionMatcher) { |
|
|
|
ServerWebExchangeMatcher requireCsrfProtectionMatcher) { |
|
|
|
|
|
|
|
Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null"); |
|
|
|
Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null"); |
|
|
|
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher; |
|
|
|
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
/** |
|
|
|
* Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart |
|
|
|
* Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token |
|
|
|
* data requests. |
|
|
|
* from the body of multipart data requests. |
|
|
|
* @param tokenFromMultipartDataEnabled true if should read from multipart form body, else false. Default is false |
|
|
|
* @param tokenFromMultipartDataEnabled true if should read from multipart form body, |
|
|
|
|
|
|
|
* else false. Default is false |
|
|
|
*/ |
|
|
|
*/ |
|
|
|
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) { |
|
|
|
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) { |
|
|
|
this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled; |
|
|
|
this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled; |
|
|
|
@ -113,38 +115,33 @@ public class CsrfWebFilter implements WebFilter { |
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
@Override |
|
|
|
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { |
|
|
|
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { |
|
|
|
if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) { |
|
|
|
if (Boolean.TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) { |
|
|
|
return chain.filter(exchange).then(Mono.empty()); |
|
|
|
return chain.filter(exchange).then(Mono.empty()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return this.requireCsrfProtectionMatcher.matches(exchange).filter(MatchResult::isMatch) |
|
|
|
return this.requireCsrfProtectionMatcher.matches(exchange) |
|
|
|
.filter((matchResult) -> !exchange.getAttributes().containsKey(CsrfToken.class.getName())) |
|
|
|
.filter( matchResult -> matchResult.isMatch()) |
|
|
|
.flatMap((m) -> validateToken(exchange)).flatMap((m) -> continueFilterChain(exchange, chain)) |
|
|
|
.filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName())) |
|
|
|
.switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty())) |
|
|
|
.flatMap(m -> validateToken(exchange)) |
|
|
|
.onErrorResume(CsrfException.class, (ex) -> this.accessDeniedHandler.handle(exchange, ex)); |
|
|
|
.flatMap(m -> continueFilterChain(exchange, chain)) |
|
|
|
|
|
|
|
.switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty())) |
|
|
|
|
|
|
|
.onErrorResume(CsrfException.class, e -> this.accessDeniedHandler |
|
|
|
|
|
|
|
.handle(exchange, e)); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public static void skipExchange(ServerWebExchange exchange) { |
|
|
|
public static void skipExchange(ServerWebExchange exchange) { |
|
|
|
exchange.getAttributes().put(SHOULD_NOT_FILTER, TRUE); |
|
|
|
exchange.getAttributes().put(SHOULD_NOT_FILTER, Boolean.TRUE); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private Mono<Void> validateToken(ServerWebExchange exchange) { |
|
|
|
private Mono<Void> validateToken(ServerWebExchange exchange) { |
|
|
|
return this.csrfTokenRepository.loadToken(exchange) |
|
|
|
return this.csrfTokenRepository.loadToken(exchange) |
|
|
|
.switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("An expected CSRF token cannot be found")))) |
|
|
|
.switchIfEmpty( |
|
|
|
.filterWhen(expected -> containsValidCsrfToken(exchange, expected)) |
|
|
|
Mono.defer(() -> Mono.error(new CsrfException("An expected CSRF token cannot be found")))) |
|
|
|
.switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("Invalid CSRF Token")))) |
|
|
|
.filterWhen((expected) -> containsValidCsrfToken(exchange, expected)) |
|
|
|
.then(); |
|
|
|
.switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("Invalid CSRF Token")))).then(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) { |
|
|
|
private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) { |
|
|
|
return exchange.getFormData() |
|
|
|
return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName()))) |
|
|
|
.flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName()))) |
|
|
|
.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName()))) |
|
|
|
.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName()))) |
|
|
|
.switchIfEmpty(tokenFromMultipartData(exchange, expected)) |
|
|
|
.switchIfEmpty(tokenFromMultipartData(exchange, expected)) |
|
|
|
.map((actual) -> equalsConstantTime(actual, expected.getToken())); |
|
|
|
.map(actual -> actual.equals(expected.getToken())); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) { |
|
|
|
private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) { |
|
|
|
@ -157,14 +154,12 @@ public class CsrfWebFilter implements WebFilter { |
|
|
|
if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) { |
|
|
|
if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) { |
|
|
|
return Mono.empty(); |
|
|
|
return Mono.empty(); |
|
|
|
} |
|
|
|
} |
|
|
|
return exchange.getMultipartData() |
|
|
|
return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class) |
|
|
|
.map(d -> d.getFirst(expected.getParameterName())) |
|
|
|
.map(FormFieldPart::value); |
|
|
|
.cast(FormFieldPart.class) |
|
|
|
|
|
|
|
.map(FormFieldPart::value); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) { |
|
|
|
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) { |
|
|
|
return Mono.defer(() ->{ |
|
|
|
return Mono.defer(() -> { |
|
|
|
Mono<CsrfToken> csrfToken = csrfToken(exchange); |
|
|
|
Mono<CsrfToken> csrfToken = csrfToken(exchange); |
|
|
|
exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken); |
|
|
|
exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken); |
|
|
|
return chain.filter(exchange); |
|
|
|
return chain.filter(exchange); |
|
|
|
@ -172,26 +167,44 @@ public class CsrfWebFilter implements WebFilter { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) { |
|
|
|
private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) { |
|
|
|
return this.csrfTokenRepository.loadToken(exchange) |
|
|
|
return this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange)); |
|
|
|
.switchIfEmpty(generateToken(exchange)); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
|
|
|
* Constant time comparison to prevent against timing attacks. |
|
|
|
|
|
|
|
* @param expected |
|
|
|
|
|
|
|
* @param actual |
|
|
|
|
|
|
|
* @return |
|
|
|
|
|
|
|
*/ |
|
|
|
|
|
|
|
private static boolean equalsConstantTime(String expected, String actual) { |
|
|
|
|
|
|
|
byte[] expectedBytes = bytesUtf8(expected); |
|
|
|
|
|
|
|
byte[] actualBytes = bytesUtf8(actual); |
|
|
|
|
|
|
|
return MessageDigest.isEqual(expectedBytes, actualBytes); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private static byte[] bytesUtf8(String s) { |
|
|
|
|
|
|
|
// need to check if Utf8.encode() runs in constant time (probably not).
|
|
|
|
|
|
|
|
// This may leak length of string.
|
|
|
|
|
|
|
|
return (s != null) ? Utf8.encode(s) : null; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private Mono<CsrfToken> generateToken(ServerWebExchange exchange) { |
|
|
|
private Mono<CsrfToken> generateToken(ServerWebExchange exchange) { |
|
|
|
return this.csrfTokenRepository.generateToken(exchange) |
|
|
|
return this.csrfTokenRepository.generateToken(exchange) |
|
|
|
.delayUntil(token -> this.csrfTokenRepository.saveToken(exchange, token)); |
|
|
|
.delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher { |
|
|
|
private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher { |
|
|
|
|
|
|
|
|
|
|
|
private static final Set<HttpMethod> ALLOWED_METHODS = new HashSet<>( |
|
|
|
private static final Set<HttpMethod> ALLOWED_METHODS = new HashSet<>( |
|
|
|
Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS)); |
|
|
|
Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS)); |
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
@Override |
|
|
|
public Mono<MatchResult> matches(ServerWebExchange exchange) { |
|
|
|
public Mono<MatchResult> matches(ServerWebExchange exchange) { |
|
|
|
return Mono.just(exchange.getRequest()) |
|
|
|
return Mono.just(exchange.getRequest()).flatMap((r) -> Mono.justOrEmpty(r.getMethod())) |
|
|
|
.flatMap(r -> Mono.justOrEmpty(r.getMethod())) |
|
|
|
.filter(ALLOWED_METHODS::contains).flatMap((m) -> MatchResult.notMatch()) |
|
|
|
.filter(m -> ALLOWED_METHODS.contains(m)) |
|
|
|
.switchIfEmpty(MatchResult.match()); |
|
|
|
.flatMap(m -> MatchResult.notMatch()) |
|
|
|
|
|
|
|
.switchIfEmpty(MatchResult.match()); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|