@ -43,7 +43,6 @@ import static org.mockito.BDDMockito.given;
@@ -43,7 +43,6 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.lenient ;
import static org.mockito.Mockito.mock ;
import static org.mockito.Mockito.never ;
import static org.mockito.Mockito.spy ;
import static org.mockito.Mockito.times ;
import static org.mockito.Mockito.verify ;
import static org.mockito.Mockito.verifyNoInteractions ;
@ -87,11 +86,7 @@ public class CsrfFilterTests {
@@ -87,11 +86,7 @@ public class CsrfFilterTests {
}
private CsrfFilter createCsrfFilter ( CsrfTokenRepository repository ) {
return createCsrfFilter ( new CsrfTokenRepositoryRequestHandler ( repository ) ) ;
}
private CsrfFilter createCsrfFilter ( CsrfTokenRequestHandler requestHandler ) {
CsrfFilter filter = new CsrfFilter ( requestHandler ) ;
CsrfFilter filter = new CsrfFilter ( repository ) ;
filter . setRequireCsrfProtectionMatcher ( this . requestMatcher ) ;
filter . setAccessDeniedHandler ( this . deniedHandler ) ;
return filter ;
@ -104,7 +99,7 @@ public class CsrfFilterTests {
@@ -104,7 +99,7 @@ public class CsrfFilterTests {
@Test
public void constructorNullRepository ( ) {
assertThatIllegalArgumentException ( ) . isThrownBy ( ( ) - > new CsrfFilter ( ( CsrfTokenRequestHandler ) null ) ) ;
assertThatIllegalArgumentException ( ) . isThrownBy ( ( ) - > new CsrfFilter ( null ) ) ;
}
// SEC-2276
@ -129,7 +124,8 @@ public class CsrfFilterTests {
@@ -129,7 +124,8 @@ public class CsrfFilterTests {
@Test
public void doFilterAccessDeniedNoTokenPresent ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
@ -140,7 +136,8 @@ public class CsrfFilterTests {
@@ -140,7 +136,8 @@ public class CsrfFilterTests {
@Test
public void doFilterAccessDeniedIncorrectTokenPresent ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) + " INVALID" ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
@ -152,7 +149,8 @@ public class CsrfFilterTests {
@@ -152,7 +149,8 @@ public class CsrfFilterTests {
@Test
public void doFilterAccessDeniedIncorrectTokenPresentHeader ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . addHeader ( this . token . getHeaderName ( ) , this . token . getToken ( ) + " INVALID" ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
@ -165,7 +163,8 @@ public class CsrfFilterTests {
@@ -165,7 +163,8 @@ public class CsrfFilterTests {
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter ( )
throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) ) ;
this . request . addHeader ( this . token . getHeaderName ( ) , this . token . getToken ( ) + " INVALID" ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
@ -178,7 +177,8 @@ public class CsrfFilterTests {
@@ -178,7 +177,8 @@ public class CsrfFilterTests {
@Test
public void doFilterNotCsrfRequestExistingToken ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( false ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
@ -189,7 +189,8 @@ public class CsrfFilterTests {
@@ -189,7 +189,8 @@ public class CsrfFilterTests {
@Test
public void doFilterNotCsrfRequestGenerateToken ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( false ) ;
given ( this . tokenRepository . generateToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , true ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
@ -200,7 +201,8 @@ public class CsrfFilterTests {
@@ -200,7 +201,8 @@ public class CsrfFilterTests {
@Test
public void doFilterIsCsrfRequestExistingTokenHeader ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . addHeader ( this . token . getHeaderName ( ) , this . token . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
@ -213,7 +215,8 @@ public class CsrfFilterTests {
@@ -213,7 +215,8 @@ public class CsrfFilterTests {
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam ( )
throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) + " INVALID" ) ;
this . request . addHeader ( this . token . getHeaderName ( ) , this . token . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
@ -226,7 +229,8 @@ public class CsrfFilterTests {
@@ -226,7 +229,8 @@ public class CsrfFilterTests {
@Test
public void doFilterIsCsrfRequestExistingToken ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
@ -240,7 +244,8 @@ public class CsrfFilterTests {
@@ -240,7 +244,8 @@ public class CsrfFilterTests {
@Test
public void doFilterIsCsrfRequestGenerateToken ( ) throws ServletException , IOException {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . generateToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , true ) ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
@ -248,16 +253,17 @@ public class CsrfFilterTests {
@@ -248,16 +253,17 @@ public class CsrfFilterTests {
// LazyCsrfTokenRepository requires the response as an attribute
assertThat ( this . request . getAttribute ( HttpServletResponse . class . getName ( ) ) ) . isEqualTo ( this . response ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
verify ( this . tokenRepository ) . saveToken ( this . token , this . request , this . response ) ;
verifyNoMoreInteractions ( this . deniedHandler ) ;
}
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods ( ) throws ServletException , IOException {
this . filter = create CsrfFilter( this . tokenRepository ) ;
this . filter = new CsrfFilter ( this . tokenRepository ) ;
this . filter . setAccessDeniedHandler ( this . deniedHandler ) ;
for ( String method : Arrays . asList ( "GET" , "TRACE" , "OPTIONS" , "HEAD" ) ) {
resetRequestResponse ( ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . setMethod ( method ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
@ -273,11 +279,12 @@ public class CsrfFilterTests {
@@ -273,11 +279,12 @@ public class CsrfFilterTests {
* /
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive ( ) throws Exception {
this . filter = new CsrfFilter ( new CsrfTokenRepositoryRequestHandler ( this . tokenRepository ) ) ;
this . filter = new CsrfFilter ( this . tokenRepository ) ;
this . filter . setAccessDeniedHandler ( this . deniedHandler ) ;
for ( String method : Arrays . asList ( "get" , "TrAcE" , "oPTIOnS" , "hEaD" ) ) {
resetRequestResponse ( ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . setMethod ( method ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
verify ( this . deniedHandler ) . handle ( eq ( this . request ) , eq ( this . response ) ,
@ -288,11 +295,12 @@ public class CsrfFilterTests {
@@ -288,11 +295,12 @@ public class CsrfFilterTests {
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods ( ) throws ServletException , IOException {
this . filter = new CsrfFilter ( new CsrfTokenRepositoryRequestHandler ( this . tokenRepository ) ) ;
this . filter = new CsrfFilter ( this . tokenRepository ) ;
this . filter . setAccessDeniedHandler ( this . deniedHandler ) ;
for ( String method : Arrays . asList ( "POST" , "PUT" , "PATCH" , "DELETE" , "INVALID" ) ) {
resetRequestResponse ( ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . request . setMethod ( method ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
verify ( this . deniedHandler ) . handle ( eq ( this . request ) , eq ( this . response ) ,
@ -303,10 +311,11 @@ public class CsrfFilterTests {
@@ -303,10 +311,11 @@ public class CsrfFilterTests {
@Test
public void doFilterDefaultAccessDenied ( ) throws ServletException , IOException {
this . filter = new CsrfFilter ( new CsrfTokenRepositoryRequestHandler ( this . tokenRepository ) ) ;
this . filter = new CsrfFilter ( this . tokenRepository ) ;
this . filter . setRequireCsrfProtectionMatcher ( this . requestMatcher ) ;
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
assertThatCsrfToken ( this . request . getAttribute ( this . csrfAttrName ) ) . isEqualTo ( this . token ) ;
assertThatCsrfToken ( this . request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( this . token ) ;
@ -317,7 +326,7 @@ public class CsrfFilterTests {
@@ -317,7 +326,7 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenSkipRequestInvokedThenSkips ( ) throws Exception {
CsrfTokenRepository repository = mock ( CsrfTokenRepository . class ) ;
CsrfFilter filter = create CsrfFilter( repository ) ;
CsrfFilter filter = new CsrfFilter ( repository ) ;
lenient ( ) . when ( repository . loadToken ( any ( HttpServletRequest . class ) ) ) . thenReturn ( this . token ) ;
MockHttpServletRequest request = new MockHttpServletRequest ( ) ;
CsrfFilter . skipRequest ( request ) ;
@ -333,7 +342,8 @@ public class CsrfFilterTests {
@@ -333,7 +342,8 @@ public class CsrfFilterTests {
given ( token . getToken ( ) ) . willReturn ( null ) ;
given ( token . getHeaderName ( ) ) . willReturn ( this . token . getHeaderName ( ) ) ;
given ( token . getParameterName ( ) ) . willReturn ( this . token . getParameterName ( ) ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( token ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( token , false ) ) ;
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
filter . doFilterInternal ( this . request , this . response , this . filterChain ) ;
assertThat ( this . response . getStatus ( ) ) . isEqualTo ( HttpServletResponse . SC_OK ) ;
@ -341,13 +351,15 @@ public class CsrfFilterTests {
@@ -341,13 +351,15 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenRequestHandlerThenUsed ( ) throws Exception {
CsrfTokenRequestHandler requestHandler = mock ( CsrfTokenRequestHandler . class ) ;
given ( requestHandler . handle ( this . request , this . response ) )
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . filter = createCsrfFilter ( requestHandler ) ;
CsrfTokenRequestHandler requestHandler = mock ( CsrfTokenRequestHandler . class ) ;
this . filter = createCsrfFilter ( this . tokenRepository ) ;
this . filter . setRequestHandler ( requestHandler ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
verify ( requestHandler ) . handle ( eq ( this . request ) , eq ( this . response ) ) ;
verify ( this . tokenRepository ) . loadDeferredToken ( this . request , this . response ) ;
verify ( requestHandler ) . handle ( eq ( this . request ) , eq ( this . response ) , any ( ) ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
}
@ -365,41 +377,20 @@ public class CsrfFilterTests {
@@ -365,41 +377,20 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet ( )
throws ServletException , IOException {
CsrfFilter filter = createCsrfFilter ( this . tokenRepository ) ;
String csrfAttrName = "_csrf" ;
CsrfTokenRepositoryRequest Handler requestHandler = new CsrfTokenRepositoryRequestHandler ( this . tokenRepository ) ;
CsrfTokenRequestAttribute Handler requestHandler = new CsrfTokenRequestAttributeHandler ( ) ;
requestHandler . setCsrfRequestAttributeName ( csrfAttrName ) ;
this . filter = createCsrfFilter ( requestHandler ) ;
CsrfToken expectedCsrfToken = spy ( this . token ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( expectedCsrfToken ) ;
filter . setRequestHandler ( requestHandler ) ;
CsrfToken expectedCsrfToken = mock ( CsrfToken . class ) ;
given ( this . tokenRepository . loadDeferredToken ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( expectedCsrfToken , true ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
filter . doFilter ( this . request , this . response , this . filterChain ) ;
verifyNoInteractions ( expectedCsrfToken ) ;
CsrfToken tokenFromRequest = ( CsrfToken ) this . request . getAttribute ( csrfAttrName ) ;
assertThatCsrfToken ( tokenFromRequest ) . isEqualTo ( expectedCsrfToken ) ;
}
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
private final CsrfToken csrfToken ;
private final boolean isGenerated ;
private TestDeferredCsrfToken ( CsrfToken csrfToken , boolean isGenerated ) {
this . csrfToken = csrfToken ;
this . isGenerated = isGenerated ;
}
@Override
public CsrfToken get ( ) {
return this . csrfToken ;
}
@Override
public boolean isGenerated ( ) {
return this . isGenerated ;
}
}
}