diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/condition/ConsumesRequestCondition.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/condition/ConsumesRequestCondition.java index 671780e6213..6592b6d4444 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/condition/ConsumesRequestCondition.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/condition/ConsumesRequestCondition.java @@ -23,10 +23,13 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Set; +import org.springframework.http.HttpHeaders; import org.springframework.http.InvalidMediaTypeException; import org.springframework.http.MediaType; +import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.cors.reactive.CorsUtils; import org.springframework.web.server.ServerWebExchange; @@ -50,6 +53,8 @@ public final class ConsumesRequestCondition extends AbstractRequestCondition expressions; + private boolean bodyRequired = true; + /** * Creates a new instance from 0 or more "consumes" expressions. @@ -141,6 +146,29 @@ public final class ConsumesRequestCondition extends AbstractRequestConditionBy default this is set to {@code true} in which case it is assumed a + * request body is required and this condition matches to the "Content-Type" + * header or falls back on "Content-Type: application/octet-stream". + *

If set to {@code false}, and the request does not have a body, then this + * condition matches automatically, i.e. without checking expressions. + * @param bodyRequired whether requests are expected to have a body + * @since 5.2 + */ + public void setBodyRequired(boolean bodyRequired) { + this.bodyRequired = bodyRequired; + } + + /** + * Return the setting for {@link #setBodyRequired(boolean)}. + * @since 5.2 + */ + public boolean isBodyRequired() { + return this.bodyRequired; + } + + /** * Returns the "other" instance if it has any expressions; returns "this" * instance otherwise. Practically that means a method-level "consumes" @@ -163,16 +191,27 @@ public final class ConsumesRequestCondition extends AbstractRequestCondition result = getMatchingExpressions(exchange); return !CollectionUtils.isEmpty(result) ? new ConsumesRequestCondition(result) : null; } + private boolean hasBody(ServerHttpRequest request) { + String contentLength = request.getHeaders().getFirst(HttpHeaders.CONTENT_LENGTH); + String transferEncoding = request.getHeaders().getFirst(HttpHeaders.TRANSFER_ENCODING); + return StringUtils.hasText(transferEncoding) || + (StringUtils.hasText(contentLength) && !contentLength.trim().equals("0")); + } + @Nullable private List getMatchingExpressions(ServerWebExchange exchange) { List result = null; diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java index 5035390f99c..7f5051f633e 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java @@ -18,6 +18,7 @@ package org.springframework.web.reactive.result.method.annotation; import java.lang.reflect.AnnotatedElement; import java.lang.reflect.Method; +import java.lang.reflect.Parameter; import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; @@ -25,6 +26,8 @@ import java.util.function.Predicate; import org.springframework.context.EmbeddedValueResolverAware; import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.core.annotation.MergedAnnotations; import org.springframework.lang.Nullable; import org.springframework.stereotype.Controller; import org.springframework.util.Assert; @@ -32,12 +35,14 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.util.StringValueResolver; import org.springframework.web.bind.annotation.CrossOrigin; +import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.method.HandlerMethod; import org.springframework.web.reactive.accept.RequestedContentTypeResolver; import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder; +import org.springframework.web.reactive.result.condition.ConsumesRequestCondition; import org.springframework.web.reactive.result.condition.RequestCondition; import org.springframework.web.reactive.result.method.RequestMappingInfo; import org.springframework.web.reactive.result.method.RequestMappingInfoHandlerMapping; @@ -255,6 +260,31 @@ public class RequestMappingHandlerMapping extends RequestMappingInfoHandlerMappi } } + @Override + public void registerMapping(RequestMappingInfo mapping, Object handler, Method method) { + super.registerMapping(mapping, handler, method); + updateConsumesCondition(mapping, method); + } + + @Override + protected void registerHandlerMethod(Object handler, Method method, RequestMappingInfo mapping) { + super.registerHandlerMethod(handler, method, mapping); + updateConsumesCondition(mapping, method); + } + + private void updateConsumesCondition(RequestMappingInfo info, Method method) { + ConsumesRequestCondition condition = info.getConsumesCondition(); + if (!condition.isEmpty()) { + for (Parameter parameter : method.getParameters()) { + MergedAnnotation annot = MergedAnnotations.from(parameter).get(RequestBody.class); + if (annot.isPresent()) { + condition.setBodyRequired(annot.getBoolean("required")); + break; + } + } + } + } + @Override protected CorsConfiguration initCorsConfiguration(Object handler, Method method, RequestMappingInfo mappingInfo) { HandlerMethod handlerMethod = createHandlerMethod(handler, method); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/condition/ConsumesRequestConditionTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/condition/ConsumesRequestConditionTests.java index 3b1087fe5a3..ff6264646b3 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/condition/ConsumesRequestConditionTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/condition/ConsumesRequestConditionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,6 +99,24 @@ public class ConsumesRequestConditionTests { assertNull(condition.getMatchingCondition(exchange)); } + @Test // gh-22010 + public void consumesNoContent() { + ConsumesRequestCondition condition = new ConsumesRequestCondition("text/plain"); + condition.setBodyRequired(false); + + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + assertNotNull(condition.getMatchingCondition(MockServerWebExchange.from(request))); + + request = MockServerHttpRequest.get("/").header(HttpHeaders.CONTENT_LENGTH, "0").build(); + assertNotNull(condition.getMatchingCondition(MockServerWebExchange.from(request))); + + request = MockServerHttpRequest.get("/").header(HttpHeaders.CONTENT_LENGTH, "21").build(); + assertNull(condition.getMatchingCondition(MockServerWebExchange.from(request))); + + request = MockServerHttpRequest.get("/").header(HttpHeaders.TRANSFER_ENCODING, "chunked").build(); + assertNull(condition.getMatchingCondition(MockServerWebExchange.from(request))); + } + @Test public void compareToSingle() throws Exception { MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMappingTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMappingTests.java index 921b5a3772a..e31775eb6da 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMappingTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMappingTests.java @@ -22,7 +22,6 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.lang.reflect.Method; import java.security.Principal; -import java.util.ArrayList; import java.util.Collections; import java.util.Set; @@ -32,16 +31,20 @@ import org.junit.Test; import org.springframework.core.annotation.AliasFor; import org.springframework.http.MediaType; import org.springframework.stereotype.Controller; +import org.springframework.util.ClassUtils; import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PatchMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PutMapping; +import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.support.StaticWebApplicationContext; import org.springframework.web.method.HandlerTypePredicate; +import org.springframework.web.reactive.result.condition.ConsumesRequestCondition; +import org.springframework.web.reactive.result.condition.PatternsRequestCondition; import org.springframework.web.reactive.result.method.RequestMappingInfo; import org.springframework.web.util.pattern.PathPattern; import org.springframework.web.util.pattern.PathPatternParser; @@ -103,10 +106,26 @@ public class RequestMappingHandlerMappingTests { @Test // SPR-14988 public void getMappingOverridesConsumesFromTypeLevelAnnotation() throws Exception { - RequestMappingInfo requestMappingInfo = assertComposedAnnotationMapping(RequestMethod.GET); + RequestMappingInfo requestMappingInfo = assertComposedAnnotationMapping(RequestMethod.POST); - assertArrayEquals(new MediaType[]{MediaType.ALL}, new ArrayList<>( - requestMappingInfo.getConsumesCondition().getConsumableMediaTypes()).toArray()); + ConsumesRequestCondition condition = requestMappingInfo.getConsumesCondition(); + assertEquals(Collections.singleton(MediaType.APPLICATION_XML), condition.getConsumableMediaTypes()); + } + + @Test // gh-22010 + public void consumesWithOptionalRequestBody() { + this.wac.registerSingleton("testController", ComposedAnnotationController.class); + this.wac.refresh(); + this.handlerMapping.afterPropertiesSet(); + RequestMappingInfo info = this.handlerMapping.getHandlerMethods().keySet().stream() + .filter(i -> { + PatternsRequestCondition condition = i.getPatternsCondition(); + return condition.getPatterns().iterator().next().getPatternString().equals("/post"); + }) + .findFirst() + .orElseThrow(() -> new AssertionError("No /post")); + + assertFalse(info.getConsumesCondition().isBodyRequired()); } @Test @@ -146,7 +165,7 @@ public class RequestMappingHandlerMappingTests { RequestMethod requestMethod) throws Exception { Class clazz = ComposedAnnotationController.class; - Method method = clazz.getMethod(methodName); + Method method = ClassUtils.getMethod(clazz, methodName, null); RequestMappingInfo info = this.handlerMapping.getMappingForMethod(method, clazz); assertNotNull(info); @@ -175,12 +194,12 @@ public class RequestMappingHandlerMappingTests { public void postJson() { } - @GetMapping(value = "/get", consumes = MediaType.ALL_VALUE) + @GetMapping("/get") public void get() { } - @PostMapping("/post") - public void post() { + @PostMapping(path = "/post", consumes = MediaType.APPLICATION_XML_VALUE) + public void post(@RequestBody(required = false) Foo foo) { } @PutMapping("/put") @@ -196,6 +215,9 @@ public class RequestMappingHandlerMappingTests { } } + private static class Foo { + } + @RequestMapping(method = RequestMethod.POST, produces = MediaType.APPLICATION_JSON_VALUE, diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/condition/ConsumesRequestCondition.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/condition/ConsumesRequestCondition.java index 396e6bf0526..4a4163ea137 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/condition/ConsumesRequestCondition.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/condition/ConsumesRequestCondition.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Set; import javax.servlet.http.HttpServletRequest; +import org.springframework.http.HttpHeaders; import org.springframework.http.InvalidMediaTypeException; import org.springframework.http.MediaType; import org.springframework.lang.Nullable; @@ -49,8 +50,11 @@ public final class ConsumesRequestCondition extends AbstractRequestCondition expressions; + private boolean bodyRequired = true; + /** * Creates a new instance from 0 or more "consumes" expressions. @@ -141,6 +145,29 @@ public final class ConsumesRequestCondition extends AbstractRequestConditionBy default this is set to {@code true} in which case it is assumed a + * request body is required and this condition matches to the "Content-Type" + * header or falls back on "Content-Type: application/octet-stream". + *

If set to {@code false}, and the request does not have a body, then this + * condition matches automatically, i.e. without checking expressions. + * @param bodyRequired whether requests are expected to have a body + * @since 5.2 + */ + public void setBodyRequired(boolean bodyRequired) { + this.bodyRequired = bodyRequired; + } + + /** + * Return the setting for {@link #setBodyRequired(boolean)}. + * @since 5.2 + */ + public boolean isBodyRequired() { + return this.bodyRequired; + } + + /** * Returns the "other" instance if it has any expressions; returns "this" * instance otherwise. Practically that means a method-level "consumes" @@ -170,14 +197,17 @@ public final class ConsumesRequestCondition extends AbstractRequestCondition getMatchingExpressions(MediaType contentType) { List result = null; diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMapping.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMapping.java index 3e1a772a8ef..e6580b4d5f8 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMapping.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMapping.java @@ -18,6 +18,7 @@ package org.springframework.web.servlet.mvc.method.annotation; import java.lang.reflect.AnnotatedElement; import java.lang.reflect.Method; +import java.lang.reflect.Parameter; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; @@ -28,6 +29,8 @@ import javax.servlet.http.HttpServletRequest; import org.springframework.context.EmbeddedValueResolverAware; import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.core.annotation.MergedAnnotations; import org.springframework.lang.Nullable; import org.springframework.stereotype.Controller; import org.springframework.util.Assert; @@ -35,6 +38,7 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringValueResolver; import org.springframework.web.accept.ContentNegotiationManager; import org.springframework.web.bind.annotation.CrossOrigin; +import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.cors.CorsConfiguration; @@ -43,6 +47,7 @@ import org.springframework.web.servlet.handler.MatchableHandlerMapping; import org.springframework.web.servlet.handler.RequestMatchResult; import org.springframework.web.servlet.mvc.condition.AbstractRequestCondition; import org.springframework.web.servlet.mvc.condition.CompositeRequestCondition; +import org.springframework.web.servlet.mvc.condition.ConsumesRequestCondition; import org.springframework.web.servlet.mvc.condition.RequestCondition; import org.springframework.web.servlet.mvc.method.RequestMappingInfo; import org.springframework.web.servlet.mvc.method.RequestMappingInfoHandlerMapping; @@ -332,6 +337,31 @@ public class RequestMappingHandlerMapping extends RequestMappingInfoHandlerMappi } } + @Override + public void registerMapping(RequestMappingInfo mapping, Object handler, Method method) { + super.registerMapping(mapping, handler, method); + updateConsumesCondition(mapping, method); + } + + @Override + protected void registerHandlerMethod(Object handler, Method method, RequestMappingInfo mapping) { + super.registerHandlerMethod(handler, method, mapping); + updateConsumesCondition(mapping, method); + } + + private void updateConsumesCondition(RequestMappingInfo info, Method method) { + ConsumesRequestCondition condition = info.getConsumesCondition(); + if (!condition.isEmpty()) { + for (Parameter parameter : method.getParameters()) { + MergedAnnotation annot = MergedAnnotations.from(parameter).get(RequestBody.class); + if (annot.isPresent()) { + condition.setBodyRequired(annot.getBoolean("required")); + break; + } + } + } + } + @Override public RequestMatchResult match(HttpServletRequest request, String pattern) { RequestMappingInfo info = RequestMappingInfo.paths(pattern).options(this.config).build(); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/condition/ConsumesRequestConditionTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/condition/ConsumesRequestConditionTests.java index fc128b4b09b..76b2fbebe67 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/condition/ConsumesRequestConditionTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/condition/ConsumesRequestConditionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import java.util.Collections; import org.junit.Test; +import org.springframework.http.HttpHeaders; import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.web.servlet.mvc.condition.ConsumesRequestCondition.ConsumeMediaTypeExpression; @@ -107,6 +108,27 @@ public class ConsumesRequestConditionTests { assertNull(condition.getMatchingCondition(request)); } + @Test // gh-22010 + public void consumesNoContent() { + ConsumesRequestCondition condition = new ConsumesRequestCondition("text/plain"); + condition.setBodyRequired(false); + + MockHttpServletRequest request = new MockHttpServletRequest(); + assertNotNull(condition.getMatchingCondition(request)); + + request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.CONTENT_LENGTH, "0"); + assertNotNull(condition.getMatchingCondition(request)); + + request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.CONTENT_LENGTH, "21"); + assertNull(condition.getMatchingCondition(request)); + + request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.TRANSFER_ENCODING, "chunked"); + assertNull(condition.getMatchingCondition(request)); + } + @Test public void compareToSingle() { MockHttpServletRequest request = new MockHttpServletRequest(); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMappingTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMappingTests.java index d0f227b5333..8564c30cf67 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMappingTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMappingTests.java @@ -22,7 +22,6 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.lang.reflect.Method; import java.security.Principal; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -34,6 +33,7 @@ import org.junit.Test; import org.springframework.core.annotation.AliasFor; import org.springframework.http.MediaType; import org.springframework.stereotype.Controller; +import org.springframework.util.ClassUtils; import org.springframework.web.accept.ContentNegotiationManager; import org.springframework.web.accept.PathExtensionContentNegotiationStrategy; import org.springframework.web.bind.annotation.DeleteMapping; @@ -41,11 +41,14 @@ import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PatchMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PutMapping; +import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.support.StaticWebApplicationContext; +import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.HandlerTypePredicate; +import org.springframework.web.servlet.mvc.condition.ConsumesRequestCondition; import org.springframework.web.servlet.mvc.method.RequestMappingInfo; import static org.junit.Assert.*; @@ -165,10 +168,23 @@ public class RequestMappingHandlerMappingTests { @Test // SPR-14988 public void getMappingOverridesConsumesFromTypeLevelAnnotation() throws Exception { - RequestMappingInfo requestMappingInfo = assertComposedAnnotationMapping(RequestMethod.GET); + RequestMappingInfo requestMappingInfo = assertComposedAnnotationMapping(RequestMethod.POST); - assertArrayEquals(new MediaType[]{MediaType.ALL}, new ArrayList<>( - requestMappingInfo.getConsumesCondition().getConsumableMediaTypes()).toArray()); + ConsumesRequestCondition condition = requestMappingInfo.getConsumesCondition(); + assertEquals(Collections.singleton(MediaType.APPLICATION_XML), condition.getConsumableMediaTypes()); + } + + @Test // gh-22010 + public void consumesWithOptionalRequestBody() { + this.wac.registerSingleton("testController", ComposedAnnotationController.class); + this.wac.refresh(); + this.handlerMapping.afterPropertiesSet(); + RequestMappingInfo info = this.handlerMapping.getHandlerMethods().keySet().stream() + .filter(i -> i.getPatternsCondition().getPatterns().equals(Collections.singleton("/post"))) + .findFirst() + .orElseThrow(() -> new AssertionError("No /post")); + + assertFalse(info.getConsumesCondition().isBodyRequired()); } @Test @@ -207,7 +223,7 @@ public class RequestMappingHandlerMappingTests { RequestMethod requestMethod) throws Exception { Class clazz = ComposedAnnotationController.class; - Method method = clazz.getMethod(methodName); + Method method = ClassUtils.getMethod(clazz, methodName, null); RequestMappingInfo info = this.handlerMapping.getMappingForMethod(method, clazz); assertNotNull(info); @@ -236,12 +252,12 @@ public class RequestMappingHandlerMappingTests { public void postJson() { } - @GetMapping(path = "/get", consumes = MediaType.ALL_VALUE) + @GetMapping("/get") public void get() { } - @PostMapping("/post") - public void post() { + @PostMapping(path = "/post", consumes = MediaType.APPLICATION_XML_VALUE) + public void post(@RequestBody(required = false) Foo foo) { } @PutMapping("/put") @@ -281,4 +297,8 @@ public class RequestMappingHandlerMappingTests { } } + + private static class Foo { + } + }