Improve support for Mono<ResponseEntity<?>>

If the body class is not resolvable from the return type and there is
a body instance we now fall back on the class of the body instance.

Issue: SPR-14877
This commit is contained in:
Rossen Stoyanchev
2016-11-04 12:51:00 +02:00
parent 84d3808b3b
commit c430402872
6 changed files with 57 additions and 30 deletions
@@ -22,6 +22,7 @@ import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.function.Supplier;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
@@ -101,13 +102,14 @@ public abstract class AbstractHandlerResultHandler implements Ordered {
* Select the best media type for the current request through a content * Select the best media type for the current request through a content
* negotiation algorithm. * negotiation algorithm.
* @param exchange the current request * @param exchange the current request
* @param producibleTypes the media types that can be produced for the current request * @param producibleTypesSupplier the media types that can be produced for the current request
* @return the selected media type or {@code null} * @return the selected media type or {@code null}
*/ */
protected MediaType selectMediaType(ServerWebExchange exchange, List<MediaType> producibleTypes) { protected MediaType selectMediaType(ServerWebExchange exchange,
Supplier<List<MediaType>> producibleTypesSupplier) {
List<MediaType> acceptableTypes = getAcceptableTypes(exchange); List<MediaType> acceptableTypes = getAcceptableTypes(exchange);
producibleTypes = getProducibleTypes(exchange, producibleTypes); List<MediaType> producibleTypes = getProducibleTypes(exchange, producibleTypesSupplier);
Set<MediaType> compatibleMediaTypes = new LinkedHashSet<>(); Set<MediaType> compatibleMediaTypes = new LinkedHashSet<>();
for (MediaType acceptable : acceptableTypes) { for (MediaType acceptable : acceptableTypes) {
@@ -139,13 +141,15 @@ public abstract class AbstractHandlerResultHandler implements Ordered {
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private List<MediaType> getProducibleTypes(ServerWebExchange exchange, List<MediaType> mediaTypes) { private List<MediaType> getProducibleTypes(ServerWebExchange exchange,
Supplier<List<MediaType>> producibleTypesSupplier) {
Optional<Object> optional = exchange.getAttribute(HandlerMapping.PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE); Optional<Object> optional = exchange.getAttribute(HandlerMapping.PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE);
if (optional.isPresent()) { if (optional.isPresent()) {
Set<MediaType> set = (Set<MediaType>) optional.get(); Set<MediaType> set = (Set<MediaType>) optional.get();
return new ArrayList<>(set); return new ArrayList<>(set);
} }
return mediaTypes; return producibleTypesSupplier.get();
} }
private MediaType selectMoreSpecificMediaType(MediaType acceptable, MediaType producible) { private MediaType selectMoreSpecificMediaType(MediaType acceptable, MediaType producible) {
@@ -93,8 +93,9 @@ public abstract class AbstractMessageWriterResultHandler extends AbstractHandler
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
protected Mono<Void> writeBody(Object body, MethodParameter bodyParameter, ServerWebExchange exchange) { protected Mono<Void> writeBody(Object body, MethodParameter bodyParameter, ServerWebExchange exchange) {
ResolvableType bodyType = ResolvableType.forMethodParameter(bodyParameter); ResolvableType valueType = ResolvableType.forMethodParameter(bodyParameter);
ReactiveAdapter adapter = getAdapterRegistry().getAdapterFrom(bodyType.resolve(), body); Class<?> valueClass = valueType.resolve();
ReactiveAdapter adapter = getAdapterRegistry().getAdapterFrom(valueClass, body);
Publisher<?> publisher; Publisher<?> publisher;
ResolvableType elementType; ResolvableType elementType;
@@ -102,11 +103,11 @@ public abstract class AbstractMessageWriterResultHandler extends AbstractHandler
publisher = adapter.toPublisher(body); publisher = adapter.toPublisher(body);
elementType = adapter.getDescriptor().isNoValue() ? elementType = adapter.getDescriptor().isNoValue() ?
ResolvableType.forClass(Void.class) : ResolvableType.forClass(Void.class) :
bodyType.getGeneric(0); valueType.getGeneric(0);
} }
else { else {
publisher = Mono.justOrEmpty(body); publisher = Mono.justOrEmpty(body);
elementType = bodyType; elementType = (valueClass == null && body != null ? ResolvableType.forInstance(body) : valueType);
} }
if (void.class == elementType.getRawClass() || Void.class == elementType.getRawClass()) { if (void.class == elementType.getRawClass() || Void.class == elementType.getRawClass()) {
@@ -114,29 +115,29 @@ public abstract class AbstractMessageWriterResultHandler extends AbstractHandler
.doOnSubscribe(sub -> updateResponseStatus(bodyParameter, exchange)); .doOnSubscribe(sub -> updateResponseStatus(bodyParameter, exchange));
} }
List<MediaType> producibleTypes = getProducibleMediaTypes(elementType);
if (producibleTypes.isEmpty()) {
return Mono.error(new IllegalStateException(
"No converter for return value type: " + elementType));
}
ServerHttpRequest request = exchange.getRequest(); ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
MediaType bestMediaType = selectMediaType(exchange, producibleTypes); MediaType bestMediaType = selectMediaType(exchange, () -> getProducibleMediaTypes(elementType));
if (bestMediaType != null) { if (bestMediaType != null) {
for (HttpMessageWriter<?> messageWriter : getMessageWriters()) { for (HttpMessageWriter<?> messageWriter : getMessageWriters()) {
if (messageWriter.canWrite(elementType, bestMediaType)) { if (messageWriter.canWrite(elementType, bestMediaType)) {
Mono<Void> bodyWriter = (messageWriter instanceof ServerHttpMessageWriter ? Mono<Void> bodyWriter = (messageWriter instanceof ServerHttpMessageWriter ?
((ServerHttpMessageWriter<?>) messageWriter).write((Publisher) publisher, ((ServerHttpMessageWriter<?>) messageWriter).write((Publisher) publisher,
bodyType, elementType, bestMediaType, request, response, Collections.emptyMap()) : valueType, elementType, bestMediaType, request, response, Collections.emptyMap()) :
messageWriter.write((Publisher) publisher, elementType, messageWriter.write((Publisher) publisher, elementType,
bestMediaType, response, Collections.emptyMap())); bestMediaType, response, Collections.emptyMap()));
return bodyWriter.doOnSubscribe(sub -> updateResponseStatus(bodyParameter, exchange)); return bodyWriter.doOnSubscribe(sub -> updateResponseStatus(bodyParameter, exchange));
} }
} }
} }
else {
if (getProducibleMediaTypes(elementType).isEmpty()) {
return Mono.error(new IllegalStateException(
"No converter for return value type: " + elementType));
}
}
return Mono.error(new NotAcceptableStatusException(producibleTypes)); return Mono.error(new NotAcceptableStatusException(getProducibleMediaTypes(elementType)));
} }
private List<MediaType> getProducibleMediaTypes(ResolvableType elementType) { private List<MediaType> getProducibleMediaTypes(ResolvableType elementType) {
@@ -319,7 +319,7 @@ public class ViewResolutionResultHandler extends AbstractHandlerResultHandler
views.addAll(getDefaultViews()); views.addAll(getDefaultViews());
List<MediaType> producibleTypes = getProducibleMediaTypes(views); List<MediaType> producibleTypes = getProducibleMediaTypes(views);
MediaType bestMediaType = selectMediaType(exchange, producibleTypes); MediaType bestMediaType = selectMediaType(exchange, () -> producibleTypes);
if (bestMediaType != null) { if (bestMediaType != null) {
for (View view : views) { for (View view : views) {
@@ -71,7 +71,7 @@ public class HandlerResultHandlerTests {
public void usesContentTypeResolver() throws Exception { public void usesContentTypeResolver() throws Exception {
TestResultHandler resultHandler = new TestResultHandler(new FixedContentTypeResolver(IMAGE_GIF)); TestResultHandler resultHandler = new TestResultHandler(new FixedContentTypeResolver(IMAGE_GIF));
List<MediaType> mediaTypes = Arrays.asList(IMAGE_JPEG, IMAGE_GIF, IMAGE_PNG); List<MediaType> mediaTypes = Arrays.asList(IMAGE_JPEG, IMAGE_GIF, IMAGE_PNG);
MediaType actual = resultHandler.selectMediaType(this.exchange, mediaTypes); MediaType actual = resultHandler.selectMediaType(this.exchange, () -> mediaTypes);
assertEquals(IMAGE_GIF, actual); assertEquals(IMAGE_GIF, actual);
} }
@@ -82,7 +82,7 @@ public class HandlerResultHandlerTests {
this.exchange.getAttributes().put(PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE, producible); this.exchange.getAttributes().put(PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE, producible);
List<MediaType> mediaTypes = Arrays.asList(IMAGE_JPEG, IMAGE_GIF, IMAGE_PNG); List<MediaType> mediaTypes = Arrays.asList(IMAGE_JPEG, IMAGE_GIF, IMAGE_PNG);
MediaType actual = resultHandler.selectMediaType(this.exchange, mediaTypes); MediaType actual = resultHandler.selectMediaType(this.exchange, () -> mediaTypes);
assertEquals(IMAGE_GIF, actual); assertEquals(IMAGE_GIF, actual);
} }
@@ -92,7 +92,7 @@ public class HandlerResultHandlerTests {
this.request.setHeader("Accept", "text/plain; q=0.5, application/json"); this.request.setHeader("Accept", "text/plain; q=0.5, application/json");
List<MediaType> mediaTypes = Arrays.asList(TEXT_PLAIN, APPLICATION_JSON_UTF8); List<MediaType> mediaTypes = Arrays.asList(TEXT_PLAIN, APPLICATION_JSON_UTF8);
MediaType actual = this.resultHandler.selectMediaType(this.exchange, mediaTypes); MediaType actual = this.resultHandler.selectMediaType(this.exchange, () -> mediaTypes);
assertEquals(APPLICATION_JSON_UTF8, actual); assertEquals(APPLICATION_JSON_UTF8, actual);
} }
@@ -102,7 +102,8 @@ public class HandlerResultHandlerTests {
MediaType text8859 = MediaType.parseMediaType("text/plain;charset=ISO-8859-1"); MediaType text8859 = MediaType.parseMediaType("text/plain;charset=ISO-8859-1");
MediaType textUtf8 = MediaType.parseMediaType("text/plain;charset=UTF-8"); MediaType textUtf8 = MediaType.parseMediaType("text/plain;charset=UTF-8");
this.request.getHeaders().setAccept(Collections.singletonList(text8859)); this.request.getHeaders().setAccept(Collections.singletonList(text8859));
MediaType actual = this.resultHandler.selectMediaType(this.exchange, Collections.singletonList(textUtf8)); MediaType actual = this.resultHandler.selectMediaType(this.exchange,
() -> Collections.singletonList(textUtf8));
assertEquals(text8859, actual); assertEquals(text8859, actual);
} }
@@ -110,7 +111,7 @@ public class HandlerResultHandlerTests {
@Test // SPR-12894 @Test // SPR-12894
public void noConcreteMediaType() throws Exception { public void noConcreteMediaType() throws Exception {
List<MediaType> producible = Collections.singletonList(ALL); List<MediaType> producible = Collections.singletonList(ALL);
MediaType actual = this.resultHandler.selectMediaType(this.exchange, producible); MediaType actual = this.resultHandler.selectMediaType(this.exchange, () -> producible);
assertEquals(APPLICATION_OCTET_STREAM, actual); assertEquals(APPLICATION_OCTET_STREAM, actual);
} }
@@ -43,7 +43,6 @@ import org.springframework.core.codec.ByteBufferEncoder;
import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource; import org.springframework.core.io.Resource;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.support.DataBufferTestUtils; import org.springframework.core.io.buffer.support.DataBufferTestUtils;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.codec.EncoderHttpMessageWriter; import org.springframework.http.codec.EncoderHttpMessageWriter;
@@ -51,9 +50,9 @@ import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.http.codec.ResourceHttpMessageWriter; import org.springframework.http.codec.ResourceHttpMessageWriter;
import org.springframework.http.codec.json.Jackson2JsonEncoder; import org.springframework.http.codec.json.Jackson2JsonEncoder;
import org.springframework.http.codec.xml.Jaxb2XmlEncoder; import org.springframework.http.codec.xml.Jaxb2XmlEncoder;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.util.ObjectUtils; import org.springframework.util.ObjectUtils;
import org.springframework.web.reactive.accept.RequestedContentTypeResolver; import org.springframework.web.reactive.accept.RequestedContentTypeResolver;
import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder; import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder;
@@ -211,7 +210,7 @@ public class MessageWriterResultHandlerTests {
public ParentClass() { public ParentClass() {
} }
public ParentClass(String parentProperty) { ParentClass(String parentProperty) {
this.parentProperty = parentProperty; this.parentProperty = parentProperty;
} }
@@ -235,7 +234,7 @@ public class MessageWriterResultHandlerTests {
@JsonTypeName("bar") @JsonTypeName("bar")
private static class Bar extends ParentClass { private static class Bar extends ParentClass {
public Bar(String parentProperty) { Bar(String parentProperty) {
super(parentProperty); super(parentProperty);
} }
} }
@@ -253,7 +252,7 @@ public class MessageWriterResultHandlerTests {
private String name; private String name;
public SimpleBean(Long id, String name) { SimpleBean(Long id, String name) {
this.id = id; this.id = id;
this.name = name; this.name = name;
} }
@@ -23,6 +23,7 @@ import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@@ -37,11 +38,11 @@ import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.core.codec.ByteBufferEncoder; import org.springframework.core.codec.ByteBufferEncoder;
import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.support.DataBufferTestUtils; import org.springframework.core.io.buffer.support.DataBufferTestUtils;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.http.codec.EncoderHttpMessageWriter; import org.springframework.http.codec.EncoderHttpMessageWriter;
import org.springframework.http.codec.HttpMessageWriter; import org.springframework.http.codec.HttpMessageWriter;
@@ -51,6 +52,7 @@ import org.springframework.http.codec.xml.Jaxb2XmlEncoder;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import org.springframework.util.ObjectUtils; import org.springframework.util.ObjectUtils;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.reactive.accept.RequestedContentTypeResolver; import org.springframework.web.reactive.accept.RequestedContentTypeResolver;
import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder; import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder;
@@ -270,6 +272,23 @@ public class ResponseEntityResultHandlerTests {
assertConditionalResponse(HttpStatus.OK, "body", newEtag, oneMinAgo); assertConditionalResponse(HttpStatus.OK, "body", newEtag, oneMinAgo);
} }
@Test // SPR-14877
public void handleMonoWithWildcardBodyType() throws Exception {
this.exchange.getAttributes().put(HandlerMapping.PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE,
Collections.singleton(MediaType.APPLICATION_JSON));
HandlerResult result = new HandlerResult(new TestController(), Mono.just(ok().body("body")),
ResolvableMethod.onClass(TestController.class)
.name("monoResponseEntityWildcard")
.resolveReturnType());
this.resultHandler.handleResult(this.exchange, result).block(Duration.ofSeconds(5));
assertEquals(HttpStatus.OK, this.response.getStatusCode());
assertResponseBody("\"body\"");
}
private void testHandle(Object returnValue, ResolvableType type) { private void testHandle(Object returnValue, ResolvableType type) {
HandlerResult result = handlerResult(returnValue, type); HandlerResult result = handlerResult(returnValue, type);
@@ -333,6 +352,9 @@ public class ResponseEntityResultHandlerTests {
String string() { return null; } String string() { return null; }
Completable completable() { return null; } Completable completable() { return null; }
Mono<ResponseEntity<?>> monoResponseEntityWildcard() { return null; }
} }
} }