diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 2a1f43cb2a..19711e54af 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -2731,6 +2731,19 @@ public class ServerHttpSecurity { return this; } + /** + * Specifies if {@link CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart + * data requests. + * + * @param enabled true if should read from multipart form body, else false. Default is false + * @return the {@link CsrfSpec} for additional configuration + */ + public CsrfSpec tokenFromMultipartDataEnabled(boolean enabled) { + this.filter.setTokenFromMultipartDataEnabled(enabled); + return this; + } + + /** * Allows method chaining to continue configuring the {@link ServerHttpSecurity} * @return the {@link ServerHttpSecurity} to continue configuring diff --git a/gradle/dependency-management.gradle b/gradle/dependency-management.gradle index a33638468b..4df20c7303 100644 --- a/gradle/dependency-management.gradle +++ b/gradle/dependency-management.gradle @@ -210,6 +210,7 @@ dependencyManagement { dependency 'org.slf4j:slf4j-nop:1.7.28' dependency 'org.sonatype.sisu.inject:cglib:2.2.1-v20090111' dependency 'org.springframework.ldap:spring-ldap-core:2.3.2.RELEASE' + dependency 'org.synchronoss.cloud:nio-multipart-parser:1.1.0' dependency 'org.thymeleaf:thymeleaf-spring5:3.0.11.RELEASE' dependency 'org.unbescape:unbescape:1.1.5.RELEASE' dependency 'org.w3c.css:sac:1.3' diff --git a/web/spring-security-web.gradle b/web/spring-security-web.gradle index 460a8f1461..1fb1344a23 100644 --- a/web/spring-security-web.gradle +++ b/web/spring-security-web.gradle @@ -25,6 +25,7 @@ dependencies { testCompile 'org.codehaus.groovy:groovy-all' testCompile 'org.skyscreamer:jsonassert' testCompile 'org.springframework:spring-webflux' + testCompile 'org.synchronoss.cloud:nio-multipart-parser' testCompile powerMock2Dependencies testCompile spockDependencies diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java index 111306c3fd..33d06d39da 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java @@ -16,14 +16,12 @@ package org.springframework.security.web.server.csrf; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Set; - -import reactor.core.publisher.Mono; - +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.multipart.FormFieldPart; +import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; @@ -31,6 +29,11 @@ import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; +import reactor.core.publisher.Mono; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; import static java.lang.Boolean.TRUE; @@ -78,6 +81,8 @@ public class CsrfWebFilter implements WebFilter { private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN); + private boolean isTokenFromMultipartDataEnabled; + public void setAccessDeniedHandler( ServerAccessDeniedHandler accessDeniedHandler) { Assert.notNull(accessDeniedHandler, "accessDeniedHandler"); @@ -96,6 +101,15 @@ public class CsrfWebFilter implements WebFilter { this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher; } + /** + * Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart + * data requests. + * @param tokenFromMultipartDataEnabled true if should read from multipart form body, else false. Default is false + */ + public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) { + this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled; + } + @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) { @@ -128,9 +142,26 @@ public class CsrfWebFilter implements WebFilter { return exchange.getFormData() .flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName()))) .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName()))) + .switchIfEmpty(tokenFromMultipartData(exchange, expected)) .map(actual -> actual.equals(expected.getToken())); } + private Mono tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) { + if (!this.isTokenFromMultipartDataEnabled) { + return Mono.empty(); + } + ServerHttpRequest request = exchange.getRequest(); + HttpHeaders headers = request.getHeaders(); + MediaType contentType = headers.getContentType(); + if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) { + return Mono.empty(); + } + return exchange.getMultipartData() + .map(d -> d.getFirst(expected.getParameterName())) + .cast(FormFieldPart.class) + .map(FormFieldPart::value); + } + private Mono continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) { return Mono.defer(() ->{ Mono csrfToken = csrfToken(exchange); diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java index 1101ddbba9..cd3df76693 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java @@ -20,17 +20,20 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; -import reactor.test.publisher.PublisherProbe; - import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebSession; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.PublisherProbe; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -38,6 +41,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; import static org.springframework.mock.web.server.MockServerWebExchange.from; +import static org.springframework.web.reactive.function.BodyInserters.fromMultipartData; /** * @author Rob Winch @@ -57,7 +61,7 @@ public class CsrfWebFilterTests { private MockServerWebExchange get = from( MockServerHttpRequest.get("/")); - private MockServerWebExchange post = from( + private ServerWebExchange post = from( MockServerHttpRequest.post("/")); @Test @@ -193,4 +197,91 @@ public class CsrfWebFilterTests { verifyZeroInteractions(matcher); } + + @Test + public void filterWhenMultipartFormDataAndNotEnabledThenDenied() { + this.csrfFilter.setCsrfTokenRepository(this.repository); + when(this.repository.loadToken(any())) + .thenReturn(Mono.just(this.token)); + + WebTestClient client = WebTestClient.bindToController(new OkController()) + .webFilter(this.csrfFilter) + .build(); + + client.post() + .uri("/") + .contentType(MediaType.MULTIPART_FORM_DATA) + .body(fromMultipartData(this.token.getParameterName(), this.token.getToken())) + .exchange() + .expectStatus().isForbidden(); + } + + @Test + public void filterWhenMultipartFormDataAndEnabledThenGranted() { + this.csrfFilter.setCsrfTokenRepository(this.repository); + this.csrfFilter.setTokenFromMultipartDataEnabled(true); + when(this.repository.loadToken(any())) + .thenReturn(Mono.just(this.token)); + when(this.repository.generateToken(any())) + .thenReturn(Mono.just(this.token)); + + WebTestClient client = WebTestClient.bindToController(new OkController()) + .webFilter(this.csrfFilter) + .build(); + + client.post() + .uri("/") + .contentType(MediaType.MULTIPART_FORM_DATA) + .body(fromMultipartData(this.token.getParameterName(), this.token.getToken())) + .exchange() + .expectStatus().is2xxSuccessful(); + } + + @Test + public void filterWhenFormDataAndEnabledThenGranted() { + this.csrfFilter.setCsrfTokenRepository(this.repository); + this.csrfFilter.setTokenFromMultipartDataEnabled(true); + when(this.repository.loadToken(any())) + .thenReturn(Mono.just(this.token)); + when(this.repository.generateToken(any())) + .thenReturn(Mono.just(this.token)); + + WebTestClient client = WebTestClient.bindToController(new OkController()) + .webFilter(this.csrfFilter) + .build(); + + client.post() + .uri("/") + .contentType(MediaType.APPLICATION_FORM_URLENCODED) + .bodyValue(this.token.getParameterName() + "="+this.token.getToken()) + .exchange() + .expectStatus().is2xxSuccessful(); + } + + @Test + public void filterWhenMultipartMixedAndEnabledThenNotRead() { + this.csrfFilter.setCsrfTokenRepository(this.repository); + this.csrfFilter.setTokenFromMultipartDataEnabled(true); + when(this.repository.loadToken(any())) + .thenReturn(Mono.just(this.token)); + + WebTestClient client = WebTestClient.bindToController(new OkController()) + .webFilter(this.csrfFilter) + .build(); + + client.post() + .uri("/") + .contentType(MediaType.MULTIPART_MIXED) + .bodyValue(this.token.getParameterName() + "="+this.token.getToken()) + .exchange() + .expectStatus().isForbidden(); + } + + @RestController + static class OkController { + @RequestMapping("/**") + String ok() { + return "ok"; + } + } }