diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/AbstractHandlerMethodMapping.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/AbstractHandlerMethodMapping.java index 585bd125716..500887f979e 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/AbstractHandlerMethodMapping.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/AbstractHandlerMethodMapping.java @@ -262,6 +262,7 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap try { // Ensure form data is parsed for "params" conditions... return exchange.getRequestParams() + .then(exchange.getMultipartData()) .then(Mono.defer(() -> { HandlerMethod handlerMethod = null; try { diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolver.java index 8a7619925b6..840fd3d6efd 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolver.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolver.java @@ -133,6 +133,7 @@ class ControllerMethodResolver { // Annotation-based... registrar.add(new RequestParamMethodArgumentResolver(beanFactory, reactiveRegistry, false)); + registrar.add(new RequestPartMethodArgumentResolver(beanFactory, reactiveRegistry, false)); registrar.add(new RequestParamMapMethodArgumentResolver(reactiveRegistry)); registrar.add(new PathVariableMethodArgumentResolver(beanFactory, reactiveRegistry)); registrar.add(new PathVariableMapMethodArgumentResolver(reactiveRegistry)); diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestParamMapMethodArgumentResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestParamMapMethodArgumentResolver.java index 7635a55f3cb..b279473dcf3 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestParamMapMethodArgumentResolver.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestParamMapMethodArgumentResolver.java @@ -21,6 +21,8 @@ import java.util.Optional; import org.springframework.core.MethodParameter; import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.ResolvableType; +import org.springframework.http.codec.multipart.Part; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; @@ -42,6 +44,7 @@ import org.springframework.web.server.ServerWebExchange; * request parameters have multiple values. * * @author Rossen Stoyanchev + * @author Sebastien Deleuze * @since 5.0 * @see RequestParamMethodArgumentResolver */ @@ -67,12 +70,17 @@ public class RequestParamMapMethodArgumentResolver extends HandlerMethodArgument public Optional resolveArgumentValue(MethodParameter methodParameter, BindingContext context, ServerWebExchange exchange) { - Class paramType = methodParameter.getParameterType(); - boolean isMultiValueMap = MultiValueMap.class.isAssignableFrom(paramType); + ResolvableType paramType = ResolvableType.forType(methodParameter.getGenericParameterType()); + boolean isMultiValueMap = MultiValueMap.class.isAssignableFrom(paramType.getRawClass()); + + if (paramType.getGeneric(1).getRawClass() == Part.class) { + MultiValueMap requestParts = exchange.getMultipartData().subscribe().peek(); + Assert.notNull(requestParts, "Expected multipart data (if any) to be parsed."); + return Optional.of(isMultiValueMap ? requestParts : requestParts.toSingleValueMap()); + } MultiValueMap requestParams = exchange.getRequestParams().subscribe().peek(); Assert.notNull(requestParams, "Expected form data (if any) to be parsed."); - return Optional.of(isMultiValueMap ? requestParams : requestParams.toSingleValueMap()); } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestParamMethodArgumentResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestParamMethodArgumentResolver.java index 703618cf936..816c1036b53 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestParamMethodArgumentResolver.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestParamMethodArgumentResolver.java @@ -25,6 +25,7 @@ import org.springframework.beans.factory.config.ConfigurableBeanFactory; import org.springframework.core.MethodParameter; import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.convert.converter.Converter; +import org.springframework.http.codec.multipart.Part; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; @@ -102,7 +103,7 @@ public class RequestParamMethodArgumentResolver extends AbstractNamedValueSyncAr protected Optional resolveNamedValue(String name, MethodParameter parameter, ServerWebExchange exchange) { - List paramValues = getRequestParams(exchange).get(name); + List paramValues = parameter.getParameter().getType() == Part.class ? getMultipartData(exchange).get(name) : getRequestParams(exchange).get(name); Object result = null; if (paramValues != null) { result = (paramValues.size() == 1 ? paramValues.get(0) : paramValues); @@ -116,6 +117,12 @@ public class RequestParamMethodArgumentResolver extends AbstractNamedValueSyncAr return params; } + private MultiValueMap getMultipartData(ServerWebExchange exchange) { + MultiValueMap params = exchange.getMultipartData().subscribe().peek(); + Assert.notNull(params, "Expected multipart data (if any) to be parsed."); + return params; + } + @Override protected void handleMissingValue(String name, MethodParameter parameter, ServerWebExchange exchange) { String type = parameter.getNestedParameterType().getSimpleName(); diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestPartMethodArgumentResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestPartMethodArgumentResolver.java new file mode 100644 index 00000000000..7fc9acc4e8e --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestPartMethodArgumentResolver.java @@ -0,0 +1,128 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.result.method.annotation; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.core.MethodParameter; +import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.codec.multipart.Part; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RequestPart; +import org.springframework.web.bind.annotation.ValueConstants; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.ServerWebInputException; + +/** + * Resolver for method arguments annotated with @{@link RequestPart}. + * + * @author Sebastien Deleuze + * @since 5.0 + * @see RequestParamMapMethodArgumentResolver + */ +public class RequestPartMethodArgumentResolver extends AbstractNamedValueSyncArgumentResolver { + + private final boolean useDefaultResolution; + + + /** + * Class constructor with a default resolution mode flag. + * @param factory a bean factory used for resolving ${...} placeholder + * and #{...} SpEL expressions in default values, or {@code null} if default + * values are not expected to contain expressions + * @param registry for checking reactive type wrappers + * @param useDefaultResolution in default resolution mode a method argument + * that is a simple type, as defined in {@link BeanUtils#isSimpleProperty}, + * is treated as a request parameter even if it isn't annotated, the + * request parameter name is derived from the method parameter name. + */ + public RequestPartMethodArgumentResolver( + ConfigurableBeanFactory factory, ReactiveAdapterRegistry registry, boolean useDefaultResolution) { + + super(factory, registry); + this.useDefaultResolution = useDefaultResolution; + } + + + @Override + public boolean supportsParameter(MethodParameter param) { + if (checkAnnotatedParamNoReactiveWrapper(param, RequestPart.class, this::singleParam)) { + return true; + } + else if (this.useDefaultResolution) { + return checkParameterTypeNoReactiveWrapper(param, BeanUtils::isSimpleProperty) || + BeanUtils.isSimpleProperty(param.nestedIfOptional().getNestedParameterType()); + } + return false; + } + + private boolean singleParam(RequestPart requestParam, Class type) { + return !Map.class.isAssignableFrom(type) || StringUtils.hasText(requestParam.name()); + } + + @Override + protected NamedValueInfo createNamedValueInfo(MethodParameter parameter) { + RequestPart ann = parameter.getParameterAnnotation(RequestPart.class); + return (ann != null ? new RequestPartNamedValueInfo(ann) : new RequestPartNamedValueInfo()); + } + + @Override + protected Optional resolveNamedValue(String name, MethodParameter parameter, + ServerWebExchange exchange) { + + List paramValues = getMultipartData(exchange).get(name); + Object result = null; + if (paramValues != null) { + result = (paramValues.size() == 1 ? paramValues.get(0) : paramValues); + } + return Optional.ofNullable(result); + } + + private MultiValueMap getMultipartData(ServerWebExchange exchange) { + MultiValueMap params = exchange.getMultipartData().subscribe().peek(); + Assert.notNull(params, "Expected multipart data (if any) to be parsed."); + return params; + } + + @Override + protected void handleMissingValue(String name, MethodParameter parameter, ServerWebExchange exchange) { + String type = parameter.getNestedParameterType().getSimpleName(); + String reason = "Required " + type + " parameter '" + name + "' is not present"; + throw new ServerWebInputException(reason, parameter); + } + + + private static class RequestPartNamedValueInfo extends NamedValueInfo { + + RequestPartNamedValueInfo() { + super("", false, ValueConstants.DEFAULT_NONE); + } + + RequestPartNamedValueInfo(RequestPart annotation) { + super(annotation.name(), annotation.required(), ValueConstants.DEFAULT_NONE); + } + } + +} diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolverTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolverTests.java index 59dda9bc248..801dcce5fb3 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolverTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolverTests.java @@ -91,6 +91,7 @@ public class ControllerMethodResolverTests { AtomicInteger index = new AtomicInteger(-1); assertEquals(RequestParamMethodArgumentResolver.class, next(resolvers, index).getClass()); + assertEquals(RequestPartMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(RequestParamMapMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(PathVariableMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(PathVariableMapMethodArgumentResolver.class, next(resolvers, index).getClass()); @@ -129,6 +130,7 @@ public class ControllerMethodResolverTests { AtomicInteger index = new AtomicInteger(-1); assertEquals(RequestParamMethodArgumentResolver.class, next(resolvers, index).getClass()); + assertEquals(RequestPartMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(RequestParamMapMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(PathVariableMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(PathVariableMapMethodArgumentResolver.class, next(resolvers, index).getClass()); @@ -165,6 +167,7 @@ public class ControllerMethodResolverTests { AtomicInteger index = new AtomicInteger(-1); assertEquals(RequestParamMethodArgumentResolver.class, next(resolvers, index).getClass()); + assertEquals(RequestPartMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(RequestParamMapMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(PathVariableMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(PathVariableMapMethodArgumentResolver.class, next(resolvers, index).getClass()); @@ -195,6 +198,7 @@ public class ControllerMethodResolverTests { AtomicInteger index = new AtomicInteger(-1); assertEquals(RequestParamMethodArgumentResolver.class, next(resolvers, index).getClass()); + assertEquals(RequestPartMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(RequestParamMapMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(PathVariableMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(PathVariableMapMethodArgumentResolver.class, next(resolvers, index).getClass()); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java new file mode 100644 index 00000000000..33a08473443 --- /dev/null +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.result.method.annotation; + +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.io.ClassPathResource; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.multipart.Part; +import org.springframework.http.server.reactive.AbstractHttpHandlerIntegrationTests; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RequestPart; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.reactive.DispatcherHandler; +import org.springframework.web.reactive.config.EnableWebFlux; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.server.adapter.WebHttpHandlerBuilder; + +import static org.junit.Assert.assertEquals; + +public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + private AnnotationConfigApplicationContext wac; + + private WebClient webClient; + + @Override + @Before + public void setup() throws Exception { + super.setup(); + this.webClient = WebClient.create("http://localhost:" + this.port); + } + + + @Override + protected HttpHandler createHttpHandler() { + this.wac = new AnnotationConfigApplicationContext(); + this.wac.register(TestConfiguration.class); + this.wac.refresh(); + + return WebHttpHandlerBuilder.webHandler(new DispatcherHandler(this.wac)).build(); + } + + @Test + public void map() { + test("/map"); + } + + @Test + public void multiValueMap() { + test("/multivaluemap"); + } + + @Test + public void partParam() { + test("/partparam"); + } + + @Test + public void part() { + test("/part"); + } + + private void test(String uri) { + Mono result = webClient + .post() + .uri(uri) + .contentType(MediaType.MULTIPART_FORM_DATA) + .body(BodyInserters.fromMultipartData(generateBody())) + .exchange(); + + StepVerifier + .create(result) + .consumeNextWith(response -> assertEquals(HttpStatus.OK, response.statusCode())) + .verifyComplete(); + } + + private MultiValueMap generateBody() { + HttpHeaders fooHeaders = new HttpHeaders(); + fooHeaders.setContentType(MediaType.TEXT_PLAIN); + ClassPathResource fooResource = new ClassPathResource("org/springframework/http/codec/multipart/foo.txt"); + HttpEntity fooPart = new HttpEntity<>(fooResource, fooHeaders); + HttpEntity barPart = new HttpEntity<>("bar"); + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("fooPart", fooPart); + parts.add("barPart", barPart); + return parts; + } + + @RestController + @SuppressWarnings("unused") + static class MultipartController { + + @PostMapping("/map") + void map(@RequestParam Map parts) { + assertEquals(2, parts.size()); + assertEquals("foo.txt", parts.get("fooPart").getFilename().get()); + assertEquals("bar", parts.get("barPart").getContentAsString().block()); + } + + @PostMapping("/multivaluemap") + void multiValueMap(@RequestParam MultiValueMap parts) { + Map map = parts.toSingleValueMap(); + assertEquals(2, map.size()); + assertEquals("foo.txt", map.get("fooPart").getFilename().get()); + assertEquals("bar", map.get("barPart").getContentAsString().block()); + } + + @PostMapping("/partparam") + void partParam(@RequestParam Part fooPart) { + assertEquals("foo.txt", fooPart.getFilename().get()); + } + + @PostMapping("/part") + void part(@RequestPart Part fooPart) { + assertEquals("foo.txt", fooPart.getFilename().get()); + } + + } + + @Configuration + @EnableWebFlux + @SuppressWarnings("unused") + static class TestConfiguration { + + @Bean + public MultipartController multipartController() { + return new MultipartController(); + } + } + +}