From e6d6b397677598b62e5270309d3f0a0711d8e692 Mon Sep 17 00:00:00 2001
From: Rob Winch
* Applies
@@ -58,6 +61,7 @@ import static java.lang.Boolean.TRUE;
* @since 3.2
*/
public final class CsrfFilter extends OncePerRequestFilter {
+
/**
* The default {@link RequestMatcher} that indicates if CSRF protection is required or
* not. The default is to ignore GET, HEAD, TRACE, OPTIONS and process all other
@@ -66,18 +70,21 @@ public final class CsrfFilter extends OncePerRequestFilter {
public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher();
/**
- * The attribute name to use when marking a given request as one that should not be filtered.
+ * The attribute name to use when marking a given request as one that should not be
+ * filtered.
*
- * To use, set the attribute on your {@link HttpServletRequest}:
- *
+ * To use, set the attribute on your {@link HttpServletRequest}:
* CsrfFilter.skipRequest(request);
*
*/
private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfFilter.class.getName();
private final Log logger = LogFactory.getLog(getClass());
+
private final CsrfTokenRepository tokenRepository;
+
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
+
private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
@@ -87,62 +94,46 @@ public final class CsrfFilter extends OncePerRequestFilter {
@Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
- return TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER));
+ return Boolean.TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER));
}
- /*
- * (non-Javadoc)
- *
- * @see
- * org.springframework.web.filter.OncePerRequestFilter#doFilterInternal(javax.servlet
- * .http.HttpServletRequest, javax.servlet.http.HttpServletResponse,
- * javax.servlet.FilterChain)
- */
@Override
- protected void doFilterInternal(HttpServletRequest request,
- HttpServletResponse response, FilterChain filterChain)
- throws ServletException, IOException {
+ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
+ throws ServletException, IOException {
request.setAttribute(HttpServletResponse.class.getName(), response);
-
CsrfToken csrfToken = this.tokenRepository.loadToken(request);
- final boolean missingToken = csrfToken == null;
+ boolean missingToken = (csrfToken == null);
if (missingToken) {
csrfToken = this.tokenRepository.generateToken(request);
this.tokenRepository.saveToken(csrfToken, request, response);
}
request.setAttribute(CsrfToken.class.getName(), csrfToken);
request.setAttribute(csrfToken.getParameterName(), csrfToken);
-
if (!this.requireCsrfProtectionMatcher.matches(request)) {
+ if (this.logger.isTraceEnabled()) {
+ this.logger.trace("Did not protect against CSRF since request did not match "
+ + this.requireCsrfProtectionMatcher);
+ }
filterChain.doFilter(request, response);
return;
}
-
String actualToken = request.getHeader(csrfToken.getHeaderName());
if (actualToken == null) {
actualToken = request.getParameter(csrfToken.getParameterName());
}
- if (!csrfToken.getToken().equals(actualToken)) {
- if (this.logger.isDebugEnabled()) {
- this.logger.debug("Invalid CSRF token found for "
- + UrlUtils.buildFullRequestUrl(request));
- }
- if (missingToken) {
- this.accessDeniedHandler.handle(request, response,
- new MissingCsrfTokenException(actualToken));
- }
- else {
- this.accessDeniedHandler.handle(request, response,
- new InvalidCsrfTokenException(csrfToken, actualToken));
- }
+ if (!equalsConstantTime(csrfToken.getToken(), actualToken)) {
+ this.logger.debug(
+ LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)));
+ AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken)
+ : new MissingCsrfTokenException(actualToken);
+ this.accessDeniedHandler.handle(request, response, exception);
return;
}
-
filterChain.doFilter(request, response);
}
public static void skipRequest(HttpServletRequest request) {
- request.setAttribute(SHOULD_NOT_FILTER, TRUE);
+ request.setAttribute(SHOULD_NOT_FILTER, Boolean.TRUE);
}
/**
@@ -154,14 +145,11 @@ public final class CsrfFilter extends OncePerRequestFilter {
* The default is to apply CSRF protection for any HTTP method other than GET, HEAD,
* TRACE, OPTIONS.
*
* The default is to use AccessDeniedHandlerImpl with no arguments. *
- * * @param accessDeniedHandler the {@link AccessDeniedHandler} to use */ public void setAccessDeniedHandler(AccessDeniedHandler accessDeniedHandler) { @@ -180,20 +167,38 @@ public final class CsrfFilter extends OncePerRequestFilter { this.accessDeniedHandler = accessDeniedHandler; } + /** + * Constant time comparison to prevent against timing attacks. + * @param expected + * @param actual + * @return + */ + private static boolean equalsConstantTime(String expected, String actual) { + byte[] expectedBytes = bytesUtf8(expected); + byte[] actualBytes = bytesUtf8(actual); + return MessageDigest.isEqual(expectedBytes, actualBytes); + } + + private static byte[] bytesUtf8(String s) { + // need to check if Utf8.encode() runs in constant time (probably not). + // This may leak length of string. + return (s != null) ? Utf8.encode(s) : null; + } + private static final class DefaultRequiresCsrfMatcher implements RequestMatcher { - private final HashSet@@ -64,13 +66,14 @@ import static java.lang.Boolean.TRUE; * @since 5.0 */ public class CsrfWebFilter implements WebFilter { + public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher(); /** - * The attribute name to use when marking a given request as one that should not be filtered. + * The attribute name to use when marking a given request as one that should not be + * filtered. * - * To use, set the attribute on your {@link ServerWebExchange}: - *
+ * To use, set the attribute on your {@link ServerWebExchange}:
* CsrfWebFilter.skipExchange(exchange);
*
*/
@@ -80,32 +83,31 @@ public class CsrfWebFilter implements WebFilter {
private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();
- private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN);
+ private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(
+ HttpStatus.FORBIDDEN);
private boolean isTokenFromMultipartDataEnabled;
- public void setAccessDeniedHandler(
- ServerAccessDeniedHandler accessDeniedHandler) {
+ public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) {
Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
this.accessDeniedHandler = accessDeniedHandler;
}
- public void setCsrfTokenRepository(
- ServerCsrfTokenRepository csrfTokenRepository) {
+ public void setCsrfTokenRepository(ServerCsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
this.csrfTokenRepository = csrfTokenRepository;
}
- public void setRequireCsrfProtectionMatcher(
- ServerWebExchangeMatcher requireCsrfProtectionMatcher) {
+ public void setRequireCsrfProtectionMatcher(ServerWebExchangeMatcher requireCsrfProtectionMatcher) {
Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null");
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
}
/**
- * Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart
- * data requests.
- * @param tokenFromMultipartDataEnabled true if should read from multipart form body, else false. Default is false
+ * Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token
+ * from the body of multipart data requests.
+ * @param tokenFromMultipartDataEnabled true if should read from multipart form body,
+ * else false. Default is false
*/
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
@@ -113,38 +115,33 @@ public class CsrfWebFilter implements WebFilter {
@Override
public Mono filter(ServerWebExchange exchange, WebFilterChain chain) {
- if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
+ if (Boolean.TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
return chain.filter(exchange).then(Mono.empty());
}
-
- return this.requireCsrfProtectionMatcher.matches(exchange)
- .filter( matchResult -> matchResult.isMatch())
- .filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
- .flatMap(m -> validateToken(exchange))
- .flatMap(m -> continueFilterChain(exchange, chain))
- .switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty()))
- .onErrorResume(CsrfException.class, e -> this.accessDeniedHandler
- .handle(exchange, e));
+ return this.requireCsrfProtectionMatcher.matches(exchange).filter(MatchResult::isMatch)
+ .filter((matchResult) -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
+ .flatMap((m) -> validateToken(exchange)).flatMap((m) -> continueFilterChain(exchange, chain))
+ .switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty()))
+ .onErrorResume(CsrfException.class, (ex) -> this.accessDeniedHandler.handle(exchange, ex));
}
public static void skipExchange(ServerWebExchange exchange) {
- exchange.getAttributes().put(SHOULD_NOT_FILTER, TRUE);
+ exchange.getAttributes().put(SHOULD_NOT_FILTER, Boolean.TRUE);
}
private Mono validateToken(ServerWebExchange exchange) {
return this.csrfTokenRepository.loadToken(exchange)
- .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("An expected CSRF token cannot be found"))))
- .filterWhen(expected -> containsValidCsrfToken(exchange, expected))
- .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("Invalid CSRF Token"))))
- .then();
+ .switchIfEmpty(
+ Mono.defer(() -> Mono.error(new CsrfException("An expected CSRF token cannot be found"))))
+ .filterWhen((expected) -> containsValidCsrfToken(exchange, expected))
+ .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("Invalid CSRF Token")))).then();
}
private Mono containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) {
- return exchange.getFormData()
- .flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
- .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
- .switchIfEmpty(tokenFromMultipartData(exchange, expected))
- .map(actual -> actual.equals(expected.getToken()));
+ return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
+ .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
+ .switchIfEmpty(tokenFromMultipartData(exchange, expected))
+ .map((actual) -> equalsConstantTime(actual, expected.getToken()));
}
private Mono tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
@@ -157,14 +154,12 @@ public class CsrfWebFilter implements WebFilter {
if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) {
return Mono.empty();
}
- return exchange.getMultipartData()
- .map(d -> d.getFirst(expected.getParameterName()))
- .cast(FormFieldPart.class)
- .map(FormFieldPart::value);
+ return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class)
+ .map(FormFieldPart::value);
}
private Mono continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
- return Mono.defer(() ->{
+ return Mono.defer(() -> {
Mono csrfToken = csrfToken(exchange);
exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken);
return chain.filter(exchange);
@@ -172,26 +167,44 @@ public class CsrfWebFilter implements WebFilter {
}
private Mono csrfToken(ServerWebExchange exchange) {
- return this.csrfTokenRepository.loadToken(exchange)
- .switchIfEmpty(generateToken(exchange));
+ return this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange));
+ }
+
+ /**
+ * Constant time comparison to prevent against timing attacks.
+ * @param expected
+ * @param actual
+ * @return
+ */
+ private static boolean equalsConstantTime(String expected, String actual) {
+ byte[] expectedBytes = bytesUtf8(expected);
+ byte[] actualBytes = bytesUtf8(actual);
+ return MessageDigest.isEqual(expectedBytes, actualBytes);
+ }
+
+ private static byte[] bytesUtf8(String s) {
+ // need to check if Utf8.encode() runs in constant time (probably not).
+ // This may leak length of string.
+ return (s != null) ? Utf8.encode(s) : null;
}
private Mono generateToken(ServerWebExchange exchange) {
return this.csrfTokenRepository.generateToken(exchange)
- .delayUntil(token -> this.csrfTokenRepository.saveToken(exchange, token));
+ .delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token));
}
private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {
+
private static final Set ALLOWED_METHODS = new HashSet<>(
- Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS));
+ Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS));
@Override
public Mono matches(ServerWebExchange exchange) {
- return Mono.just(exchange.getRequest())
- .flatMap(r -> Mono.justOrEmpty(r.getMethod()))
- .filter(m -> ALLOWED_METHODS.contains(m))
- .flatMap(m -> MatchResult.notMatch())
- .switchIfEmpty(MatchResult.match());
+ return Mono.just(exchange.getRequest()).flatMap((r) -> Mono.justOrEmpty(r.getMethod()))
+ .filter(ALLOWED_METHODS::contains).flatMap((m) -> MatchResult.notMatch())
+ .switchIfEmpty(MatchResult.match());
}
+
}
+
}