From aee2df8919befd7162723b4bd5cbd480cdd6993c Mon Sep 17 00:00:00 2001 From: Sebastien Deleuze Date: Tue, 23 Apr 2019 11:09:41 +0200 Subject: [PATCH] Improve WebFlux suspending handler method support Support for suspending handler methods introduced in Spring Framework 5.2 M1 does not detect types correctly and does not support suspending handler methods returning Flow which is a common use case with WebClient. This commit fixes these issues and adds Coroutines integration tests. Closes gh-22820 Closes gh-22827 --- .../springframework/core/CoroutinesUtils.kt | 27 +++- .../AbstractMessageWriterResultHandler.java | 25 +++- .../annotation/CoroutinesIntegrationTests.kt | 139 ++++++++++++++++++ 3 files changed, 184 insertions(+), 7 deletions(-) create mode 100644 spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt diff --git a/spring-core-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt b/spring-core-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt index 50d535abefb..b4e4cd8e0ab 100644 --- a/spring-core-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt +++ b/spring-core-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt @@ -19,8 +19,12 @@ package org.springframework.core import kotlinx.coroutines.Deferred import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.async +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.reactor.mono @@ -29,6 +33,8 @@ import reactor.core.publisher.onErrorMap import java.lang.reflect.InvocationTargetException import java.lang.reflect.Method import kotlin.reflect.full.callSuspend +import kotlin.reflect.full.isSubtypeOf +import kotlin.reflect.full.starProjectedType import kotlin.reflect.jvm.kotlinFunction /** @@ -50,18 +56,29 @@ internal fun monoToDeferred(source: Mono) = GlobalScope.async(Dispatchers.Unconfined) { source.awaitFirstOrNull() } /** - * Invoke an handler method converting suspending method to [Mono] if necessary. + * Invoke an handler method converting suspending method to [Mono] or [Flow] if necessary. * * @author Sebastien Deleuze * @since 5.2 */ +@Suppress("UNCHECKED_CAST") +@FlowPreview internal fun invokeHandlerMethod(method: Method, bean: Any, vararg args: Any?): Any? { val function = method.kotlinFunction!! return if (function.isSuspend) { - GlobalScope.mono(Dispatchers.Unconfined) { - function.callSuspend(bean, *args.sliceArray(0..(args.size-2))) - .let { if (it == Unit) null else it} } - .onErrorMap(InvocationTargetException::class) { it.targetException } + if (function.returnType.isSubtypeOf(Flow::class.starProjectedType)) { + flow { + (function.callSuspend(bean, *args.sliceArray(0..(args.size-2))) as Flow<*>).collect { + emit(it) + } + } + } + else { + GlobalScope.mono(Dispatchers.Unconfined) { + function.callSuspend(bean, *args.sliceArray(0..(args.size-2))) + .let { if (it == Unit) null else it} + }.onErrorMap(InvocationTargetException::class) { it.targetException } + } } else { function.call(bean, *args) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java index 23eb9422a17..5b843491c9b 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java @@ -16,12 +16,16 @@ package org.springframework.web.reactive.result.method.annotation; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; +import kotlin.reflect.KFunction; +import kotlin.reflect.jvm.ReflectJvmMapping; import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; +import org.springframework.core.KotlinDetector; import org.springframework.core.MethodParameter; import org.springframework.core.ReactiveAdapter; import org.springframework.core.ReactiveAdapterRegistry; @@ -48,6 +52,8 @@ import org.springframework.web.server.ServerWebExchange; */ public abstract class AbstractMessageWriterResultHandler extends HandlerResultHandlerSupport { + private static final String COROUTINES_FLOW_CLASS_NAME = "kotlinx.coroutines.flow.Flow"; + private final List> messageWriters; @@ -110,7 +116,7 @@ public abstract class AbstractMessageWriterResultHandler extends HandlerResultHa * @return indicates completion or error * @since 5.0.2 */ - @SuppressWarnings({"unchecked", "rawtypes"}) + @SuppressWarnings({"unchecked", "rawtypes", "ConstantConditions"}) protected Mono writeBody(@Nullable Object body, MethodParameter bodyParameter, @Nullable MethodParameter actualParam, ServerWebExchange exchange) { @@ -122,7 +128,11 @@ public abstract class AbstractMessageWriterResultHandler extends HandlerResultHa ResolvableType elementType; if (adapter != null) { publisher = adapter.toPublisher(body); - ResolvableType genericType = bodyType.getGeneric(); + boolean isUnwrapped = KotlinDetector.isKotlinReflectPresent() && + KotlinDetector.isKotlinType(bodyParameter.getContainingClass()) && + KotlinDelegate.isSuspend(bodyParameter.getMethod()) && + !COROUTINES_FLOW_CLASS_NAME.equals(bodyType.toClass().getName()); + ResolvableType genericType = isUnwrapped ? bodyType : bodyType.getGeneric(); elementType = getElementType(adapter, genericType); } else { @@ -183,4 +193,15 @@ public abstract class AbstractMessageWriterResultHandler extends HandlerResultHa return writableMediaTypes; } + /** + * Inner class to avoid a hard dependency on Kotlin at runtime. + */ + private static class KotlinDelegate { + + static private boolean isSuspend(Method method) { + KFunction function = ReflectJvmMapping.getKotlinFunction(method); + return function != null && function.isSuspend(); + } + } + } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt new file mode 100644 index 00000000000..f2722552e3c --- /dev/null +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.result.method.annotation + +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.FlowPreview +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.async +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import org.junit.Assert.assertEquals +import org.junit.Test +import org.springframework.context.ApplicationContext +import org.springframework.context.annotation.AnnotationConfigApplicationContext +import org.springframework.context.annotation.ComponentScan +import org.springframework.context.annotation.Configuration +import org.springframework.http.HttpHeaders +import org.springframework.http.HttpStatus +import org.springframework.web.bind.annotation.GetMapping +import org.springframework.web.bind.annotation.RestController +import org.springframework.web.client.HttpServerErrorException +import org.springframework.web.reactive.config.EnableWebFlux + +@FlowPreview +class CoroutinesIntegrationTests : AbstractRequestMappingIntegrationTests() { + + override fun initApplicationContext(): ApplicationContext { + val context = AnnotationConfigApplicationContext() + context.register(WebConfig::class.java) + context.refresh() + return context + } + + @Test + fun `Suspending handler method`() { + val entity = performGet("/suspend", HttpHeaders.EMPTY, String::class.java) + assertEquals(HttpStatus.OK, entity.statusCode) + assertEquals("foo", entity.body) + } + + @Test + fun `Handler method returning Deferred`() { + val entity = performGet("/deferred", HttpHeaders.EMPTY, String::class.java) + assertEquals(HttpStatus.OK, entity.statusCode) + assertEquals("foo", entity.body) + } + + @Test + fun `Handler method returning Flow`() { + val entity = performGet("/flow", HttpHeaders.EMPTY, String::class.java) + assertEquals(HttpStatus.OK, entity.statusCode) + assertEquals("foobar", entity.body) + } + + @Test + fun `Suspending handler method returning Flow`() { + val entity = performGet("/suspending-flow", HttpHeaders.EMPTY, String::class.java) + assertEquals(HttpStatus.OK, entity.statusCode) + assertEquals("foobar", entity.body) + } + + @Test(expected = HttpServerErrorException.InternalServerError::class) + fun `Suspending handler method throwing exception`() { + performGet("/error", HttpHeaders.EMPTY, String::class.java) + } + + @Test(expected = HttpServerErrorException.InternalServerError::class) + fun `Handler method returning Flow throwing exception`() { + performGet("/flow-error", HttpHeaders.EMPTY, String::class.java) + } + + @Configuration + @EnableWebFlux + @ComponentScan(resourcePattern = "**/CoroutinesIntegrationTests*") + open class WebConfig + + @RestController + class CoroutinesController { + + @GetMapping("/suspend") + suspend fun suspendingEndpoint(): String { + delay(1) + return "foo" + } + + @GetMapping("/deferred") + fun deferredEndpoint(): Deferred = GlobalScope.async { + delay(1) + "foo" + } + + @GetMapping("/flow") + fun flowEndpoint()= flow { + emit("foo") + delay(1) + emit("bar") + delay(1) + } + + @GetMapping("/suspending-flow") + suspend fun suspendingFlowEndpoint(): Flow { + delay(10) + return flow { + emit("foo") + delay(1) + emit("bar") + delay(1) + } + } + + @GetMapping("/error") + suspend fun error() { + delay(1) + throw IllegalStateException() + } + + @GetMapping("/flow-error") + suspend fun flowError() = flow { + delay(1) + throw IllegalStateException() + } + + } +}