@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2022 the original author or authors .
* Copyright 2002 - 2023 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -127,11 +127,12 @@ public class CsrfFilterTests {
@@ -127,11 +127,12 @@ public class CsrfFilterTests {
@Test
public void doFilterAccessDeniedNoTokenPresent ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verify ( this . deniedHandler ) . handle ( eq ( this . request ) , eq ( this . response ) , any ( InvalidCsrfTokenException . class ) ) ;
verifyNoMoreInteractions ( this . filterChain ) ;
}
@ -139,12 +140,13 @@ public class CsrfFilterTests {
@@ -139,12 +140,13 @@ public class CsrfFilterTests {
@Test
public void doFilterAccessDeniedIncorrectTokenPresent ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) + " INVALID" ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verify ( this . deniedHandler ) . handle ( eq ( this . request ) , eq ( this . response ) , any ( InvalidCsrfTokenException . class ) ) ;
verifyNoMoreInteractions ( this . filterChain ) ;
}
@ -152,12 +154,13 @@ public class CsrfFilterTests {
@@ -152,12 +154,13 @@ public class CsrfFilterTests {
@Test
public void doFilterAccessDeniedIncorrectTokenPresentHeader ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
this . request . addHeader ( this . token . getHeaderName ( ) , this . token . getToken ( ) + " INVALID" ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verify ( this . deniedHandler ) . handle ( eq ( this . request ) , eq ( this . response ) , any ( InvalidCsrfTokenException . class ) ) ;
verifyNoMoreInteractions ( this . filterChain ) ;
}
@ -166,8 +169,8 @@ public class CsrfFilterTests {
@@ -166,8 +169,8 @@ public class CsrfFilterTests {
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter ( )
throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler ( ) ;
handler . handle ( this . request , this . response , ( ) - > this . token ) ;
CsrfToken csrfToken = ( CsrfToken ) this . request . getAttribute ( CsrfToken . class . getName ( ) ) ;
@ -176,6 +179,7 @@ public class CsrfFilterTests {
@@ -176,6 +179,7 @@ public class CsrfFilterTests {
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verify ( this . deniedHandler ) . handle ( eq ( this . request ) , eq ( this . response ) , any ( InvalidCsrfTokenException . class ) ) ;
verifyNoMoreInteractions ( this . filterChain ) ;
}
@ -183,11 +187,12 @@ public class CsrfFilterTests {
@@ -183,11 +187,12 @@ public class CsrfFilterTests {
@Test
public void doFilterNotCsrfRequestExistingToken ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
verifyNoMoreInteractions ( this . deniedHandler ) ;
}
@ -195,11 +200,12 @@ public class CsrfFilterTests {
@@ -195,11 +200,12 @@ public class CsrfFilterTests {
@Test
public void doFilterNotCsrfRequestGenerateToken ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , true ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , true ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
verifyNoMoreInteractions ( this . deniedHandler ) ;
}
@ -207,8 +213,8 @@ public class CsrfFilterTests {
@@ -207,8 +213,8 @@ public class CsrfFilterTests {
@Test
public void doFilterIsCsrfRequestExistingTokenHeader ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler ( ) ;
handler . handle ( this . request , this . response , ( ) - > this . token ) ;
CsrfToken csrfToken = ( CsrfToken ) this . request . getAttribute ( CsrfToken . class . getName ( ) ) ;
@ -216,6 +222,7 @@ public class CsrfFilterTests {
@@ -216,6 +222,7 @@ public class CsrfFilterTests {
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
verifyNoMoreInteractions ( this . deniedHandler ) ;
}
@ -224,8 +231,8 @@ public class CsrfFilterTests {
@@ -224,8 +231,8 @@ public class CsrfFilterTests {
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam ( )
throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler ( ) ;
handler . handle ( this . request , this . response , ( ) - > this . token ) ;
CsrfToken csrfToken = ( CsrfToken ) this . request . getAttribute ( CsrfToken . class . getName ( ) ) ;
@ -234,6 +241,7 @@ public class CsrfFilterTests {
@@ -234,6 +241,7 @@ public class CsrfFilterTests {
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
verifyNoMoreInteractions ( this . deniedHandler ) ;
}
@ -241,8 +249,8 @@ public class CsrfFilterTests {
@@ -241,8 +249,8 @@ public class CsrfFilterTests {
@Test
public void doFilterIsCsrfRequestExistingToken ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler ( ) ;
handler . handle ( this . request , this . response , ( ) - > this . token ) ;
CsrfToken csrfToken = ( CsrfToken ) this . request . getAttribute ( CsrfToken . class . getName ( ) ) ;
@ -250,6 +258,7 @@ public class CsrfFilterTests {
@@ -250,6 +258,7 @@ public class CsrfFilterTests {
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
verifyNoMoreInteractions ( this . deniedHandler ) ;
verify ( this . tokenRepository , never ( ) ) . saveToken ( any ( CsrfToken . class ) , any ( HttpServletRequest . class ) ,
@ -259,8 +268,8 @@ public class CsrfFilterTests {
@@ -259,8 +268,8 @@ public class CsrfFilterTests {
@Test
public void doFilterIsCsrfRequestGenerateToken ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , true ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , true ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler ( ) ;
handler . handle ( this . request , this . response , ( ) - > this . token ) ;
CsrfToken csrfToken = ( CsrfToken ) this . request . getAttribute ( CsrfToken . class . getName ( ) ) ;
@ -268,6 +277,7 @@ public class CsrfFilterTests {
@@ -268,6 +277,7 @@ public class CsrfFilterTests {
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
// LazyCsrfTokenRepository requires the response as an attribute
assertThat ( this . request . getAttribute ( HttpServletResponse . class . getName ( ) ) ) . isEqualTo ( this . response ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
@ -332,11 +342,12 @@ public class CsrfFilterTests {
@@ -332,11 +342,12 @@ public class CsrfFilterTests {
this . filter = new CsrfFilter ( this . tokenRepository ) ;
this . filter . setRequireCsrfProtectionMatcher ( this . requestMatcher ) ;
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
assertThat ( this . response . getStatus ( ) ) . isEqualTo ( HttpServletResponse . SC_FORBIDDEN ) ;
verifyNoMoreInteractions ( this . filterChain ) ;
}
@ -360,22 +371,24 @@ public class CsrfFilterTests {
@@ -360,22 +371,24 @@ public class CsrfFilterTests {
given ( token . getToken ( ) ) . willReturn ( null ) ;
given ( token . getHeaderName ( ) ) . willReturn ( this . token . getHeaderName ( ) ) ;
given ( token . getParameterName ( ) ) . willReturn ( this . token . getParameterName ( ) ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken ( token , false ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
filter . doFilterInternal ( this . request , this . response , this . filterChain ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
assertThat ( this . response . getStatus ( ) ) . isEqualTo ( HttpServletResponse . SC_OK ) ;
}
@Test
public void doFilterWhenRequestHandlerThenUsed ( ) throws Exception {
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
CsrfTokenRequestHandler requestHandler = mock ( CsrfTokenRequestHandler . class ) ;
this . filter = createCsrfFilter ( this . tokenRepository ) ;
this . filter . setRequestHandler ( requestHandler ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verify ( this . tokenRepository ) . loadDeferredToken ( this . request , this . response ) ;
verify ( requestHandler ) . handle ( eq ( this . request ) , eq ( this . response ) , any ( ) ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
@ -384,11 +397,12 @@ public class CsrfFilterTests {
@@ -384,11 +397,12 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess ( ) throws Exception {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThat ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( "_csrf" ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
assertThat ( this . response . getStatus ( ) ) . isEqualTo ( HttpServletResponse . SC_OK ) ;
@ -407,10 +421,11 @@ public class CsrfFilterTests {
@@ -407,10 +421,11 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException ( ) throws Exception {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) )
. willReturn ( new TestDeferredCsrf Token ( this . token , false ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrf Token( this . token , false ) ;
given ( this . tokenRepository . loadDeferred Token( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verify ( this . deniedHandler ) . handle ( eq ( this . request ) , eq ( this . response ) , any ( AccessDeniedException . class ) ) ;
verifyNoMoreInteractions ( this . filterChain ) ;
}
@ -435,10 +450,11 @@ public class CsrfFilterTests {
@@ -435,10 +450,11 @@ public class CsrfFilterTests {
requestHandler . setCsrfRequestAttributeName ( csrfAttrName ) ;
filter . setRequestHandler ( requestHandler ) ;
CsrfToken expectedCsrfToken = mock ( CsrfToken . class ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( expectedCsrfToken , true ) ) ;
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken ( expectedCsrfToken , true ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) ) . willReturn ( deferredCsrfToken ) ;
filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThat ( this . request . getAttribute ( DeferredCsrfToken . class . getName ( ) ) ) . isSameAs ( deferredCsrfToken ) ;
verifyNoInteractions ( expectedCsrfToken ) ;
CsrfToken tokenFromRequest = ( CsrfToken ) this . request . getAttribute ( csrfAttrName ) ;