Browse Source

CORS support in HTTP method predicate

This commit introduces CORS support for the HttpMethodPredicate in
WebMvc.fn and WebFlux.fn.

Closes gh-24564
pull/24674/head
Arjen Poutsma 6 years ago
parent
commit
c03cdbac21
  1. 21
      spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java
  2. 144
      spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java
  3. 20
      spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java
  4. 19
      spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java

21
spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -52,6 +52,7 @@ import org.springframework.lang.NonNull; @@ -52,6 +52,7 @@ import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.cors.reactive.CorsUtils;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
@ -449,11 +450,25 @@ public abstract class RequestPredicates { @@ -449,11 +450,25 @@ public abstract class RequestPredicates {
@Override
public boolean test(ServerRequest request) {
boolean match = this.httpMethods.contains(request.method());
traceMatch("Method", this.httpMethods, request.method(), match);
HttpMethod method = method(request);
boolean match = this.httpMethods.contains(method);
traceMatch("Method", this.httpMethods, method, match);
return match;
}
@Nullable
private static HttpMethod method(ServerRequest request) {
if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) {
String accessControlRequestMethod =
request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
return HttpMethod.resolve(accessControlRequestMethod);
}
else {
return request.method();
}
}
@Override
public void accept(Visitor visitor) {
visitor.method(Collections.unmodifiableSet(this.httpMethods));

144
spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,10 +22,14 @@ import java.util.function.Function; @@ -22,10 +22,14 @@ import java.util.function.Function;
import org.junit.jupiter.api.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest;
import org.springframework.web.testfixture.server.MockServerWebExchange;
import org.springframework.web.util.pattern.PathPatternParser;
import static java.util.Collections.emptyList;
import static org.assertj.core.api.Assertions.assertThat;
/**
@ -33,98 +37,133 @@ import static org.assertj.core.api.Assertions.assertThat; @@ -33,98 +37,133 @@ import static org.assertj.core.api.Assertions.assertThat;
*/
public class RequestPredicatesTests {
@Test
public void all() {
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com").build();
MockServerWebExchange mockExchange = MockServerWebExchange.from(mockRequest);
RequestPredicate predicate = RequestPredicates.all();
MockServerRequest request = MockServerRequest.builder().build();
ServerRequest request = new DefaultServerRequest(mockExchange, Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
}
@Test
public void method() {
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com").build();
HttpMethod httpMethod = HttpMethod.GET;
RequestPredicate predicate = RequestPredicates.method(httpMethod);
MockServerRequest request = MockServerRequest.builder().method(httpMethod).build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
mockRequest = MockServerHttpRequest.post("https://example.com").build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isFalse();
}
@Test
public void methodCorsPreFlight() {
RequestPredicate predicate = RequestPredicates.method(HttpMethod.PUT);
MockServerHttpRequest mockRequest = MockServerHttpRequest.options("https://example.com")
.header("Origin", "https://example.com")
.header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "PUT")
.build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
request = MockServerRequest.builder().method(HttpMethod.POST).build();
mockRequest = MockServerHttpRequest.options("https://example.com")
.header("Origin", "https://example.com")
.header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST")
.build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isFalse();
}
@Test
public void methods() {
RequestPredicate predicate = RequestPredicates.methods(HttpMethod.GET, HttpMethod.HEAD);
MockServerRequest request = MockServerRequest.builder().method(HttpMethod.GET).build();
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com").build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
request = MockServerRequest.builder().method(HttpMethod.HEAD).build();
mockRequest = MockServerHttpRequest.head("https://example.com").build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
request = MockServerRequest.builder().method(HttpMethod.POST).build();
mockRequest = MockServerHttpRequest.post("https://example.com").build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isFalse();
}
@Test
public void allMethods() {
URI uri = URI.create("http://localhost/path");
RequestPredicate predicate = RequestPredicates.GET("/p*");
MockServerRequest request = MockServerRequest.builder().method(HttpMethod.GET).uri(uri).build();
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com/path").build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
predicate = RequestPredicates.HEAD("/p*");
request = MockServerRequest.builder().method(HttpMethod.HEAD).uri(uri).build();
mockRequest = MockServerHttpRequest.head("https://example.com/path").build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
predicate = RequestPredicates.POST("/p*");
request = MockServerRequest.builder().method(HttpMethod.POST).uri(uri).build();
mockRequest = MockServerHttpRequest.post("https://example.com/path").build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
predicate = RequestPredicates.PUT("/p*");
request = MockServerRequest.builder().method(HttpMethod.PUT).uri(uri).build();
mockRequest = MockServerHttpRequest.put("https://example.com/path").build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
predicate = RequestPredicates.PATCH("/p*");
request = MockServerRequest.builder().method(HttpMethod.PATCH).uri(uri).build();
mockRequest = MockServerHttpRequest.patch("https://example.com/path").build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
predicate = RequestPredicates.DELETE("/p*");
request = MockServerRequest.builder().method(HttpMethod.DELETE).uri(uri).build();
mockRequest = MockServerHttpRequest.delete("https://example.com/path").build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
predicate = RequestPredicates.OPTIONS("/p*");
request = MockServerRequest.builder().method(HttpMethod.OPTIONS).uri(uri).build();
mockRequest = MockServerHttpRequest.options("https://example.com/path").build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
}
@Test
public void path() {
URI uri = URI.create("http://localhost/path");
URI uri = URI.create("https://localhost/path");
RequestPredicate predicate = RequestPredicates.path("/p*");
MockServerRequest request = MockServerRequest.builder().uri(uri).build();
MockServerHttpRequest mockRequest = MockServerHttpRequest.get(uri.toString()).build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList());
assertThat(predicate.test(request)).isTrue();
request = MockServerRequest.builder().build();
mockRequest = MockServerHttpRequest.head("https://example.com").build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isFalse();
}
@Test
public void pathNoLeadingSlash() {
URI uri = URI.create("http://localhost/path");
RequestPredicate predicate = RequestPredicates.path("p*");
MockServerRequest request = MockServerRequest.builder().uri(uri).build();
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com/path").build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
}
@Test
public void pathEncoded() {
URI uri = URI.create("http://localhost/foo%20bar");
URI uri = URI.create("https://localhost/foo%20bar");
RequestPredicate predicate = RequestPredicates.path("/foo bar");
MockServerRequest request = MockServerRequest.builder().uri(uri).build();
MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
request = MockServerRequest.builder().build();
assertThat(predicate.test(request)).isFalse();
}
@Test
@ -133,9 +172,9 @@ public class RequestPredicatesTests { @@ -133,9 +172,9 @@ public class RequestPredicatesTests {
parser.setCaseSensitive(false);
Function<String, RequestPredicate> pathPredicates = RequestPredicates.pathPredicates(parser);
URI uri = URI.create("http://localhost/path");
RequestPredicate predicate = pathPredicates.apply("/P*");
MockServerRequest request = MockServerRequest.builder().uri(uri).build();
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com/path").build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
}
@ -146,10 +185,16 @@ public class RequestPredicatesTests { @@ -146,10 +185,16 @@ public class RequestPredicatesTests {
RequestPredicate predicate =
RequestPredicates.headers(
headers -> headers.header(name).equals(Collections.singletonList(value)));
MockServerRequest request = MockServerRequest.builder().header(name, value).build();
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
.header(name, value)
.build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
request = MockServerRequest.builder().build();
mockRequest = MockServerHttpRequest.get("https://example.com")
.header(name, "bar")
.build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isFalse();
}
@ -157,10 +202,16 @@ public class RequestPredicatesTests { @@ -157,10 +202,16 @@ public class RequestPredicatesTests {
public void contentType() {
MediaType json = MediaType.APPLICATION_JSON;
RequestPredicate predicate = RequestPredicates.contentType(json);
MockServerRequest request = MockServerRequest.builder().header("Content-Type", json.toString()).build();
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.CONTENT_TYPE, json.toString())
.build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
request = MockServerRequest.builder().build();
mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.CONTENT_TYPE, "foo/bar")
.build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isFalse();
}
@ -168,10 +219,16 @@ public class RequestPredicatesTests { @@ -168,10 +219,16 @@ public class RequestPredicatesTests {
public void accept() {
MediaType json = MediaType.APPLICATION_JSON;
RequestPredicate predicate = RequestPredicates.accept(json);
MockServerRequest request = MockServerRequest.builder().header("Accept", json.toString()).build();
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.ACCEPT, json.toString())
.build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
request = MockServerRequest.builder().header("Accept", MediaType.TEXT_XML_VALUE).build();
mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.ACCEPT, "foo/bar")
.build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isFalse();
}
@ -179,25 +236,30 @@ public class RequestPredicatesTests { @@ -179,25 +236,30 @@ public class RequestPredicatesTests {
public void pathExtension() {
RequestPredicate predicate = RequestPredicates.pathExtension("txt");
URI uri = URI.create("http://localhost/file.txt");
MockServerRequest request = MockServerRequest.builder().uri(uri).build();
URI uri = URI.create("https://localhost/file.txt");
MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
uri = URI.create("http://localhost/FILE.TXT");
request = MockServerRequest.builder().uri(uri).build();
uri = URI.create("https://localhost/FILE.TXT");
mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
predicate = RequestPredicates.pathExtension("bar");
assertThat(predicate.test(request)).isFalse();
uri = URI.create("http://localhost/file.foo");
request = MockServerRequest.builder().uri(uri).build();
uri = URI.create("https://localhost/file.foo");
mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isFalse();
}
@Test
public void queryParam() {
MockServerRequest request = MockServerRequest.builder().queryParam("foo", "bar").build();
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
.queryParam("foo", "bar").build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
RequestPredicate predicate = RequestPredicates.queryParam("foo", s -> s.equals("bar"));
assertThat(predicate.test(request)).isTrue();

20
spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -53,6 +53,7 @@ import org.springframework.lang.NonNull; @@ -53,6 +53,7 @@ import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.util.UriBuilder;
import org.springframework.web.util.UriUtils;
import org.springframework.web.util.pattern.PathPattern;
@ -444,11 +445,24 @@ public abstract class RequestPredicates { @@ -444,11 +445,24 @@ public abstract class RequestPredicates {
@Override
public boolean test(ServerRequest request) {
boolean match = this.httpMethods.contains(request.method());
traceMatch("Method", this.httpMethods, request.method(), match);
HttpMethod method = method(request);
boolean match = this.httpMethods.contains(method);
traceMatch("Method", this.httpMethods, method, match);
return match;
}
@Nullable
private static HttpMethod method(ServerRequest request) {
if (CorsUtils.isPreFlightRequest(request.servletRequest())) {
String accessControlRequestMethod =
request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
return HttpMethod.resolve(accessControlRequestMethod);
}
else {
return request.method();
}
}
@Override
public void accept(Visitor visitor) {
visitor.method(Collections.unmodifiableSet(this.httpMethods));

19
spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,6 +21,7 @@ import java.util.function.Function; @@ -21,6 +21,7 @@ import java.util.function.Function;
import org.junit.jupiter.api.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
@ -57,6 +58,22 @@ public class RequestPredicatesTests { @@ -57,6 +58,22 @@ public class RequestPredicatesTests {
assertThat(predicate.test(request)).isFalse();
}
@Test
public void methodCorsPreFlight() {
RequestPredicate predicate = RequestPredicates.method(HttpMethod.PUT);
MockHttpServletRequest servletRequest = new MockHttpServletRequest("OPTIONS", "https://example.com");
servletRequest.addHeader("Origin", "https://example.com");
servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "PUT");
ServerRequest request = new DefaultServerRequest(servletRequest, emptyList());
assertThat(predicate.test(request)).isTrue();
servletRequest.removeHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST");
request = new DefaultServerRequest(servletRequest, emptyList());
assertThat(predicate.test(request)).isFalse();
}
@Test
public void methods() {
RequestPredicate predicate = RequestPredicates.methods(HttpMethod.GET, HttpMethod.HEAD);

Loading…
Cancel
Save