|
|
|
@ -44,7 +44,6 @@ import static org.mockito.BDDMockito.given; |
|
|
|
import static org.mockito.Mockito.lenient; |
|
|
|
import static org.mockito.Mockito.lenient; |
|
|
|
import static org.mockito.Mockito.mock; |
|
|
|
import static org.mockito.Mockito.mock; |
|
|
|
import static org.mockito.Mockito.never; |
|
|
|
import static org.mockito.Mockito.never; |
|
|
|
import static org.mockito.Mockito.spy; |
|
|
|
|
|
|
|
import static org.mockito.Mockito.times; |
|
|
|
import static org.mockito.Mockito.times; |
|
|
|
import static org.mockito.Mockito.verify; |
|
|
|
import static org.mockito.Mockito.verify; |
|
|
|
import static org.mockito.Mockito.verifyNoInteractions; |
|
|
|
import static org.mockito.Mockito.verifyNoInteractions; |
|
|
|
@ -86,11 +85,7 @@ public class CsrfFilterTests { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) { |
|
|
|
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) { |
|
|
|
return createCsrfFilter(new CsrfTokenRepositoryRequestHandler(repository)); |
|
|
|
CsrfFilter filter = new CsrfFilter(repository); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private CsrfFilter createCsrfFilter(CsrfTokenRequestHandler requestHandler) { |
|
|
|
|
|
|
|
CsrfFilter filter = new CsrfFilter(requestHandler); |
|
|
|
|
|
|
|
filter.setRequireCsrfProtectionMatcher(this.requestMatcher); |
|
|
|
filter.setRequireCsrfProtectionMatcher(this.requestMatcher); |
|
|
|
filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
return filter; |
|
|
|
return filter; |
|
|
|
@ -103,7 +98,7 @@ public class CsrfFilterTests { |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void constructorNullRepository() { |
|
|
|
public void constructorNullRepository() { |
|
|
|
assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter((CsrfTokenRequestHandler) null)); |
|
|
|
assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// SEC-2276
|
|
|
|
// SEC-2276
|
|
|
|
@ -128,7 +123,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { |
|
|
|
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
@ -139,7 +135,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException { |
|
|
|
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
@ -151,7 +148,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { |
|
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
@ -164,7 +162,8 @@ public class CsrfFilterTests { |
|
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() |
|
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() |
|
|
|
throws ServletException, IOException { |
|
|
|
throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
@ -177,7 +176,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException { |
|
|
|
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
@ -188,7 +188,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException { |
|
|
|
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false); |
|
|
|
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, true)); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
@ -199,7 +200,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { |
|
|
|
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
@ -212,7 +214,8 @@ public class CsrfFilterTests { |
|
|
|
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() |
|
|
|
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() |
|
|
|
throws ServletException, IOException { |
|
|
|
throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); |
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
@ -225,7 +228,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { |
|
|
|
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
@ -239,7 +243,8 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { |
|
|
|
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, true)); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
@ -247,17 +252,17 @@ public class CsrfFilterTests { |
|
|
|
// LazyCsrfTokenRepository requires the response as an attribute
|
|
|
|
// LazyCsrfTokenRepository requires the response as an attribute
|
|
|
|
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); |
|
|
|
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); |
|
|
|
verify(this.filterChain).doFilter(this.request, this.response); |
|
|
|
verify(this.filterChain).doFilter(this.request, this.response); |
|
|
|
verify(this.tokenRepository).saveToken(this.token, this.request, this.response); |
|
|
|
|
|
|
|
verifyNoMoreInteractions(this.deniedHandler); |
|
|
|
verifyNoMoreInteractions(this.deniedHandler); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException { |
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException { |
|
|
|
this.filter = createCsrfFilter(this.tokenRepository); |
|
|
|
this.filter = new CsrfFilter(this.tokenRepository); |
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { |
|
|
|
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) { |
|
|
|
resetRequestResponse(); |
|
|
|
resetRequestResponse(); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setMethod(method); |
|
|
|
this.request.setMethod(method); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
verify(this.filterChain).doFilter(this.request, this.response); |
|
|
|
verify(this.filterChain).doFilter(this.request, this.response); |
|
|
|
@ -273,11 +278,12 @@ public class CsrfFilterTests { |
|
|
|
*/ |
|
|
|
*/ |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception { |
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception { |
|
|
|
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); |
|
|
|
this.filter = new CsrfFilter(this.tokenRepository); |
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) { |
|
|
|
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) { |
|
|
|
resetRequestResponse(); |
|
|
|
resetRequestResponse(); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setMethod(method); |
|
|
|
this.request.setMethod(method); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), |
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), |
|
|
|
@ -288,11 +294,12 @@ public class CsrfFilterTests { |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException { |
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException { |
|
|
|
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); |
|
|
|
this.filter = new CsrfFilter(this.tokenRepository); |
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler); |
|
|
|
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) { |
|
|
|
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) { |
|
|
|
resetRequestResponse(); |
|
|
|
resetRequestResponse(); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.request.setMethod(method); |
|
|
|
this.request.setMethod(method); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), |
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), |
|
|
|
@ -303,10 +310,11 @@ public class CsrfFilterTests { |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterDefaultAccessDenied() throws ServletException, IOException { |
|
|
|
public void doFilterDefaultAccessDenied() throws ServletException, IOException { |
|
|
|
this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository)); |
|
|
|
this.filter = new CsrfFilter(this.tokenRepository); |
|
|
|
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher); |
|
|
|
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); |
|
|
|
@ -317,7 +325,7 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception { |
|
|
|
public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception { |
|
|
|
CsrfTokenRepository repository = mock(CsrfTokenRepository.class); |
|
|
|
CsrfTokenRepository repository = mock(CsrfTokenRepository.class); |
|
|
|
CsrfFilter filter = createCsrfFilter(repository); |
|
|
|
CsrfFilter filter = new CsrfFilter(repository); |
|
|
|
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token); |
|
|
|
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token); |
|
|
|
MockHttpServletRequest request = new MockHttpServletRequest(); |
|
|
|
MockHttpServletRequest request = new MockHttpServletRequest(); |
|
|
|
CsrfFilter.skipRequest(request); |
|
|
|
CsrfFilter.skipRequest(request); |
|
|
|
@ -333,7 +341,8 @@ public class CsrfFilterTests { |
|
|
|
given(token.getToken()).willReturn(null); |
|
|
|
given(token.getToken()).willReturn(null); |
|
|
|
given(token.getHeaderName()).willReturn(this.token.getHeaderName()); |
|
|
|
given(token.getHeaderName()).willReturn(this.token.getHeaderName()); |
|
|
|
given(token.getParameterName()).willReturn(this.token.getParameterName()); |
|
|
|
given(token.getParameterName()).willReturn(this.token.getParameterName()); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(token); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(token, false)); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true); |
|
|
|
filter.doFilterInternal(this.request, this.response, this.filterChain); |
|
|
|
filter.doFilterInternal(this.request, this.response, this.filterChain); |
|
|
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); |
|
|
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); |
|
|
|
@ -341,13 +350,15 @@ public class CsrfFilterTests { |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterWhenRequestHandlerThenUsed() throws Exception { |
|
|
|
public void doFilterWhenRequestHandlerThenUsed() throws Exception { |
|
|
|
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
given(requestHandler.handle(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false)); |
|
|
|
this.filter = createCsrfFilter(requestHandler); |
|
|
|
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); |
|
|
|
|
|
|
|
this.filter = createCsrfFilter(this.tokenRepository); |
|
|
|
|
|
|
|
this.filter.setRequestHandler(requestHandler); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken()); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
verify(requestHandler).handle(eq(this.request), eq(this.response)); |
|
|
|
verify(this.tokenRepository).loadDeferredToken(this.request, this.response); |
|
|
|
|
|
|
|
verify(requestHandler).handle(eq(this.request), eq(this.response), any()); |
|
|
|
verify(this.filterChain).doFilter(this.request, this.response); |
|
|
|
verify(this.filterChain).doFilter(this.request, this.response); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@ -365,41 +376,20 @@ public class CsrfFilterTests { |
|
|
|
@Test |
|
|
|
@Test |
|
|
|
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet() |
|
|
|
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet() |
|
|
|
throws ServletException, IOException { |
|
|
|
throws ServletException, IOException { |
|
|
|
|
|
|
|
CsrfFilter filter = createCsrfFilter(this.tokenRepository); |
|
|
|
String csrfAttrName = "_csrf"; |
|
|
|
String csrfAttrName = "_csrf"; |
|
|
|
CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository); |
|
|
|
CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler(); |
|
|
|
requestHandler.setCsrfRequestAttributeName(csrfAttrName); |
|
|
|
requestHandler.setCsrfRequestAttributeName(csrfAttrName); |
|
|
|
this.filter = createCsrfFilter(requestHandler); |
|
|
|
filter.setRequestHandler(requestHandler); |
|
|
|
CsrfToken expectedCsrfToken = spy(this.token); |
|
|
|
CsrfToken expectedCsrfToken = mock(CsrfToken.class); |
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken); |
|
|
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)) |
|
|
|
|
|
|
|
.willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true)); |
|
|
|
|
|
|
|
|
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
filter.doFilter(this.request, this.response, this.filterChain); |
|
|
|
|
|
|
|
|
|
|
|
verifyNoInteractions(expectedCsrfToken); |
|
|
|
verifyNoInteractions(expectedCsrfToken); |
|
|
|
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName); |
|
|
|
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName); |
|
|
|
assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken); |
|
|
|
assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private static final class TestDeferredCsrfToken implements DeferredCsrfToken { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private final CsrfToken csrfToken; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private final boolean isGenerated; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) { |
|
|
|
|
|
|
|
this.csrfToken = csrfToken; |
|
|
|
|
|
|
|
this.isGenerated = isGenerated; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public CsrfToken get() { |
|
|
|
|
|
|
|
return this.csrfToken; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public boolean isGenerated() { |
|
|
|
|
|
|
|
return this.isGenerated; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|