@ -86,7 +86,11 @@ public class CsrfFilterTests {
}
}
private CsrfFilter createCsrfFilter ( CsrfTokenRepository repository ) {
private CsrfFilter createCsrfFilter ( CsrfTokenRepository repository ) {
CsrfFilter filter = new CsrfFilter ( repository ) ;
return createCsrfFilter ( new CsrfTokenRepositoryRequestHandler ( repository ) ) ;
}
private CsrfFilter createCsrfFilter ( CsrfTokenRequestHandler requestHandler ) {
CsrfFilter filter = new CsrfFilter ( requestHandler ) ;
filter . setRequireCsrfProtectionMatcher ( this . requestMatcher ) ;
filter . setRequireCsrfProtectionMatcher ( this . requestMatcher ) ;
filter . setAccessDeniedHandler ( this . deniedHandler ) ;
filter . setAccessDeniedHandler ( this . deniedHandler ) ;
return filter ;
return filter ;
@ -99,7 +103,7 @@ public class CsrfFilterTests {
@Test
@Test
public void constructorNullRepository ( ) {
public void constructorNullRepository ( ) {
assertThatIllegalArgumentException ( ) . isThrownBy ( ( ) - > new CsrfFilter ( null ) ) ;
assertThatIllegalArgumentException ( ) . isThrownBy ( ( ) - > new CsrfFilter ( ( CsrfTokenRequestHandler ) null ) ) ;
}
}
// SEC-2276
// SEC-2276
@ -249,7 +253,7 @@ public class CsrfFilterTests {
@Test
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods ( ) throws ServletException , IOException {
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods ( ) throws ServletException , IOException {
this . filter = new CsrfFilter ( this . tokenRepository ) ;
this . filter = create CsrfFilter( this . tokenRepository ) ;
this . filter . setAccessDeniedHandler ( this . deniedHandler ) ;
this . filter . setAccessDeniedHandler ( this . deniedHandler ) ;
for ( String method : Arrays . asList ( "GET" , "TRACE" , "OPTIONS" , "HEAD" ) ) {
for ( String method : Arrays . asList ( "GET" , "TRACE" , "OPTIONS" , "HEAD" ) ) {
resetRequestResponse ( ) ;
resetRequestResponse ( ) ;
@ -269,7 +273,7 @@ public class CsrfFilterTests {
* /
* /
@Test
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive ( ) throws Exception {
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive ( ) throws Exception {
this . filter = new CsrfFilter ( this . tokenRepository ) ;
this . filter = new CsrfFilter ( new CsrfTokenRepositoryRequestHandler ( this . tokenRepository ) ) ;
this . filter . setAccessDeniedHandler ( this . deniedHandler ) ;
this . filter . setAccessDeniedHandler ( this . deniedHandler ) ;
for ( String method : Arrays . asList ( "get" , "TrAcE" , "oPTIOnS" , "hEaD" ) ) {
for ( String method : Arrays . asList ( "get" , "TrAcE" , "oPTIOnS" , "hEaD" ) ) {
resetRequestResponse ( ) ;
resetRequestResponse ( ) ;
@ -284,7 +288,7 @@ public class CsrfFilterTests {
@Test
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods ( ) throws ServletException , IOException {
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods ( ) throws ServletException , IOException {
this . filter = new CsrfFilter ( this . tokenRepository ) ;
this . filter = new CsrfFilter ( new CsrfTokenRepositoryRequestHandler ( this . tokenRepository ) ) ;
this . filter . setAccessDeniedHandler ( this . deniedHandler ) ;
this . filter . setAccessDeniedHandler ( this . deniedHandler ) ;
for ( String method : Arrays . asList ( "POST" , "PUT" , "PATCH" , "DELETE" , "INVALID" ) ) {
for ( String method : Arrays . asList ( "POST" , "PUT" , "PATCH" , "DELETE" , "INVALID" ) ) {
resetRequestResponse ( ) ;
resetRequestResponse ( ) ;
@ -299,7 +303,7 @@ public class CsrfFilterTests {
@Test
@Test
public void doFilterDefaultAccessDenied ( ) throws ServletException , IOException {
public void doFilterDefaultAccessDenied ( ) throws ServletException , IOException {
this . filter = new CsrfFilter ( this . tokenRepository ) ;
this . filter = new CsrfFilter ( new CsrfTokenRepositoryRequestHandler ( this . tokenRepository ) ) ;
this . filter . setRequireCsrfProtectionMatcher ( this . requestMatcher ) ;
this . filter . setRequireCsrfProtectionMatcher ( this . requestMatcher ) ;
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
@ -313,7 +317,7 @@ public class CsrfFilterTests {
@Test
@Test
public void doFilterWhenSkipRequestInvokedThenSkips ( ) throws Exception {
public void doFilterWhenSkipRequestInvokedThenSkips ( ) throws Exception {
CsrfTokenRepository repository = mock ( CsrfTokenRepository . class ) ;
CsrfTokenRepository repository = mock ( CsrfTokenRepository . class ) ;
CsrfFilter filter = new CsrfFilter ( repository ) ;
CsrfFilter filter = create CsrfFilter( repository ) ;
lenient ( ) . when ( repository . loadToken ( any ( HttpServletRequest . class ) ) ) . thenReturn ( this . token ) ;
lenient ( ) . when ( repository . loadToken ( any ( HttpServletRequest . class ) ) ) . thenReturn ( this . token ) ;
MockHttpServletRequest request = new MockHttpServletRequest ( ) ;
MockHttpServletRequest request = new MockHttpServletRequest ( ) ;
CsrfFilter . skipRequest ( request ) ;
CsrfFilter . skipRequest ( request ) ;
@ -340,25 +344,13 @@ public class CsrfFilterTests {
CsrfTokenRequestHandler requestHandler = mock ( CsrfTokenRequestHandler . class ) ;
CsrfTokenRequestHandler requestHandler = mock ( CsrfTokenRequestHandler . class ) ;
given ( requestHandler . handle ( this . request , this . response ) )
given ( requestHandler . handle ( this . request , this . response ) )
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
. willReturn ( new TestDeferredCsrfToken ( this . token , false ) ) ;
this . filter . setRequestHandl er( requestHandler ) ;
this . filter = createCsrfFilt er ( requestHandler ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) ) ;
this . request . setParameter ( this . token . getParameterName ( ) , this . token . getToken ( ) ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
verify ( requestHandler ) . handle ( eq ( this . request ) , eq ( this . response ) ) ;
verify ( requestHandler ) . handle ( eq ( this . request ) , eq ( this . response ) ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
}
}
@Test
public void doFilterWhenRequestResolverThenUsed ( ) throws Exception {
given ( this . requestMatcher . matches ( this . request ) ) . willReturn ( true ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( this . token ) ;
CsrfTokenRequestResolver requestResolver = mock ( CsrfTokenRequestResolver . class ) ;
given ( requestResolver . resolveCsrfTokenValue ( this . request , this . token ) ) . willReturn ( this . token . getToken ( ) ) ;
this . filter . setRequestResolver ( requestResolver ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
verify ( requestResolver ) . resolveCsrfTokenValue ( this . request , this . token ) ;
verify ( this . filterChain ) . doFilter ( this . request , this . response ) ;
}
@Test
@Test
public void setRequireCsrfProtectionMatcherNull ( ) {
public void setRequireCsrfProtectionMatcherNull ( ) {
assertThatIllegalArgumentException ( ) . isThrownBy ( ( ) - > this . filter . setRequireCsrfProtectionMatcher ( null ) ) ;
assertThatIllegalArgumentException ( ) . isThrownBy ( ( ) - > this . filter . setRequireCsrfProtectionMatcher ( null ) ) ;
@ -373,16 +365,14 @@ public class CsrfFilterTests {
@Test
@Test
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet ( )
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet ( )
throws ServletException , IOException {
throws ServletException , IOException {
CsrfFilter filter = createCsrfFilter ( this . tokenRepository ) ;
String csrfAttrName = "_csrf" ;
String csrfAttrName = "_csrf" ;
CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor ( ) ;
CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler ( this . tokenRepository ) ;
csrfTokenRequestProcessor . setTokenRepository ( this . tokenRepository ) ;
requestHandler . setCsrfRequestAttributeName ( csrfAttrName ) ;
csrfTokenRequestProcessor . setCsrfRequestAttributeName ( csrfAttrName ) ;
this . filter = createCsrfFilter ( requestHandler ) ;
filter . setRequestHandler ( csrfTokenRequestProcessor ) ;
CsrfToken expectedCsrfToken = spy ( this . token ) ;
CsrfToken expectedCsrfToken = spy ( this . token ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( expectedCsrfToken ) ;
given ( this . tokenRepository . loadToken ( this . request ) ) . willReturn ( expectedCsrfToken ) ;
filter . doFilter ( this . request , this . response , this . filterChain ) ;
this . filter . doFilter ( this . request , this . response , this . filterChain ) ;
verifyNoInteractions ( expectedCsrfToken ) ;
verifyNoInteractions ( expectedCsrfToken ) ;
CsrfToken tokenFromRequest = ( CsrfToken ) this . request . getAttribute ( csrfAttrName ) ;
CsrfToken tokenFromRequest = ( CsrfToken ) this . request . getAttribute ( csrfAttrName ) ;
@ -410,6 +400,6 @@ public class CsrfFilterTests {
return this . isGenerated ;
return this . isGenerated ;
}
}
} ;
}
}
}