diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java index cbb3a1795df..2569ce2d1a0 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,6 +52,7 @@ import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; +import org.springframework.web.cors.reactive.CorsUtils; import org.springframework.web.reactive.function.BodyExtractor; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; @@ -449,11 +450,25 @@ public abstract class RequestPredicates { @Override public boolean test(ServerRequest request) { - boolean match = this.httpMethods.contains(request.method()); - traceMatch("Method", this.httpMethods, request.method(), match); + HttpMethod method = method(request); + boolean match = this.httpMethods.contains(method); + traceMatch("Method", this.httpMethods, method, match); return match; } + @Nullable + private static HttpMethod method(ServerRequest request) { + if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) { + String accessControlRequestMethod = + request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); + return HttpMethod.resolve(accessControlRequestMethod); + } + else { + return request.method(); + } + } + + @Override public void accept(Visitor visitor) { visitor.method(Collections.unmodifiableSet(this.httpMethods)); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java index 7783d1fe683..4282baec278 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,10 +22,14 @@ import java.util.function.Function; import org.junit.jupiter.api.Test; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest; +import org.springframework.web.testfixture.server.MockServerWebExchange; import org.springframework.web.util.pattern.PathPatternParser; +import static java.util.Collections.emptyList; import static org.assertj.core.api.Assertions.assertThat; /** @@ -33,98 +37,133 @@ import static org.assertj.core.api.Assertions.assertThat; */ public class RequestPredicatesTests { + @Test public void all() { + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com").build(); + MockServerWebExchange mockExchange = MockServerWebExchange.from(mockRequest); RequestPredicate predicate = RequestPredicates.all(); - MockServerRequest request = MockServerRequest.builder().build(); + ServerRequest request = new DefaultServerRequest(mockExchange, Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); } @Test public void method() { + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com").build(); + HttpMethod httpMethod = HttpMethod.GET; RequestPredicate predicate = RequestPredicates.method(httpMethod); - MockServerRequest request = MockServerRequest.builder().method(httpMethod).build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); + assertThat(predicate.test(request)).isTrue(); + + mockRequest = MockServerHttpRequest.post("https://example.com").build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); + assertThat(predicate.test(request)).isFalse(); + } + + @Test + public void methodCorsPreFlight() { + RequestPredicate predicate = RequestPredicates.method(HttpMethod.PUT); + + MockServerHttpRequest mockRequest = MockServerHttpRequest.options("https://example.com") + .header("Origin", "https://example.com") + .header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "PUT") + .build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); + assertThat(predicate.test(request)).isTrue(); - request = MockServerRequest.builder().method(HttpMethod.POST).build(); + mockRequest = MockServerHttpRequest.options("https://example.com") + .header("Origin", "https://example.com") + .header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST") + .build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isFalse(); } + @Test public void methods() { RequestPredicate predicate = RequestPredicates.methods(HttpMethod.GET, HttpMethod.HEAD); - MockServerRequest request = MockServerRequest.builder().method(HttpMethod.GET).build(); + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com").build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); - request = MockServerRequest.builder().method(HttpMethod.HEAD).build(); + mockRequest = MockServerHttpRequest.head("https://example.com").build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); - request = MockServerRequest.builder().method(HttpMethod.POST).build(); + mockRequest = MockServerHttpRequest.post("https://example.com").build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isFalse(); } @Test public void allMethods() { - URI uri = URI.create("http://localhost/path"); - RequestPredicate predicate = RequestPredicates.GET("/p*"); - MockServerRequest request = MockServerRequest.builder().method(HttpMethod.GET).uri(uri).build(); + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com/path").build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); predicate = RequestPredicates.HEAD("/p*"); - request = MockServerRequest.builder().method(HttpMethod.HEAD).uri(uri).build(); + mockRequest = MockServerHttpRequest.head("https://example.com/path").build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); predicate = RequestPredicates.POST("/p*"); - request = MockServerRequest.builder().method(HttpMethod.POST).uri(uri).build(); + mockRequest = MockServerHttpRequest.post("https://example.com/path").build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); predicate = RequestPredicates.PUT("/p*"); - request = MockServerRequest.builder().method(HttpMethod.PUT).uri(uri).build(); + mockRequest = MockServerHttpRequest.put("https://example.com/path").build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); predicate = RequestPredicates.PATCH("/p*"); - request = MockServerRequest.builder().method(HttpMethod.PATCH).uri(uri).build(); + mockRequest = MockServerHttpRequest.patch("https://example.com/path").build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); predicate = RequestPredicates.DELETE("/p*"); - request = MockServerRequest.builder().method(HttpMethod.DELETE).uri(uri).build(); + mockRequest = MockServerHttpRequest.delete("https://example.com/path").build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); predicate = RequestPredicates.OPTIONS("/p*"); - request = MockServerRequest.builder().method(HttpMethod.OPTIONS).uri(uri).build(); + mockRequest = MockServerHttpRequest.options("https://example.com/path").build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); } @Test public void path() { - URI uri = URI.create("http://localhost/path"); + URI uri = URI.create("https://localhost/path"); RequestPredicate predicate = RequestPredicates.path("/p*"); - MockServerRequest request = MockServerRequest.builder().uri(uri).build(); + MockServerHttpRequest mockRequest = MockServerHttpRequest.get(uri.toString()).build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList()); assertThat(predicate.test(request)).isTrue(); - request = MockServerRequest.builder().build(); + mockRequest = MockServerHttpRequest.head("https://example.com").build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isFalse(); } @Test public void pathNoLeadingSlash() { - URI uri = URI.create("http://localhost/path"); RequestPredicate predicate = RequestPredicates.path("p*"); - MockServerRequest request = MockServerRequest.builder().uri(uri).build(); + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com/path").build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); } @Test public void pathEncoded() { - URI uri = URI.create("http://localhost/foo%20bar"); + URI uri = URI.create("https://localhost/foo%20bar"); RequestPredicate predicate = RequestPredicates.path("/foo bar"); - MockServerRequest request = MockServerRequest.builder().uri(uri).build(); + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); - - request = MockServerRequest.builder().build(); - assertThat(predicate.test(request)).isFalse(); } @Test @@ -133,9 +172,9 @@ public class RequestPredicatesTests { parser.setCaseSensitive(false); Function pathPredicates = RequestPredicates.pathPredicates(parser); - URI uri = URI.create("http://localhost/path"); RequestPredicate predicate = pathPredicates.apply("/P*"); - MockServerRequest request = MockServerRequest.builder().uri(uri).build(); + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com/path").build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); } @@ -146,10 +185,16 @@ public class RequestPredicatesTests { RequestPredicate predicate = RequestPredicates.headers( headers -> headers.header(name).equals(Collections.singletonList(value))); - MockServerRequest request = MockServerRequest.builder().header(name, value).build(); + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com") + .header(name, value) + .build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); - request = MockServerRequest.builder().build(); + mockRequest = MockServerHttpRequest.get("https://example.com") + .header(name, "bar") + .build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isFalse(); } @@ -157,10 +202,16 @@ public class RequestPredicatesTests { public void contentType() { MediaType json = MediaType.APPLICATION_JSON; RequestPredicate predicate = RequestPredicates.contentType(json); - MockServerRequest request = MockServerRequest.builder().header("Content-Type", json.toString()).build(); + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com") + .header(HttpHeaders.CONTENT_TYPE, json.toString()) + .build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); - request = MockServerRequest.builder().build(); + mockRequest = MockServerHttpRequest.get("https://example.com") + .header(HttpHeaders.CONTENT_TYPE, "foo/bar") + .build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isFalse(); } @@ -168,10 +219,16 @@ public class RequestPredicatesTests { public void accept() { MediaType json = MediaType.APPLICATION_JSON; RequestPredicate predicate = RequestPredicates.accept(json); - MockServerRequest request = MockServerRequest.builder().header("Accept", json.toString()).build(); + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com") + .header(HttpHeaders.ACCEPT, json.toString()) + .build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); - request = MockServerRequest.builder().header("Accept", MediaType.TEXT_XML_VALUE).build(); + mockRequest = MockServerHttpRequest.get("https://example.com") + .header(HttpHeaders.ACCEPT, "foo/bar") + .build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isFalse(); } @@ -179,25 +236,30 @@ public class RequestPredicatesTests { public void pathExtension() { RequestPredicate predicate = RequestPredicates.pathExtension("txt"); - URI uri = URI.create("http://localhost/file.txt"); - MockServerRequest request = MockServerRequest.builder().uri(uri).build(); + URI uri = URI.create("https://localhost/file.txt"); + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); - uri = URI.create("http://localhost/FILE.TXT"); - request = MockServerRequest.builder().uri(uri).build(); + uri = URI.create("https://localhost/FILE.TXT"); + mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); predicate = RequestPredicates.pathExtension("bar"); assertThat(predicate.test(request)).isFalse(); - uri = URI.create("http://localhost/file.foo"); - request = MockServerRequest.builder().uri(uri).build(); + uri = URI.create("https://localhost/file.foo"); + mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isFalse(); } @Test public void queryParam() { - MockServerRequest request = MockServerRequest.builder().queryParam("foo", "bar").build(); + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com") + .queryParam("foo", "bar").build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); RequestPredicate predicate = RequestPredicates.queryParam("foo", s -> s.equals("bar")); assertThat(predicate.test(request)).isTrue(); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java index 171aa4fabc4..667c211932e 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,6 +53,7 @@ import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; +import org.springframework.web.cors.CorsUtils; import org.springframework.web.util.UriBuilder; import org.springframework.web.util.UriUtils; import org.springframework.web.util.pattern.PathPattern; @@ -444,11 +445,24 @@ public abstract class RequestPredicates { @Override public boolean test(ServerRequest request) { - boolean match = this.httpMethods.contains(request.method()); - traceMatch("Method", this.httpMethods, request.method(), match); + HttpMethod method = method(request); + boolean match = this.httpMethods.contains(method); + traceMatch("Method", this.httpMethods, method, match); return match; } + @Nullable + private static HttpMethod method(ServerRequest request) { + if (CorsUtils.isPreFlightRequest(request.servletRequest())) { + String accessControlRequestMethod = + request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); + return HttpMethod.resolve(accessControlRequestMethod); + } + else { + return request.method(); + } + } + @Override public void accept(Visitor visitor) { visitor.method(Collections.unmodifiableSet(this.httpMethods)); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java index 6af031ba393..ab079610244 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import java.util.function.Function; import org.junit.jupiter.api.Test; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; @@ -57,6 +58,22 @@ public class RequestPredicatesTests { assertThat(predicate.test(request)).isFalse(); } + @Test + public void methodCorsPreFlight() { + RequestPredicate predicate = RequestPredicates.method(HttpMethod.PUT); + + MockHttpServletRequest servletRequest = new MockHttpServletRequest("OPTIONS", "https://example.com"); + servletRequest.addHeader("Origin", "https://example.com"); + servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "PUT"); + ServerRequest request = new DefaultServerRequest(servletRequest, emptyList()); + assertThat(predicate.test(request)).isTrue(); + + servletRequest.removeHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); + servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST"); + request = new DefaultServerRequest(servletRequest, emptyList()); + assertThat(predicate.test(request)).isFalse(); + } + @Test public void methods() { RequestPredicate predicate = RequestPredicates.methods(HttpMethod.GET, HttpMethod.HEAD);