@ -16,6 +16,7 @@
@@ -16,6 +16,7 @@
package org.springframework.security.web.server.csrf ;
import java.security.MessageDigest ;
import java.util.Arrays ;
import java.util.HashSet ;
import java.util.Set ;
@ -28,6 +29,7 @@ import org.springframework.http.HttpStatus;
@@ -28,6 +29,7 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType ;
import org.springframework.http.codec.multipart.FormFieldPart ;
import org.springframework.http.server.reactive.ServerHttpRequest ;
import org.springframework.security.crypto.codec.Utf8 ;
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler ;
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler ;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher ;
@ -139,7 +141,7 @@ public class CsrfWebFilter implements WebFilter {
@@ -139,7 +141,7 @@ public class CsrfWebFilter implements WebFilter {
return exchange . getFormData ( ) . flatMap ( ( data ) - > Mono . justOrEmpty ( data . getFirst ( expected . getParameterName ( ) ) ) )
. switchIfEmpty ( Mono . justOrEmpty ( exchange . getRequest ( ) . getHeaders ( ) . getFirst ( expected . getHeaderName ( ) ) ) )
. switchIfEmpty ( tokenFromMultipartData ( exchange , expected ) )
. map ( ( actual ) - > actual . equals ( expected . getToken ( ) ) ) ;
. map ( ( actual ) - > equalsConstantTime ( actual , expected . getToken ( ) ) ) ;
}
private Mono < String > tokenFromMultipartData ( ServerWebExchange exchange , CsrfToken expected ) {
@ -168,6 +170,24 @@ public class CsrfWebFilter implements WebFilter {
@@ -168,6 +170,24 @@ public class CsrfWebFilter implements WebFilter {
return this . csrfTokenRepository . loadToken ( exchange ) . switchIfEmpty ( generateToken ( exchange ) ) ;
}
/ * *
* Constant time comparison to prevent against timing attacks .
* @param expected
* @param actual
* @return
* /
private static boolean equalsConstantTime ( String expected , String actual ) {
byte [ ] expectedBytes = bytesUtf8 ( expected ) ;
byte [ ] actualBytes = bytesUtf8 ( actual ) ;
return MessageDigest . isEqual ( expectedBytes , actualBytes ) ;
}
private static byte [ ] bytesUtf8 ( String s ) {
// need to check if Utf8.encode() runs in constant time (probably not).
// This may leak length of string.
return ( s ! = null ) ? Utf8 . encode ( s ) : null ;
}
private Mono < CsrfToken > generateToken ( ServerWebExchange exchange ) {
return this . csrfTokenRepository . generateToken ( exchange )
. delayUntil ( ( token ) - > this . csrfTokenRepository . saveToken ( exchange , token ) ) ;