|
|
|
@ -43,7 +43,6 @@ import static org.mockito.BDDMockito.given; |
|
|
|
import static org.mockito.Mockito.lenient; |
|
|
|
import static org.mockito.Mockito.lenient; |
|
|
|
import static org.mockito.Mockito.mock; |
|
|
|
import static org.mockito.Mockito.mock; |
|
|
|
import static org.mockito.Mockito.never; |
|
|
|
import static org.mockito.Mockito.never; |
|
|
|
import static org.mockito.Mockito.spy; |
|
|
|
|
|
|
|
import static org.mockito.Mockito.times; |
|
|
|
import static org.mockito.Mockito.times; |
|
|
|
import static org.mockito.Mockito.verify; |
|
|
|
import static org.mockito.Mockito.verify; |
|
|
|
import static org.mockito.Mockito.verifyNoInteractions; |
|
|
|
import static org.mockito.Mockito.verifyNoInteractions; |
|
|
|
@ -87,11 +86,7 @@ public class CsrfFilterTests { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) { |
|
|
|
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) { |
|
|
|
return createCsrfFilter(new CsrfTokenRepositoryRequestHandler(repository)); |
|
|
|
CsrfFilter filter = new CsrfFilter(repository); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private CsrfFilter createCsrfFilter(CsrfTokenRequestHandler requestHandler) { |
|
|
|
|
|
|
|
CsrfFilter filter = new CsrfFilter(requestHandler); |
|
|
|
|
|
|
|
filter.setRequireCsrfProtectionMatcher(this.requestMatcher); |
|
|
|
filter.setRequireCsrfProtectionMatcher(this.requestMatcher); |
|
|
|
filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
return filter; |
|
|
|
return filter; |
|
|
|
@ -104,7 +99,7 @@ public class CsrfFilterTests { |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void constructorNullRepository() { |
|
|
|
public void constructorNullRepository() { |
|
|
|
assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter((CsrfTokenRequestHandler) null)); |
|
|
|
assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// SEC-2276
|
|
|
|
// SEC-2276
|
|
|
|
@ -129,7 +124,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { |
|
|
|
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
@ -140,7 +136,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException { |
|
|
|
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
@ -152,7 +149,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { |
|
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
@ -165,7 +163,8 @@ public class CsrfFilterTests { |
|
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() |
|
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() |
|
|
|
throws ServletException, IOException { |
|
|
|
throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
@ -178,7 +177,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException { |
|
|
|
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
@ -189,7 +189,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException { |
|
|
|
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false); |
|
|
|
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, true)); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
@ -200,7 +201,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { |
|
|
|
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
@ -213,7 +215,8 @@ public class CsrfFilterTests { |
|
|
|
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() |
|
|
|
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() |
|
|
|
throws ServletException, IOException { |
|
|
|
throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
@ -226,7 +229,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { |
|
|
|
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
@ -240,7 +244,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { |
|
|
|
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, true)); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
@ -248,16 +253,17 @@ public class CsrfFilterTests { |
|
|
|
// LazyCsrfTokenRepository requires the response as an attribute
|
|
|
|
// LazyCsrfTokenRepository requires the response as an attribute
|
|
|
|
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); |
|
|
|
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); |
|
|
|
verify(this.filterChain).doFilter(this.request, this.response); |
|
|
|
verify(this.filterChain).doFilter(this.request, this.response); |
|
|
|
verify(this.tokenRepository).saveToken(this.token, this.request, this.response); |
|
|
|
|
|
|
|
verifyNoMoreInteractions(this.deniedHandler); |
|
|
|
verifyNoMoreInteractions(this.deniedHandler); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException { |
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException { |
|
|
|
this.filter = createCsrfFilter(this.tokenRepository); |
|
|
|
this.filter = new CsrfFilter(this.tokenRepository); |
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { |
|
|
|
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { |
|
|
|
resetRequestResponse(); |
|
|
|
resetRequestResponse(); |
|
|
|
|
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setMethod(method); |
|
|
|
this.request.setMethod(method); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
verify(this.filterChain).doFilter(this.request, this.response); |
|
|
|
verify(this.filterChain).doFilter(this.request, this.response); |
|
|
|
@ -273,11 +279,12 @@ public class CsrfFilterTests { |
|
|
|
*/ |
|
|
|
*/ |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception { |
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception { |
|
|
|
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); |
|
|
|
this.filter = new CsrfFilter(this.tokenRepository); |
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) { |
|
|
|
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) { |
|
|
|
resetRequestResponse(); |
|
|
|
resetRequestResponse(); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setMethod(method); |
|
|
|
this.request.setMethod(method); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), |
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), |
|
|
|
@ -288,11 +295,12 @@ public class CsrfFilterTests { |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException { |
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException { |
|
|
|
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); |
|
|
|
this.filter = new CsrfFilter(this.tokenRepository); |
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) { |
|
|
|
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) { |
|
|
|
resetRequestResponse(); |
|
|
|
resetRequestResponse(); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setMethod(method); |
|
|
|
this.request.setMethod(method); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), |
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), |
|
|
|
@ -303,10 +311,11 @@ public class CsrfFilterTests { |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterDefaultAccessDenied() throws ServletException, IOException { |
|
|
|
public void doFilterDefaultAccessDenied() throws ServletException, IOException { |
|
|
|
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); |
|
|
|
this.filter = new CsrfFilter(this.tokenRepository); |
|
|
|
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher); |
|
|
|
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
@ -317,7 +326,7 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception { |
|
|
|
public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception { |
|
|
|
CsrfTokenRepository repository = mock(CsrfTokenRepository.class); |
|
|
|
CsrfTokenRepository repository = mock(CsrfTokenRepository.class); |
|
|
|
CsrfFilter filter = createCsrfFilter(repository); |
|
|
|
CsrfFilter filter = new CsrfFilter(repository); |
|
|
|
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token); |
|
|
|
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token); |
|
|
|
MockHttpServletRequest request = new MockHttpServletRequest(); |
|
|
|
MockHttpServletRequest request = new MockHttpServletRequest(); |
|
|
|
CsrfFilter.skipRequest(request); |
|
|
|
CsrfFilter.skipRequest(request); |
|
|
|
@ -333,7 +342,8 @@ public class CsrfFilterTests { |
|
|
|
given(token.getToken()).willReturn(null); |
|
|
|
given(token.getToken()).willReturn(null); |
|
|
|
given(token.getHeaderName()).willReturn(this.token.getHeaderName()); |
|
|
|
given(token.getHeaderName()).willReturn(this.token.getHeaderName()); |
|
|
|
given(token.getParameterName()).willReturn(this.token.getParameterName()); |
|
|
|
given(token.getParameterName()).willReturn(this.token.getParameterName()); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(token, false)); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
filter.doFilterInternal(this.request, this.response, this.filterChain); |
|
|
|
filter.doFilterInternal(this.request, this.response, this.filterChain); |
|
|
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); |
|
|
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); |
|
|
|
@ -341,13 +351,15 @@ public class CsrfFilterTests { |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterWhenRequestHandlerThenUsed() throws Exception { |
|
|
|
public void doFilterWhenRequestHandlerThenUsed() throws Exception { |
|
|
|
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
given(requestHandler.handle(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.filter = createCsrfFilter(requestHandler); |
|
|
|
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); |
|
|
|
|
|
|
|
this.filter = createCsrfFilter(this.tokenRepository); |
|
|
|
|
|
|
|
this.filter.setRequestHandler(requestHandler); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
verify(requestHandler).handle(eq(this.request), eq(this.response)); |
|
|
|
verify(this.tokenRepository).loadDeferredToken(this.request, this.response); |
|
|
|
|
|
|
|
verify(requestHandler).handle(eq(this.request), eq(this.response), any()); |
|
|
|
verify(this.filterChain).doFilter(this.request, this.response); |
|
|
|
verify(this.filterChain).doFilter(this.request, this.response); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@ -365,41 +377,20 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet() |
|
|
|
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet() |
|
|
|
throws ServletException, IOException { |
|
|
|
throws ServletException, IOException { |
|
|
|
|
|
|
|
CsrfFilter filter = createCsrfFilter(this.tokenRepository); |
|
|
|
String csrfAttrName = "_csrf"; |
|
|
|
String csrfAttrName = "_csrf"; |
|
|
|
CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository); |
|
|
|
CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler(); |
|
|
|
requestHandler.setCsrfRequestAttributeName(csrfAttrName); |
|
|
|
requestHandler.setCsrfRequestAttributeName(csrfAttrName); |
|
|
|
this.filter = createCsrfFilter(requestHandler); |
|
|
|
filter.setRequestHandler(requestHandler); |
|
|
|
CsrfToken expectedCsrfToken = spy(this.token); |
|
|
|
CsrfToken expectedCsrfToken = mock(CsrfToken.class); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true)); |
|
|
|
|
|
|
|
|
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
|
|
|
|
|
|
|
|
verifyNoInteractions(expectedCsrfToken); |
|
|
|
verifyNoInteractions(expectedCsrfToken); |
|
|
|
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName); |
|
|
|
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName); |
|
|
|
assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken); |
|
|
|
assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private static final class TestDeferredCsrfToken implements DeferredCsrfToken { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private final CsrfToken csrfToken; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private final boolean isGenerated; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) { |
|
|
|
|
|
|
|
this.csrfToken = csrfToken; |
|
|
|
|
|
|
|
this.isGenerated = isGenerated; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public CsrfToken get() { |
|
|
|
|
|
|
|
return this.csrfToken; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public boolean isGenerated() { |
|
|
|
|
|
|
|
return this.isGenerated; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|