From 441e2105338835c1d43dc176d6527c791febc8d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Wed, 22 Nov 2023 12:40:32 +0100 Subject: [PATCH] Treat kotlin.Unit as void in web controllers This commit fixes a regression introduced by gh-21139 via the usage of Kotlin reflection to invoke HTTP handler methods. It ensures that kotlin.Unit is treated as void by returning null. It also polishes CoroutinesUtils to have a consistent handling compared to the regular case, and adds related tests to prevent future regressions. Closes gh-31648 --- .../springframework/core/CoroutinesUtils.java | 2 +- .../core/CoroutinesUtilsTests.kt | 26 +++++++++++++++++++ .../support/InvocableHandlerMethod.java | 4 ++- .../InvocableHandlerMethodKotlinTests.kt | 21 +++++++++++++++ .../result/method/InvocableHandlerMethod.java | 4 ++- .../InvocableHandlerMethodKotlinTests.kt | 26 ++++++++++++++++++- 6 files changed, 79 insertions(+), 4 deletions(-) diff --git a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java index c2babd7236b..e25a8be6d2f 100644 --- a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java +++ b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java @@ -123,7 +123,7 @@ public abstract class CoroutinesUtils { } return KCallables.callSuspendBy(function, argMap, continuation); }) - .filter(result -> !Objects.equals(result, Unit.INSTANCE)) + .filter(result -> result != Unit.INSTANCE) .onErrorMap(InvocationTargetException.class, InvocationTargetException::getTargetException); KClassifier returnType = function.getReturnType().getClassifier(); diff --git a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt index ad7e4ac5d85..9a7029c7f13 100644 --- a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt @@ -20,6 +20,7 @@ import kotlinx.coroutines.* import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.reactor.awaitSingle +import kotlinx.coroutines.reactor.awaitSingleOrNull import org.assertj.core.api.Assertions import org.junit.jupiter.api.Test import reactor.core.publisher.Flux @@ -126,6 +127,24 @@ class CoroutinesUtilsTests { Assertions.assertThatIllegalArgumentException().isThrownBy { CoroutinesUtils.invokeSuspendingFunction(context, method, this, "foo") } } + @Test + fun invokeSuspendingFunctionReturningUnit() { + val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingUnit", Continuation::class.java) + val mono = CoroutinesUtils.invokeSuspendingFunction(method, this) as Mono + runBlocking { + Assertions.assertThat(mono.awaitSingleOrNull()).isNull() + } + } + + @Test + fun invokeSuspendingFunctionReturningNull() { + val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingNullable", Continuation::class.java) + val mono = CoroutinesUtils.invokeSuspendingFunction(method, this) as Mono + runBlocking { + Assertions.assertThat(mono.awaitSingleOrNull()).isNull() + } + } + suspend fun suspendingFunction(value: String): String { delay(1) return value @@ -146,4 +165,11 @@ class CoroutinesUtilsTests { return value } + suspend fun suspendingUnit() { + } + + suspend fun suspendingNullable(): String? { + return null + } + } diff --git a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java index 87c8b3d3b4a..39a7c918242 100644 --- a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java +++ b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Map; import java.util.Objects; +import kotlin.Unit; import kotlin.reflect.KFunction; import kotlin.reflect.KParameter; import kotlin.reflect.jvm.KCallablesJvm; @@ -315,7 +316,8 @@ public class InvocableHandlerMethod extends HandlerMethod { } } } - return function.callBy(argMap); + Object result = function.callBy(argMap); + return (result == Unit.INSTANCE ? null : result); } } diff --git a/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt b/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt index ad56c846e5c..24549b56c6f 100644 --- a/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt @@ -71,6 +71,19 @@ class InvocableHandlerMethodKotlinTests { Assertions.assertThat(value).isEqualTo("true") } + @Test + fun unitReturnValue() { + val value = getInvocable().invokeForRequest(request, null) + Assertions.assertThat(value).isNull() + } + + @Test + fun nullReturnValue() { + composite.addResolver(StubArgumentResolver(String::class.java, null)) + val value = getInvocable(String::class.java).invokeForRequest(request, null) + Assertions.assertThat(value).isNull() + } + private fun getInvocable(vararg argTypes: Class<*>): InvocableHandlerMethod { val method = ResolvableMethod.on(Handler::class.java).argTypes(*argTypes).resolveMethod() val handlerMethod = InvocableHandlerMethod(Handler(), method) @@ -95,6 +108,14 @@ class InvocableHandlerMethodKotlinTests { fun nullableBooleanDefaultValue(status: Boolean? = true) = status.toString() + + fun unit(): Unit { + } + + @Suppress("UNUSED_PARAMETER") + fun nullable(arg: String?): String? { + return null + } } } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java index 63dbb80a551..466adbcf1b5 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Objects; import java.util.stream.Stream; +import kotlin.Unit; import kotlin.coroutines.CoroutineContext; import kotlin.reflect.KFunction; import kotlin.reflect.KParameter; @@ -326,7 +327,8 @@ public class InvocableHandlerMethod extends HandlerMethod { } } } - return function.callBy(argMap); + Object result = function.callBy(argMap); + return (result == Unit.INSTANCE ? null : result); } } } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt index e7a1fe5a9b9..7226dea583b 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt @@ -164,6 +164,20 @@ class InvocableHandlerMethodKotlinTests { assertHandlerResultValue(result, "override") } + @Test + fun unitReturnValue() { + val method = NullResultController::unit.javaMethod!! + val result = invoke(NullResultController(), method) + assertHandlerResultValue(result, null) + } + + @Test + fun nullReturnValue() { + val method = NullResultController::nullable.javaMethod!! + val result = invoke(NullResultController(), method) + assertHandlerResultValue(result, null) + } + private fun invokeForResult(handler: Any, method: Method, vararg providedArgs: Any): HandlerResult? { return invoke(handler, method, *providedArgs).block(Duration.ofSeconds(5)) @@ -186,7 +200,7 @@ class InvocableHandlerMethodKotlinTests { return resolver } - private fun assertHandlerResultValue(mono: Mono, expected: String) { + private fun assertHandlerResultValue(mono: Mono, expected: String?) { StepVerifier.create(mono) .consumeNextWith { if (it.returnValue is Mono<*>) { @@ -242,4 +256,14 @@ class InvocableHandlerMethodKotlinTests { @Suppress("RedundantSuspendModifier") suspend fun handleSuspending(@RequestParam value: String = "default") = value } + + class NullResultController { + + fun unit() { + } + + fun nullable(): String? { + return null + } + } } \ No newline at end of file