@ -18,6 +18,7 @@ package org.springframework.security.web.csrf;
@@ -18,6 +18,7 @@ package org.springframework.security.web.csrf;
import static org.fest.assertions.Assertions.assertThat ;
import static org.mockito.Matchers.any ;
import static org.mockito.Matchers.eq ;
import static org.mockito.Mockito.times ;
import static org.mockito.Mockito.verify ;
import static org.mockito.Mockito.verifyZeroInteractions ;
import static org.mockito.Mockito.when ;
@ -27,8 +28,11 @@ import java.util.Arrays;
@@ -27,8 +28,11 @@ import java.util.Arrays;
import javax.servlet.FilterChain ;
import javax.servlet.ServletException ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
import org.fest.assertions.GenericAssert ;
import org.fest.assertions.ObjectAssert ;
import org.junit.Before ;
import org.junit.Test ;
import org.junit.runner.RunWith ;
@ -59,12 +63,12 @@ public class CsrfFilterTests {
@@ -59,12 +63,12 @@ public class CsrfFilterTests {
private MockHttpServletResponse response ;
private CsrfToken token ;
private CsrfFilter filter ;
@Before
public void setup ( ) {
token = new CsrfToken ( "headerName" , "paramName" , "csrfTokenValue" ) ;
token = new DefaultCsrfToken ( "headerName" , "paramName" ,
"csrfTokenValue" ) ;
resetRequestResponse ( ) ;
filter = new CsrfFilter ( tokenRepository ) ;
filter . setRequireCsrfProtectionMatcher ( requestMatcher ) ;
@ -81,171 +85,221 @@ public class CsrfFilterTests {
@@ -81,171 +85,221 @@ public class CsrfFilterTests {
new CsrfFilter ( null ) ;
}
// SEC-2276
@Test
public void doFilterDoesNotSaveCsrfTokenUntilAccessed ( ) throws ServletException ,
IOException {
when ( requestMatcher . matches ( request ) ) . thenReturn ( false ) ;
when ( tokenRepository . generateToken ( request ) ) . thenReturn ( token ) ;
filter . doFilter ( request , response , filterChain ) ;
CsrfToken attrToken = ( CsrfToken ) request . getAttribute ( token . getParameterName ( ) ) ;
// no CsrfToken should have been saved yet
verify ( tokenRepository , times ( 0 ) ) . saveToken ( any ( CsrfToken . class ) , any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
verify ( filterChain ) . doFilter ( request , response ) ;
// access the token
attrToken . getToken ( ) ;
// now the CsrfToken should have been saved
verify ( tokenRepository ) . saveToken ( eq ( token ) , any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
}
@Test
public void doFilterAccessDeniedNoTokenPresent ( ) throws ServletException , IOException {
public void doFilterAccessDeniedNoTokenPresent ( ) throws ServletException ,
IOException {
when ( requestMatcher . matches ( request ) ) . thenReturn ( true ) ;
when ( tokenRepository . loadToken ( request ) ) . thenReturn ( token ) ;
filter . doFilter ( request , response , filterChain ) ;
assertThat ( response . getHeader ( token . getHeaderName ( ) ) ) . isEqualTo ( token . getToken ( ) ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo (
token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo (
token ) ;
verify ( deniedHandler ) . handle ( eq ( request ) , eq ( response ) , any ( InvalidCsrfTokenException . class ) ) ;
verify ( deniedHandler ) . handle ( eq ( request ) , eq ( response ) ,
any ( InvalidCsrfTokenException . class ) ) ;
verifyZeroInteractions ( filterChain ) ;
}
@Test
public void doFilterAccessDeniedIncorrectTokenPresent ( ) throws ServletException , IOException {
public void doFilterAccessDeniedIncorrectTokenPresent ( )
throws ServletException , IOException {
when ( requestMatcher . matches ( request ) ) . thenReturn ( true ) ;
when ( tokenRepository . loadToken ( request ) ) . thenReturn ( token ) ;
request . setParameter ( token . getParameterName ( ) , token . getToken ( ) + " INVALID" ) ;
request . setParameter ( token . getParameterName ( ) , token . getToken ( )
+ " INVALID" ) ;
filter . doFilter ( request , response , filterChain ) ;
assertThat ( response . getHeader ( token . getHeaderName ( ) ) ) . isEqualTo ( token . getToken ( ) ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo (
token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo (
token ) ;
verify ( deniedHandler ) . handle ( eq ( request ) , eq ( response ) , any ( InvalidCsrfTokenException . class ) ) ;
verify ( deniedHandler ) . handle ( eq ( request ) , eq ( response ) ,
any ( InvalidCsrfTokenException . class ) ) ;
verifyZeroInteractions ( filterChain ) ;
}
@Test
public void doFilterAccessDeniedIncorrectTokenPresentHeader ( ) throws ServletException , IOException {
public void doFilterAccessDeniedIncorrectTokenPresentHeader ( )
throws ServletException , IOException {
when ( requestMatcher . matches ( request ) ) . thenReturn ( true ) ;
when ( tokenRepository . loadToken ( request ) ) . thenReturn ( token ) ;
request . addHeader ( token . getHeaderName ( ) , token . getToken ( ) + " INVALID" ) ;
request . addHeader ( token . getHeaderName ( ) , token . getToken ( ) + " INVALID" ) ;
filter . doFilter ( request , response , filterChain ) ;
assertThat ( response . getHeader ( token . getHeaderName ( ) ) ) . isEqualTo ( token . getToken ( ) ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo (
token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo (
token ) ;
verify ( deniedHandler ) . handle ( eq ( request ) , eq ( response ) , any ( InvalidCsrfTokenException . class ) ) ;
verify ( deniedHandler ) . handle ( eq ( request ) , eq ( response ) ,
any ( InvalidCsrfTokenException . class ) ) ;
verifyZeroInteractions ( filterChain ) ;
}
@Test
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter ( ) throws ServletException , IOException {
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter ( )
throws ServletException , IOException {
when ( requestMatcher . matches ( request ) ) . thenReturn ( true ) ;
when ( tokenRepository . loadToken ( request ) ) . thenReturn ( token ) ;
request . setParameter ( token . getParameterName ( ) , token . getToken ( ) ) ;
request . addHeader ( token . getHeaderName ( ) , token . getToken ( ) + " INVALID" ) ;
request . addHeader ( token . getHeaderName ( ) , token . getToken ( ) + " INVALID" ) ;
filter . doFilter ( request , response , filterChain ) ;
assertThat ( response . getHeader ( token . getHeaderName ( ) ) ) . isEqualTo ( token . getToken ( ) ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo (
token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo (
token ) ;
verify ( deniedHandler ) . handle ( eq ( request ) , eq ( response ) , any ( InvalidCsrfTokenException . class ) ) ;
verify ( deniedHandler ) . handle ( eq ( request ) , eq ( response ) ,
any ( InvalidCsrfTokenException . class ) ) ;
verifyZeroInteractions ( filterChain ) ;
}
@Test
public void doFilterNotCsrfRequestExistingToken ( ) throws ServletException , IOException {
public void doFilterNotCsrfRequestExistingToken ( ) throws ServletException ,
IOException {
when ( requestMatcher . matches ( request ) ) . thenReturn ( false ) ;
when ( tokenRepository . loadToken ( request ) ) . thenReturn ( token ) ;
filter . doFilter ( request , response , filterChain ) ;
assertThat ( response . getHeader ( token . getHeaderName ( ) ) ) . isEqualTo ( token . getToken ( ) ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo (
token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo (
token ) ;
verify ( filterChain ) . doFilter ( request , response ) ;
verifyZeroInteractions ( deniedHandler ) ;
}
@Test
public void doFilterNotCsrfRequestGenerateToken ( ) throws ServletException , IOException {
public void doFilterNotCsrfRequestGenerateToken ( ) throws ServletException ,
IOException {
when ( requestMatcher . matches ( request ) ) . thenReturn ( false ) ;
when ( tokenRepository . generateAndSaveToken ( request , response ) ) . thenReturn ( token ) ;
when ( tokenRepository . generateToken ( request ) )
. thenReturn ( token ) ;
filter . doFilter ( request , response , filterChain ) ;
assertThat ( response . getHeader ( token . getHeaderName ( ) ) ) . isEqualTo ( token . getToken ( ) ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( token ) ;
assertToken ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo (
token ) ;
assertToken ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo (
token ) ;
verify ( filterChain ) . doFilter ( request , response ) ;
verifyZeroInteractions ( deniedHandler ) ;
}
@Test
public void doFilterIsCsrfRequestExistingTokenHeader ( ) throws ServletException , IOException {
public void doFilterIsCsrfRequestExistingTokenHeader ( )
throws ServletException , IOException {
when ( requestMatcher . matches ( request ) ) . thenReturn ( true ) ;
when ( tokenRepository . loadToken ( request ) ) . thenReturn ( token ) ;
request . addHeader ( token . getHeaderName ( ) , token . getToken ( ) ) ;
filter . doFilter ( request , response , filterChain ) ;
assertThat ( response . getHeader ( token . getHeaderName ( ) ) ) . isEqualTo ( token . getToken ( ) ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo (
token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo (
token ) ;
verify ( filterChain ) . doFilter ( request , response ) ;
verifyZeroInteractions ( deniedHandler ) ;
}
@Test
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam ( ) throws ServletException , IOException {
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam ( )
throws ServletException , IOException {
when ( requestMatcher . matches ( request ) ) . thenReturn ( true ) ;
when ( tokenRepository . loadToken ( request ) ) . thenReturn ( token ) ;
request . setParameter ( token . getParameterName ( ) , token . getToken ( ) + " INVALID" ) ;
request . setParameter ( token . getParameterName ( ) , token . getToken ( )
+ " INVALID" ) ;
request . addHeader ( token . getHeaderName ( ) , token . getToken ( ) ) ;
filter . doFilter ( request , response , filterChain ) ;
assertThat ( response . getHeader ( token . getHeaderName ( ) ) ) . isEqualTo ( token . getToken ( ) ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo (
token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo (
token ) ;
verify ( filterChain ) . doFilter ( request , response ) ;
verifyZeroInteractions ( deniedHandler ) ;
}
@Test
public void doFilterIsCsrfRequestExistingToken ( ) throws ServletException , IOException {
public void doFilterIsCsrfRequestExistingToken ( ) throws ServletException ,
IOException {
when ( requestMatcher . matches ( request ) ) . thenReturn ( true ) ;
when ( tokenRepository . loadToken ( request ) ) . thenReturn ( token ) ;
request . setParameter ( token . getParameterName ( ) , token . getToken ( ) ) ;
filter . doFilter ( request , response , filterChain ) ;
assertThat ( response . getHeader ( token . getHeaderName ( ) ) ) . isEqualTo ( token . getToken ( ) ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo (
token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo (
token ) ;
verify ( filterChain ) . doFilter ( request , response ) ;
verifyZeroInteractions ( deniedHandler ) ;
}
@Test
public void doFilterIsCsrfRequestGenerateToken ( ) throws ServletException , IOException {
public void doFilterIsCsrfRequestGenerateToken ( ) throws ServletException ,
IOException {
when ( requestMatcher . matches ( request ) ) . thenReturn ( true ) ;
when ( tokenRepository . generateAndSaveToken ( request , response ) ) . thenReturn ( token ) ;
when ( tokenRepository . generateToken ( request ) )
. thenReturn ( token ) ;
request . setParameter ( token . getParameterName ( ) , token . getToken ( ) ) ;
filter . doFilter ( request , response , filterChain ) ;
assertThat ( response . getHeader ( token . getHeaderName ( ) ) ) . isEqualTo ( token . getToken ( ) ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( token ) ;
assertToken ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo (
token ) ;
assertToken ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo (
token ) ;
verify ( filterChain ) . doFilter ( request , response ) ;
verifyZeroInteractions ( deniedHandler ) ;
}
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods ( ) throws ServletException , IOException {
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods ( )
throws ServletException , IOException {
filter = new CsrfFilter ( tokenRepository ) ;
filter . setAccessDeniedHandler ( deniedHandler ) ;
for ( String method : Arrays . asList ( "GET" , "TRACE" , "OPTIONS" , "HEAD" ) ) {
for ( String method : Arrays . asList ( "GET" , "TRACE" , "OPTIONS" , "HEAD" ) ) {
resetRequestResponse ( ) ;
when ( tokenRepository . loadToken ( request ) ) . thenReturn ( token ) ;
request . setMethod ( method ) ;
@ -258,24 +312,28 @@ public class CsrfFilterTests {
@@ -258,24 +312,28 @@ public class CsrfFilterTests {
}
@Test
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods ( ) throws ServletException , IOException {
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods ( )
throws ServletException , IOException {
filter = new CsrfFilter ( tokenRepository ) ;
filter . setAccessDeniedHandler ( deniedHandler ) ;
for ( String method : Arrays . asList ( "POST" , "PUT" , "PATCH" , "DELETE" , "INVALID" ) ) {
for ( String method : Arrays . asList ( "POST" , "PUT" , "PATCH" , "DELETE" ,
"INVALID" ) ) {
resetRequestResponse ( ) ;
when ( tokenRepository . loadToken ( request ) ) . thenReturn ( token ) ;
request . setMethod ( method ) ;
filter . doFilter ( request , response , filterChain ) ;
verify ( deniedHandler ) . handle ( eq ( request ) , eq ( response ) , any ( InvalidCsrfTokenException . class ) ) ;
verify ( deniedHandler ) . handle ( eq ( request ) , eq ( response ) ,
any ( InvalidCsrfTokenException . class ) ) ;
verifyZeroInteractions ( filterChain ) ;
}
}
@Test
public void doFilterDefaultAccessDenied ( ) throws ServletException , IOException {
public void doFilterDefaultAccessDenied ( ) throws ServletException ,
IOException {
filter = new CsrfFilter ( tokenRepository ) ;
filter . setRequireCsrfProtectionMatcher ( requestMatcher ) ;
when ( requestMatcher . matches ( request ) ) . thenReturn ( true ) ;
@ -283,11 +341,13 @@ public class CsrfFilterTests {
@@ -283,11 +341,13 @@ public class CsrfFilterTests {
filter . doFilter ( request , response , filterChain ) ;
assertThat ( response . getHeader ( token . getHeaderName ( ) ) ) . isEqualTo ( token . getToken ( ) ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo ( token ) ;
assertThat ( request . getAttribute ( token . getParameterName ( ) ) ) . isEqualTo (
token ) ;
assertThat ( request . getAttribute ( CsrfToken . class . getName ( ) ) ) . isEqualTo (
token ) ;
assertThat ( response . getStatus ( ) ) . isEqualTo ( HttpServletResponse . SC_FORBIDDEN ) ;
assertThat ( response . getStatus ( ) ) . isEqualTo (
HttpServletResponse . SC_FORBIDDEN ) ;
verifyZeroInteractions ( filterChain ) ;
}
@ -300,4 +360,29 @@ public class CsrfFilterTests {
@@ -300,4 +360,29 @@ public class CsrfFilterTests {
public void setAccessDeniedHandlerNull ( ) {
filter . setAccessDeniedHandler ( null ) ;
}
private static final CsrfTokenAssert assertToken ( Object token ) {
return new CsrfTokenAssert ( ( CsrfToken ) token ) ;
}
private static class CsrfTokenAssert extends
GenericAssert < CsrfTokenAssert , CsrfToken > {
/ * *
* Creates a new < / code > { @link ObjectAssert } < / code > .
*
* @param actual
* the target to verify .
* /
protected CsrfTokenAssert ( CsrfToken actual ) {
super ( CsrfTokenAssert . class , actual ) ;
}
public CsrfTokenAssert isEqualTo ( CsrfToken expected ) {
assertThat ( actual . getHeaderName ( ) ) . isEqualTo ( expected . getHeaderName ( ) ) ;
assertThat ( actual . getParameterName ( ) ) . isEqualTo ( expected . getParameterName ( ) ) ;
assertThat ( actual . getToken ( ) ) . isEqualTo ( expected . getToken ( ) ) ;
return this ;
}
}
}