|
|
|
@ -1,5 +1,5 @@ |
|
|
|
/* |
|
|
|
/* |
|
|
|
* Copyright 2002-2020 the original author or authors. |
|
|
|
* Copyright 2002-2017 the original author or authors. |
|
|
|
* |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@ -20,14 +20,10 @@ import org.junit.Test; |
|
|
|
import org.junit.runner.RunWith; |
|
|
|
import org.junit.runner.RunWith; |
|
|
|
import org.mockito.Mock; |
|
|
|
import org.mockito.Mock; |
|
|
|
import org.mockito.junit.MockitoJUnitRunner; |
|
|
|
import org.mockito.junit.MockitoJUnitRunner; |
|
|
|
|
|
|
|
|
|
|
|
import org.springframework.http.HttpMethod; |
|
|
|
|
|
|
|
import org.springframework.http.HttpStatus; |
|
|
|
import org.springframework.http.HttpStatus; |
|
|
|
import org.springframework.http.MediaType; |
|
|
|
import org.springframework.http.MediaType; |
|
|
|
import org.springframework.mock.http.server.reactive.MockServerHttpRequest; |
|
|
|
import org.springframework.mock.http.server.reactive.MockServerHttpRequest; |
|
|
|
import org.springframework.mock.web.server.MockServerWebExchange; |
|
|
|
import org.springframework.mock.web.server.MockServerWebExchange; |
|
|
|
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; |
|
|
|
|
|
|
|
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult; |
|
|
|
|
|
|
|
import org.springframework.web.server.WebFilterChain; |
|
|
|
import org.springframework.web.server.WebFilterChain; |
|
|
|
import org.springframework.web.server.WebSession; |
|
|
|
import org.springframework.web.server.WebSession; |
|
|
|
import reactor.core.publisher.Mono; |
|
|
|
import reactor.core.publisher.Mono; |
|
|
|
@ -37,11 +33,9 @@ import reactor.test.publisher.PublisherProbe; |
|
|
|
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; |
|
|
|
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; |
|
|
|
import static org.mockito.ArgumentMatchers.any; |
|
|
|
import static org.mockito.ArgumentMatchers.any; |
|
|
|
import static org.mockito.Mockito.when; |
|
|
|
import static org.mockito.Mockito.when; |
|
|
|
import static org.springframework.mock.web.server.MockServerWebExchange.from; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
/** |
|
|
|
* @author Rob Winch |
|
|
|
* @author Rob Winch |
|
|
|
* @author Parikshit Dutta |
|
|
|
|
|
|
|
* @since 5.0 |
|
|
|
* @since 5.0 |
|
|
|
*/ |
|
|
|
*/ |
|
|
|
@RunWith(MockitoJUnitRunner.class) |
|
|
|
@RunWith(MockitoJUnitRunner.class) |
|
|
|
@ -55,10 +49,10 @@ public class CsrfWebFilterTests { |
|
|
|
|
|
|
|
|
|
|
|
private CsrfWebFilter csrfFilter = new CsrfWebFilter(); |
|
|
|
private CsrfWebFilter csrfFilter = new CsrfWebFilter(); |
|
|
|
|
|
|
|
|
|
|
|
private MockServerWebExchange get = from( |
|
|
|
private MockServerWebExchange get = MockServerWebExchange.from( |
|
|
|
MockServerHttpRequest.get("/")); |
|
|
|
MockServerHttpRequest.get("/")); |
|
|
|
|
|
|
|
|
|
|
|
private MockServerWebExchange post = from( |
|
|
|
private MockServerWebExchange post = MockServerWebExchange.from( |
|
|
|
MockServerHttpRequest.post("/")); |
|
|
|
MockServerHttpRequest.post("/")); |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
@Test |
|
|
|
@ -110,7 +104,7 @@ public class CsrfWebFilterTests { |
|
|
|
this.csrfFilter.setCsrfTokenRepository(this.repository); |
|
|
|
this.csrfFilter.setCsrfTokenRepository(this.repository); |
|
|
|
when(this.repository.loadToken(any())) |
|
|
|
when(this.repository.loadToken(any())) |
|
|
|
.thenReturn(Mono.just(this.token)); |
|
|
|
.thenReturn(Mono.just(this.token)); |
|
|
|
this.post = from(MockServerHttpRequest.post("/") |
|
|
|
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") |
|
|
|
.body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID")); |
|
|
|
.body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID")); |
|
|
|
|
|
|
|
|
|
|
|
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain); |
|
|
|
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain); |
|
|
|
@ -131,7 +125,7 @@ public class CsrfWebFilterTests { |
|
|
|
.thenReturn(Mono.just(this.token)); |
|
|
|
.thenReturn(Mono.just(this.token)); |
|
|
|
when(this.repository.generateToken(any())) |
|
|
|
when(this.repository.generateToken(any())) |
|
|
|
.thenReturn(Mono.just(this.token)); |
|
|
|
.thenReturn(Mono.just(this.token)); |
|
|
|
this.post = from(MockServerHttpRequest.post("/") |
|
|
|
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") |
|
|
|
.contentType(MediaType.APPLICATION_FORM_URLENCODED) |
|
|
|
.contentType(MediaType.APPLICATION_FORM_URLENCODED) |
|
|
|
.body(this.token.getParameterName() + "="+this.token.getToken())); |
|
|
|
.body(this.token.getParameterName() + "="+this.token.getToken())); |
|
|
|
|
|
|
|
|
|
|
|
@ -148,7 +142,7 @@ public class CsrfWebFilterTests { |
|
|
|
this.csrfFilter.setCsrfTokenRepository(this.repository); |
|
|
|
this.csrfFilter.setCsrfTokenRepository(this.repository); |
|
|
|
when(this.repository.loadToken(any())) |
|
|
|
when(this.repository.loadToken(any())) |
|
|
|
.thenReturn(Mono.just(this.token)); |
|
|
|
.thenReturn(Mono.just(this.token)); |
|
|
|
this.post = from(MockServerHttpRequest.post("/") |
|
|
|
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") |
|
|
|
.header(this.token.getHeaderName(), this.token.getToken()+"INVALID")); |
|
|
|
.header(this.token.getHeaderName(), this.token.getToken()+"INVALID")); |
|
|
|
|
|
|
|
|
|
|
|
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain); |
|
|
|
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain); |
|
|
|
@ -169,7 +163,7 @@ public class CsrfWebFilterTests { |
|
|
|
.thenReturn(Mono.just(this.token)); |
|
|
|
.thenReturn(Mono.just(this.token)); |
|
|
|
when(this.repository.generateToken(any())) |
|
|
|
when(this.repository.generateToken(any())) |
|
|
|
.thenReturn(Mono.just(this.token)); |
|
|
|
.thenReturn(Mono.just(this.token)); |
|
|
|
this.post = from(MockServerHttpRequest.post("/") |
|
|
|
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") |
|
|
|
.header(this.token.getHeaderName(), this.token.getToken())); |
|
|
|
.header(this.token.getHeaderName(), this.token.getToken())); |
|
|
|
|
|
|
|
|
|
|
|
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain); |
|
|
|
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain); |
|
|
|
@ -179,14 +173,4 @@ public class CsrfWebFilterTests { |
|
|
|
|
|
|
|
|
|
|
|
chainResult.assertWasSubscribed(); |
|
|
|
chainResult.assertWasSubscribed(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
|
|
|
// gh-8452
|
|
|
|
|
|
|
|
public void matchesRequireCsrfProtectionWhenNonStandardHTTPMethodIsUsed() { |
|
|
|
|
|
|
|
HttpMethod customHttpMethod = HttpMethod.resolve("non-standard-http-method"); |
|
|
|
|
|
|
|
MockServerWebExchange nonStandardHttpRequest = from(MockServerHttpRequest.method(customHttpMethod, "/")); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ServerWebExchangeMatcher serverWebExchangeMatcher = CsrfWebFilter.DEFAULT_CSRF_MATCHER; |
|
|
|
|
|
|
|
assertThat(serverWebExchangeMatcher.matches(nonStandardHttpRequest).map(MatchResult::isMatch).block()).isTrue(); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|