@ -16,26 +16,28 @@
@@ -16,26 +16,28 @@
package org.springframework.security.web.server.csrf ;
import java.security.MessageDigest ;
import java.util.Arrays ;
import java.util.HashSet ;
import java.util.Set ;
import reactor.core.publisher.Mono ;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.HttpMethod ;
import org.springframework.http.HttpStatus ;
import org.springframework.http.MediaType ;
import org.springframework.http.codec.multipart.FormFieldPart ;
import org.springframework.http.server.reactive.ServerHttpRequest ;
import org.springframework.security.crypto.codec.Utf8 ;
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler ;
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler ;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher ;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult ;
import org.springframework.util.Assert ;
import org.springframework.web.server.ServerWebExchange ;
import org.springframework.web.server.WebFilter ;
import org.springframework.web.server.WebFilterChain ;
import reactor.core.publisher.Mono ;
import java.util.Arrays ;
import java.util.HashSet ;
import java.util.Set ;
import static java.lang.Boolean.TRUE ;
/ * *
* < p >
@ -64,13 +66,14 @@ import static java.lang.Boolean.TRUE;
@@ -64,13 +66,14 @@ import static java.lang.Boolean.TRUE;
* @since 5 . 0
* /
public class CsrfWebFilter implements WebFilter {
public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher ( ) ;
/ * *
* The attribute name to use when marking a given request as one that should not be filtered .
* The attribute name to use when marking a given request as one that should not be
* filtered .
*
* To use , set the attribute on your { @link ServerWebExchange } :
* < pre >
* To use , set the attribute on your { @link ServerWebExchange } : < pre >
* CsrfWebFilter . skipExchange ( exchange ) ;
* < / pre >
* /
@ -80,32 +83,31 @@ public class CsrfWebFilter implements WebFilter {
@@ -80,32 +83,31 @@ public class CsrfWebFilter implements WebFilter {
private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository ( ) ;
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler ( HttpStatus . FORBIDDEN ) ;
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler (
HttpStatus . FORBIDDEN ) ;
private boolean isTokenFromMultipartDataEnabled ;
public void setAccessDeniedHandler (
ServerAccessDeniedHandler accessDeniedHandler ) {
public void setAccessDeniedHandler ( ServerAccessDeniedHandler accessDeniedHandler ) {
Assert . notNull ( accessDeniedHandler , "accessDeniedHandler" ) ;
this . accessDeniedHandler = accessDeniedHandler ;
}
public void setCsrfTokenRepository (
ServerCsrfTokenRepository csrfTokenRepository ) {
public void setCsrfTokenRepository ( ServerCsrfTokenRepository csrfTokenRepository ) {
Assert . notNull ( csrfTokenRepository , "csrfTokenRepository cannot be null" ) ;
this . csrfTokenRepository = csrfTokenRepository ;
}
public void setRequireCsrfProtectionMatcher (
ServerWebExchangeMatcher requireCsrfProtectionMatcher ) {
public void setRequireCsrfProtectionMatcher ( ServerWebExchangeMatcher requireCsrfProtectionMatcher ) {
Assert . notNull ( requireCsrfProtectionMatcher , "requireCsrfProtectionMatcher cannot be null" ) ;
this . requireCsrfProtectionMatcher = requireCsrfProtectionMatcher ;
}
/ * *
* Specifies if the { @code CsrfWebFilter } should try to resolve the actual CSRF token from the body of multipart
* data requests .
* @param tokenFromMultipartDataEnabled true if should read from multipart form body , else false . Default is false
* Specifies if the { @code CsrfWebFilter } should try to resolve the actual CSRF token
* from the body of multipart data requests .
* @param tokenFromMultipartDataEnabled true if should read from multipart form body ,
* else false . Default is false
* /
public void setTokenFromMultipartDataEnabled ( boolean tokenFromMultipartDataEnabled ) {
this . isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled ;
@ -113,38 +115,33 @@ public class CsrfWebFilter implements WebFilter {
@@ -113,38 +115,33 @@ public class CsrfWebFilter implements WebFilter {
@Override
public Mono < Void > filter ( ServerWebExchange exchange , WebFilterChain chain ) {
if ( TRUE . equals ( exchange . getAttribute ( SHOULD_NOT_FILTER ) ) ) {
if ( Boolean . TRUE . equals ( exchange . getAttribute ( SHOULD_NOT_FILTER ) ) ) {
return chain . filter ( exchange ) . then ( Mono . empty ( ) ) ;
}
return this . requireCsrfProtectionMatcher . matches ( exchange )
. filter ( matchResult - > matchResult . isMatch ( ) )
. filter ( matchResult - > ! exchange . getAttributes ( ) . containsKey ( CsrfToken . class . getName ( ) ) )
. flatMap ( m - > validateToken ( exchange ) )
. flatMap ( m - > continueFilterChain ( exchange , chain ) )
. switchIfEmpty ( continueFilterChain ( exchange , chain ) . then ( Mono . empty ( ) ) )
. onErrorResume ( CsrfException . class , e - > this . accessDeniedHandler
. handle ( exchange , e ) ) ;
return this . requireCsrfProtectionMatcher . matches ( exchange ) . filter ( MatchResult : : isMatch )
. filter ( ( matchResult ) - > ! exchange . getAttributes ( ) . containsKey ( CsrfToken . class . getName ( ) ) )
. flatMap ( ( m ) - > validateToken ( exchange ) ) . flatMap ( ( m ) - > continueFilterChain ( exchange , chain ) )
. switchIfEmpty ( continueFilterChain ( exchange , chain ) . then ( Mono . empty ( ) ) )
. onErrorResume ( CsrfException . class , ( ex ) - > this . accessDeniedHandler . handle ( exchange , ex ) ) ;
}
public static void skipExchange ( ServerWebExchange exchange ) {
exchange . getAttributes ( ) . put ( SHOULD_NOT_FILTER , TRUE ) ;
exchange . getAttributes ( ) . put ( SHOULD_NOT_FILTER , Boolean . TRUE ) ;
}
private Mono < Void > validateToken ( ServerWebExchange exchange ) {
return this . csrfTokenRepository . loadToken ( exchange )
. switchIfEmpty ( Mono . defer ( ( ) - > Mono . error ( new CsrfException ( "An expected CSRF token cannot be found" ) ) ) )
. filterWhen ( expected - > containsValidCsrfToken ( exchange , expected ) )
. switchIfEmpty ( Mono . defer ( ( ) - > Mono . error ( new CsrfException ( "Invalid CSRF Token" ) ) ) )
. then ( ) ;
. switchIfEmpty (
Mono . defer ( ( ) - > Mono . error ( new CsrfException ( "An expected CSRF token cannot be found" ) ) ) )
. filterWhen ( ( expected ) - > containsValidCsrfToken ( exchange , expected ) )
. switchIfEmpty ( Mono . defer ( ( ) - > Mono . error ( new CsrfException ( "Invalid CSRF Token" ) ) ) ) . then ( ) ;
}
private Mono < Boolean > containsValidCsrfToken ( ServerWebExchange exchange , CsrfToken expected ) {
return exchange . getFormData ( )
. flatMap ( data - > Mono . justOrEmpty ( data . getFirst ( expected . getParameterName ( ) ) ) )
. switchIfEmpty ( Mono . justOrEmpty ( exchange . getRequest ( ) . getHeaders ( ) . getFirst ( expected . getHeaderName ( ) ) ) )
. switchIfEmpty ( tokenFromMultipartData ( exchange , expected ) )
. map ( actual - > actual . equals ( expected . getToken ( ) ) ) ;
return exchange . getFormData ( ) . flatMap ( ( data ) - > Mono . justOrEmpty ( data . getFirst ( expected . getParameterName ( ) ) ) )
. switchIfEmpty ( Mono . justOrEmpty ( exchange . getRequest ( ) . getHeaders ( ) . getFirst ( expected . getHeaderName ( ) ) ) )
. switchIfEmpty ( tokenFromMultipartData ( exchange , expected ) )
. map ( ( actual ) - > equalsConstantTime ( actual , expected . getToken ( ) ) ) ;
}
private Mono < String > tokenFromMultipartData ( ServerWebExchange exchange , CsrfToken expected ) {
@ -157,14 +154,12 @@ public class CsrfWebFilter implements WebFilter {
@@ -157,14 +154,12 @@ public class CsrfWebFilter implements WebFilter {
if ( ! contentType . includes ( MediaType . MULTIPART_FORM_DATA ) ) {
return Mono . empty ( ) ;
}
return exchange . getMultipartData ( )
. map ( d - > d . getFirst ( expected . getParameterName ( ) ) )
. cast ( FormFieldPart . class )
. map ( FormFieldPart : : value ) ;
return exchange . getMultipartData ( ) . map ( ( d ) - > d . getFirst ( expected . getParameterName ( ) ) ) . cast ( FormFieldPart . class )
. map ( FormFieldPart : : value ) ;
}
private Mono < Void > continueFilterChain ( ServerWebExchange exchange , WebFilterChain chain ) {
return Mono . defer ( ( ) - > {
return Mono . defer ( ( ) - > {
Mono < CsrfToken > csrfToken = csrfToken ( exchange ) ;
exchange . getAttributes ( ) . put ( CsrfToken . class . getName ( ) , csrfToken ) ;
return chain . filter ( exchange ) ;
@ -172,26 +167,44 @@ public class CsrfWebFilter implements WebFilter {
@@ -172,26 +167,44 @@ public class CsrfWebFilter implements WebFilter {
}
private Mono < CsrfToken > csrfToken ( ServerWebExchange exchange ) {
return this . csrfTokenRepository . loadToken ( exchange )
. switchIfEmpty ( generateToken ( exchange ) ) ;
return this . csrfTokenRepository . loadToken ( exchange ) . switchIfEmpty ( generateToken ( exchange ) ) ;
}
/ * *
* Constant time comparison to prevent against timing attacks .
* @param expected
* @param actual
* @return
* /
private static boolean equalsConstantTime ( String expected , String actual ) {
byte [ ] expectedBytes = bytesUtf8 ( expected ) ;
byte [ ] actualBytes = bytesUtf8 ( actual ) ;
return MessageDigest . isEqual ( expectedBytes , actualBytes ) ;
}
private static byte [ ] bytesUtf8 ( String s ) {
// need to check if Utf8.encode() runs in constant time (probably not).
// This may leak length of string.
return ( s ! = null ) ? Utf8 . encode ( s ) : null ;
}
private Mono < CsrfToken > generateToken ( ServerWebExchange exchange ) {
return this . csrfTokenRepository . generateToken ( exchange )
. delayUntil ( token - > this . csrfTokenRepository . saveToken ( exchange , token ) ) ;
. delayUntil ( ( token ) - > this . csrfTokenRepository . saveToken ( exchange , token ) ) ;
}
private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {
private static final Set < HttpMethod > ALLOWED_METHODS = new HashSet < > (
Arrays . asList ( HttpMethod . GET , HttpMethod . HEAD , HttpMethod . TRACE , HttpMethod . OPTIONS ) ) ;
Arrays . asList ( HttpMethod . GET , HttpMethod . HEAD , HttpMethod . TRACE , HttpMethod . OPTIONS ) ) ;
@Override
public Mono < MatchResult > matches ( ServerWebExchange exchange ) {
return Mono . just ( exchange . getRequest ( ) )
. flatMap ( r - > Mono . justOrEmpty ( r . getMethod ( ) ) )
. filter ( m - > ALLOWED_METHODS . contains ( m ) )
. flatMap ( m - > MatchResult . notMatch ( ) )
. switchIfEmpty ( MatchResult . match ( ) ) ;
return Mono . just ( exchange . getRequest ( ) ) . flatMap ( ( r ) - > Mono . justOrEmpty ( r . getMethod ( ) ) )
. filter ( ALLOWED_METHODS : : contains ) . flatMap ( ( m ) - > MatchResult . notMatch ( ) )
. switchIfEmpty ( MatchResult . match ( ) ) ;
}
}
}