@ -130,8 +130,8 @@ public class CsrfFilterTests {
@@ -130,8 +130,8 @@ public class CsrfFilterTests {
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
verify ( this . deniedHandler ) . handle ( eq ( this . request ) , eq ( this . response ) , any ( InvalidCsrfTokenException . class ) ) ;
verifyNoMoreInteractions ( this . filterChain ) ;
}
@ -143,8 +143,8 @@ public class CsrfFilterTests {
@@ -143,8 +143,8 @@ public class CsrfFilterTests {
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) + " INVALID" ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
verify ( this . deniedHandler ) . handle ( eq ( this . request ) , eq ( this . response ) , any ( InvalidCsrfTokenException . class ) ) ;
verifyNoMoreInteractions ( this . filterChain ) ;
}
@ -156,8 +156,8 @@ public class CsrfFilterTests {
@@ -156,8 +156,8 @@ public class CsrfFilterTests {
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . addHeader ( this . token . getHeaderName ( ) , this . token . getToken ( ) + " INVALID" ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
verify ( this . deniedHandler ) . handle ( eq ( this . request ) , eq ( this . response ) , any ( InvalidCsrfTokenException . class ) ) ;
verifyNoMoreInteractions ( this . filterChain ) ;
}
@ -168,11 +168,14 @@ public class CsrfFilterTests {
@@ -168,11 +168,14 @@ public class CsrfFilterTests {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) ) ;
this . request . addHeader ( this . token . getHeaderName ( ) , this . token . getToken ( ) + " INVALID" ) ;
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler ( ) ;
handler . handle ( this . request , this . response , ( ) - > this . token ) ;
CsrfToken csrfToken = ( CsrfToken ) this . request . getAttribute ( CsrfToken . class . getName ( ) ) ;
this . request . setParameter ( csrfToken . getParameterName ( ) , csrfToken . getToken ( ) ) ;
this . request . addHeader ( csrfToken . getHeaderName ( ) , csrfToken . getToken ( ) + " INVALID" ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
verify ( this . deniedHandler ) . handle ( eq ( this . request ) , eq ( this . response ) , any ( InvalidCsrfTokenException . class ) ) ;
verifyNoMoreInteractions ( this . filterChain ) ;
}
@ -183,8 +186,8 @@ public class CsrfFilterTests {
@@ -183,8 +186,8 @@ public class CsrfFilterTests {
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
verifyNoMoreInteractions ( this . deniedHandler ) ;
}
@ -195,8 +198,8 @@ public class CsrfFilterTests {
@@ -195,8 +198,8 @@ public class CsrfFilterTests {
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , true ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
verifyNoMoreInteractions ( this . deniedHandler ) ;
}
@ -206,10 +209,13 @@ public class CsrfFilterTests {
@@ -206,10 +209,13 @@ public class CsrfFilterTests {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . addHeader ( this . token . getHeaderName ( ) , this . token . getToken ( ) ) ;
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler ( ) ;
handler . handle ( this . request , this . response , ( ) - > this . token ) ;
CsrfToken csrfToken = ( CsrfToken ) this . request . getAttribute ( CsrfToken . class . getName ( ) ) ;
this . request . addHeader ( csrfToken . getHeaderName ( ) , csrfToken . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
verifyNoMoreInteractions ( this . deniedHandler ) ;
}
@ -220,11 +226,14 @@ public class CsrfFilterTests {
@@ -220,11 +226,14 @@ public class CsrfFilterTests {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) + " INVALID" ) ;
this . request . addHeader ( this . token . getHeaderName ( ) , this . token . getToken ( ) ) ;
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler ( ) ;
handler . handle ( this . request , this . response , ( ) - > this . token ) ;
CsrfToken csrfToken = ( CsrfToken ) this . request . getAttribute ( CsrfToken . class . getName ( ) ) ;
this . request . setParameter ( csrfToken . getParameterName ( ) , csrfToken . getToken ( ) + " INVALID" ) ;
this . request . addHeader ( csrfToken . getHeaderName ( ) , csrfToken . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
verifyNoMoreInteractions ( this . deniedHandler ) ;
}
@ -234,10 +243,13 @@ public class CsrfFilterTests {
@@ -234,10 +243,13 @@ public class CsrfFilterTests {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) ) ;
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler ( ) ;
handler . handle ( this . request , this . response , ( ) - > this . token ) ;
CsrfToken csrfToken = ( CsrfToken ) this . request . getAttribute ( CsrfToken . class . getName ( ) ) ;
this . request . setParameter ( csrfToken . getParameterName ( ) , csrfToken . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
verifyNoMoreInteractions ( this . deniedHandler ) ;
verify ( this . tokenRepository , never ( ) ) . saveToken ( any ( CsrfToken . class ) , any ( HttpServletRequest . class ) ,
@ -249,10 +261,13 @@ public class CsrfFilterTests {
@@ -249,10 +261,13 @@ public class CsrfFilterTests {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , true ) ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) ) ;
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler ( ) ;
handler . handle ( this . request , this . response , ( ) - > this . token ) ;
CsrfToken csrfToken = ( CsrfToken ) this . request . getAttribute ( CsrfToken . class . getName ( ) ) ;
this . request . setParameter ( csrfToken . getParameterName ( ) , csrfToken . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
// LazyCsrfTokenRepository requires the response as an attribute
assertThat ( this . request . getAttribute ( HttpServletResponse . class . getName ( ) ) ) . isEqualTo ( this . response ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
@ -320,8 +335,8 @@ public class CsrfFilterTests {
@@ -320,8 +335,8 @@ public class CsrfFilterTests {
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isNotNull ( ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . response . getStatus ( ) ) . isEqualTo ( HttpServletResponse . SC_FORBIDDEN ) ;
verifyNoMoreInteractions ( this . filterChain ) ;
}
@ -371,12 +386,9 @@ public class CsrfFilterTests {
@@ -371,12 +386,9 @@ public class CsrfFilterTests {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( false ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler ( ) ;
requestHandler . setCsrfRequestAttributeName ( this . token . getParameterName ( ) ) ;
this . filter . setRequestHandler ( requestHandler ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThat ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( this . token . getParameterName ( ) ) ) . isNotNull ( ) ;
assertThat ( this . request . getAttribute ( "_csrf" ) ) . isNotNull ( ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
assertThat ( this . response . getStatus ( ) ) . isEqualTo ( HttpServletResponse . SC_OK ) ;
@ -397,8 +409,6 @@ public class CsrfFilterTests {
@@ -397,8 +409,6 @@ public class CsrfFilterTests {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler ( ) ;
this . filter . setRequestHandler ( requestHandler ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
verify ( this . deniedHandler ) . handle ( eq ( this . request ) , eq ( this . response ) , any ( AccessDeniedException . class ) ) ;
@ -421,7 +431,7 @@ public class CsrfFilterTests {
@@ -421,7 +431,7 @@ public class CsrfFilterTests {
throws ServletException , IOException {
CsrfFilter filter = createCsrfFilter ( this . tokenRepository ) ;
String csrfAttrName = "_csrf" ;
CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler ( ) ;
CsrfTokenRequestAttributeHandler requestHandler = new Xor CsrfTokenRequestAttributeHandler( ) ;
requestHandler . setCsrfRequestAttributeName ( csrfAttrName ) ;
filter . setRequestHandler ( requestHandler ) ;
CsrfToken expectedCsrfToken = mock ( CsrfToken . class ) ;
@ -432,7 +442,7 @@ public class CsrfFilterTests {
@@ -432,7 +442,7 @@ public class CsrfFilterTests {
verifyNoInteractions ( expectedCsrfToken ) ;
CsrfToken tokenFromRequest = ( CsrfToken ) this . request . getAttribute ( csrfAttrName ) ;
assertThatCsrfToken ( tokenFromRequest ) . isEqualTo ( expectedCsrfToken ) ;
assertThatCsrfToken ( tokenFromRequest ) . isNotNull ( ) ;
}
}