diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandler.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandler.java index 42e50afe653..561bf0578dc 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandler.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandler.java @@ -79,20 +79,44 @@ public class ResponseEntityResultHandler extends AbstractMessageConverterResultH @Override public boolean supports(HandlerResult result) { ResolvableType returnType = result.getReturnValueType(); - return (HttpEntity.class.isAssignableFrom(returnType.getRawClass()) && - !RequestEntity.class.isAssignableFrom(returnType.getRawClass())); + if (isSupportedType(returnType)) { + return true; + } + else if (getConversionService().canConvert(returnType.getRawClass(), Mono.class)) { + ResolvableType genericType = result.getReturnValueType().getGeneric(0); + return isSupportedType(genericType); + + } + return false; + } + + private boolean isSupportedType(ResolvableType returnType) { + Class clazz = returnType.getRawClass(); + return (HttpEntity.class.isAssignableFrom(clazz) && !RequestEntity.class.isAssignableFrom(clazz)); } @Override public Mono handleResult(ServerWebExchange exchange, HandlerResult result) { - Object body = null; + ResolvableType returnType = result.getReturnValueType(); + Mono returnValueMono; + ResolvableType bodyType; Optional optional = result.getReturnValue(); - if (optional.isPresent()) { - Assert.isInstanceOf(HttpEntity.class, optional.get()); - HttpEntity httpEntity = (HttpEntity) optional.get(); + if (optional.isPresent() && getConversionService().canConvert(returnType.getRawClass(), Mono.class)) { + returnValueMono = getConversionService().convert(optional.get(), Mono.class); + bodyType = returnType.getGeneric(0).getGeneric(0); + } + else { + returnValueMono = Mono.justOrEmpty(optional); + bodyType = returnType.getGeneric(0); + } + + return returnValueMono.then(returnValue -> { + + Assert.isInstanceOf(HttpEntity.class, returnValue); + HttpEntity httpEntity = (HttpEntity) returnValue; if (httpEntity instanceof ResponseEntity) { ResponseEntity responseEntity = (ResponseEntity) httpEntity; @@ -108,11 +132,8 @@ public class ResponseEntityResultHandler extends AbstractMessageConverterResultH .forEach(entry -> responseHeaders.put(entry.getKey(), entry.getValue())); } - body = httpEntity.getBody(); - } - - ResolvableType bodyType = result.getReturnValueType().getGeneric(0); - return writeBody(exchange, body, bodyType); + return writeBody(exchange, httpEntity.getBody(), bodyType); + }); } } diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandlerTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandlerTests.java index 1b4a1b8ebc7..1bebf0fb2d0 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandlerTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandlerTests.java @@ -16,23 +16,28 @@ package org.springframework.web.reactive.result.method.annotation; import java.net.URI; +import java.nio.charset.Charset; import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.CompletableFuture; import org.junit.Before; import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.core.test.TestSubscriber; +import rx.Single; import org.springframework.core.ResolvableType; import org.springframework.core.codec.support.ByteBufferEncoder; import org.springframework.core.codec.support.JacksonJsonEncoder; import org.springframework.core.codec.support.Jaxb2Encoder; import org.springframework.core.codec.support.StringEncoder; -import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.convert.support.GenericConversionService; import org.springframework.core.convert.support.ReactiveStreamsToCompletableFutureConverter; import org.springframework.core.convert.support.ReactiveStreamsToRxJava1Converter; +import org.springframework.core.io.buffer.support.DataBufferTestUtils; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; @@ -57,6 +62,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; +import static org.springframework.core.ResolvableType.forClassWithGenerics; /** * Unit tests for {@link ResponseEntityResultHandler}. When adding a test also @@ -104,22 +110,29 @@ public class ResponseEntityResultHandlerTests { service.addConverter(new ReactiveStreamsToRxJava1Converter()); RequestedContentTypeResolver resolver = new RequestedContentTypeResolverBuilder().build(); - return new ResponseEntityResultHandler(converterList, new DefaultConversionService(), resolver); + return new ResponseEntityResultHandler(converterList, service, resolver); } - @Test + @Test @SuppressWarnings("ConstantConditions") public void supports() throws NoSuchMethodException { ModelMap model = new ExtendedModelMap(); - ResolvableType type = ResolvableType.forClassWithGenerics(ResponseEntity.class, String.class); - assertTrue(this.resultHandler.supports(new HandlerResult(HANDLER, null, type, model))); + Object value = null; + ResolvableType type = responseEntityType(String.class); + assertTrue(this.resultHandler.supports(new HandlerResult(HANDLER, value, type, model))); + + type = forClassWithGenerics(Mono.class, responseEntityType(String.class)); + assertTrue(this.resultHandler.supports(new HandlerResult(HANDLER, value, type, model))); - type = ResolvableType.forClassWithGenerics(ResponseEntity.class, Void.class); - assertTrue(this.resultHandler.supports(new HandlerResult(HANDLER, null, type, model))); + type = forClassWithGenerics(Single.class, responseEntityType(String.class)); + assertTrue(this.resultHandler.supports(new HandlerResult(HANDLER, value, type, model))); - type = ResolvableType.forClass(Void.class); - assertFalse(this.resultHandler.supports(new HandlerResult(HANDLER, null, type, model))); + type = forClassWithGenerics(CompletableFuture.class, responseEntityType(String.class)); + assertTrue(this.resultHandler.supports(new HandlerResult(HANDLER, value, type, model))); + + type = ResolvableType.forClass(String.class); + assertFalse(this.resultHandler.supports(new HandlerResult(HANDLER, value, type, model))); } @Test @@ -129,7 +142,7 @@ public class ResponseEntityResultHandlerTests { @Test public void statusCode() throws Exception { - ResolvableType type = ResolvableType.forClassWithGenerics(ResponseEntity.class, Void.class); + ResolvableType type = responseEntityType(Void.class); HandlerResult result = new HandlerResult(HANDLER, ResponseEntity.noContent().build(), type); this.resultHandler.handleResult(exchange, result).block(Duration.ofSeconds(5)); @@ -141,9 +154,9 @@ public class ResponseEntityResultHandlerTests { @Test public void headers() throws Exception { URI location = new URI("/path"); - ResolvableType type = ResolvableType.forClassWithGenerics(ResponseEntity.class, Void.class); + ResolvableType type = responseEntityType(Void.class); HandlerResult result = new HandlerResult(HANDLER, ResponseEntity.created(location).build(), type); - this.resultHandler.handleResult(exchange, result).block(Duration.ofSeconds(5)); + this.resultHandler.handleResult(this.exchange, result).block(Duration.ofSeconds(5)); assertEquals(HttpStatus.CREATED, this.response.getStatus()); assertEquals(1, this.response.getHeaders().size()); @@ -151,4 +164,44 @@ public class ResponseEntityResultHandlerTests { assertNull(this.response.getBody()); } + @Test + public void handleReturnTypes() throws Exception { + Object returnValue = ResponseEntity.ok("abc"); + ResolvableType returnType = responseEntityType(String.class); + testHandle(returnValue, returnType); + + returnValue = Mono.just(ResponseEntity.ok("abc")); + returnType = forClassWithGenerics(Mono.class, responseEntityType(String.class)); + testHandle(returnValue, returnType); + + returnValue = Mono.just(ResponseEntity.ok("abc")); + returnType = forClassWithGenerics(Single.class, responseEntityType(String.class)); + testHandle(returnValue, returnType); + + returnValue = Mono.just(ResponseEntity.ok("abc")); + returnType = forClassWithGenerics(CompletableFuture.class, responseEntityType(String.class)); + testHandle(returnValue, returnType); + } + + + private void testHandle(Object returnValue, ResolvableType returnType) { + HandlerResult result = new HandlerResult(HANDLER, returnValue, returnType); + this.resultHandler.handleResult(this.exchange, result).block(Duration.ofSeconds(5)); + + assertEquals(HttpStatus.OK, this.response.getStatus()); + assertEquals("text/plain;charset=UTF-8", this.response.getHeaders().getFirst("Content-Type")); + assertResponseBody("abc"); + } + + + private ResolvableType responseEntityType(Class bodyType) { + return forClassWithGenerics(ResponseEntity.class, bodyType); + } + + private void assertResponseBody(String responseBody) { + TestSubscriber.subscribe(this.response.getBody()) + .assertValuesWith(buf -> assertEquals(responseBody, + DataBufferTestUtils.dumpString(buf, Charset.forName("UTF-8")))); + } + }