diff --git a/spring-core-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt b/spring-core-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt index 50d535abefb..b4e4cd8e0ab 100644 --- a/spring-core-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt +++ b/spring-core-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt @@ -19,8 +19,12 @@ package org.springframework.core import kotlinx.coroutines.Deferred import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.async +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.reactor.mono @@ -29,6 +33,8 @@ import reactor.core.publisher.onErrorMap import java.lang.reflect.InvocationTargetException import java.lang.reflect.Method import kotlin.reflect.full.callSuspend +import kotlin.reflect.full.isSubtypeOf +import kotlin.reflect.full.starProjectedType import kotlin.reflect.jvm.kotlinFunction /** @@ -50,18 +56,29 @@ internal fun monoToDeferred(source: Mono) = GlobalScope.async(Dispatchers.Unconfined) { source.awaitFirstOrNull() } /** - * Invoke an handler method converting suspending method to [Mono] if necessary. + * Invoke an handler method converting suspending method to [Mono] or [Flow] if necessary. * * @author Sebastien Deleuze * @since 5.2 */ +@Suppress("UNCHECKED_CAST") +@FlowPreview internal fun invokeHandlerMethod(method: Method, bean: Any, vararg args: Any?): Any? { val function = method.kotlinFunction!! return if (function.isSuspend) { - GlobalScope.mono(Dispatchers.Unconfined) { - function.callSuspend(bean, *args.sliceArray(0..(args.size-2))) - .let { if (it == Unit) null else it} } - .onErrorMap(InvocationTargetException::class) { it.targetException } + if (function.returnType.isSubtypeOf(Flow::class.starProjectedType)) { + flow { + (function.callSuspend(bean, *args.sliceArray(0..(args.size-2))) as Flow<*>).collect { + emit(it) + } + } + } + else { + GlobalScope.mono(Dispatchers.Unconfined) { + function.callSuspend(bean, *args.sliceArray(0..(args.size-2))) + .let { if (it == Unit) null else it} + }.onErrorMap(InvocationTargetException::class) { it.targetException } + } } else { function.call(bean, *args) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java index 23eb9422a17..5b843491c9b 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java @@ -16,12 +16,16 @@ package org.springframework.web.reactive.result.method.annotation; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; +import kotlin.reflect.KFunction; +import kotlin.reflect.jvm.ReflectJvmMapping; import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; +import org.springframework.core.KotlinDetector; import org.springframework.core.MethodParameter; import org.springframework.core.ReactiveAdapter; import org.springframework.core.ReactiveAdapterRegistry; @@ -48,6 +52,8 @@ import org.springframework.web.server.ServerWebExchange; */ public abstract class AbstractMessageWriterResultHandler extends HandlerResultHandlerSupport { + private static final String COROUTINES_FLOW_CLASS_NAME = "kotlinx.coroutines.flow.Flow"; + private final List> messageWriters; @@ -110,7 +116,7 @@ public abstract class AbstractMessageWriterResultHandler extends HandlerResultHa * @return indicates completion or error * @since 5.0.2 */ - @SuppressWarnings({"unchecked", "rawtypes"}) + @SuppressWarnings({"unchecked", "rawtypes", "ConstantConditions"}) protected Mono writeBody(@Nullable Object body, MethodParameter bodyParameter, @Nullable MethodParameter actualParam, ServerWebExchange exchange) { @@ -122,7 +128,11 @@ public abstract class AbstractMessageWriterResultHandler extends HandlerResultHa ResolvableType elementType; if (adapter != null) { publisher = adapter.toPublisher(body); - ResolvableType genericType = bodyType.getGeneric(); + boolean isUnwrapped = KotlinDetector.isKotlinReflectPresent() && + KotlinDetector.isKotlinType(bodyParameter.getContainingClass()) && + KotlinDelegate.isSuspend(bodyParameter.getMethod()) && + !COROUTINES_FLOW_CLASS_NAME.equals(bodyType.toClass().getName()); + ResolvableType genericType = isUnwrapped ? bodyType : bodyType.getGeneric(); elementType = getElementType(adapter, genericType); } else { @@ -183,4 +193,15 @@ public abstract class AbstractMessageWriterResultHandler extends HandlerResultHa return writableMediaTypes; } + /** + * Inner class to avoid a hard dependency on Kotlin at runtime. + */ + private static class KotlinDelegate { + + static private boolean isSuspend(Method method) { + KFunction function = ReflectJvmMapping.getKotlinFunction(method); + return function != null && function.isSuspend(); + } + } + } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt new file mode 100644 index 00000000000..f2722552e3c --- /dev/null +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.result.method.annotation + +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.FlowPreview +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.async +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import org.junit.Assert.assertEquals +import org.junit.Test +import org.springframework.context.ApplicationContext +import org.springframework.context.annotation.AnnotationConfigApplicationContext +import org.springframework.context.annotation.ComponentScan +import org.springframework.context.annotation.Configuration +import org.springframework.http.HttpHeaders +import org.springframework.http.HttpStatus +import org.springframework.web.bind.annotation.GetMapping +import org.springframework.web.bind.annotation.RestController +import org.springframework.web.client.HttpServerErrorException +import org.springframework.web.reactive.config.EnableWebFlux + +@FlowPreview +class CoroutinesIntegrationTests : AbstractRequestMappingIntegrationTests() { + + override fun initApplicationContext(): ApplicationContext { + val context = AnnotationConfigApplicationContext() + context.register(WebConfig::class.java) + context.refresh() + return context + } + + @Test + fun `Suspending handler method`() { + val entity = performGet("/suspend", HttpHeaders.EMPTY, String::class.java) + assertEquals(HttpStatus.OK, entity.statusCode) + assertEquals("foo", entity.body) + } + + @Test + fun `Handler method returning Deferred`() { + val entity = performGet("/deferred", HttpHeaders.EMPTY, String::class.java) + assertEquals(HttpStatus.OK, entity.statusCode) + assertEquals("foo", entity.body) + } + + @Test + fun `Handler method returning Flow`() { + val entity = performGet("/flow", HttpHeaders.EMPTY, String::class.java) + assertEquals(HttpStatus.OK, entity.statusCode) + assertEquals("foobar", entity.body) + } + + @Test + fun `Suspending handler method returning Flow`() { + val entity = performGet("/suspending-flow", HttpHeaders.EMPTY, String::class.java) + assertEquals(HttpStatus.OK, entity.statusCode) + assertEquals("foobar", entity.body) + } + + @Test(expected = HttpServerErrorException.InternalServerError::class) + fun `Suspending handler method throwing exception`() { + performGet("/error", HttpHeaders.EMPTY, String::class.java) + } + + @Test(expected = HttpServerErrorException.InternalServerError::class) + fun `Handler method returning Flow throwing exception`() { + performGet("/flow-error", HttpHeaders.EMPTY, String::class.java) + } + + @Configuration + @EnableWebFlux + @ComponentScan(resourcePattern = "**/CoroutinesIntegrationTests*") + open class WebConfig + + @RestController + class CoroutinesController { + + @GetMapping("/suspend") + suspend fun suspendingEndpoint(): String { + delay(1) + return "foo" + } + + @GetMapping("/deferred") + fun deferredEndpoint(): Deferred = GlobalScope.async { + delay(1) + "foo" + } + + @GetMapping("/flow") + fun flowEndpoint()= flow { + emit("foo") + delay(1) + emit("bar") + delay(1) + } + + @GetMapping("/suspending-flow") + suspend fun suspendingFlowEndpoint(): Flow { + delay(10) + return flow { + emit("foo") + delay(1) + emit("bar") + delay(1) + } + } + + @GetMapping("/error") + suspend fun error() { + delay(1) + throw IllegalStateException() + } + + @GetMapping("/flow-error") + suspend fun flowError() = flow { + delay(1) + throw IllegalStateException() + } + + } +}