Browse Source

Introduce automatic context propagation in Coroutines

Closes gh-35485
pull/35641/head
Sébastien Deleuze 5 months ago
parent
commit
ec77bb0032
  1. 9
      spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java
  2. 30
      spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt

9
spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java

@ -45,6 +45,7 @@ import kotlinx.coroutines.reactor.ReactorFlowKt;
import org.jspecify.annotations.Nullable; import org.jspecify.annotations.Nullable;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -85,7 +86,9 @@ public abstract class CoroutinesUtils {
/** /**
* Invoke a suspending function and convert it to {@link Mono} or {@link Flux}. * Invoke a suspending function and convert it to {@link Mono} or {@link Flux}.
* Uses an {@linkplain Dispatchers#getUnconfined() unconfined} dispatcher. * Uses an {@linkplain Dispatchers#getUnconfined() unconfined} dispatcher, augmented with
* {@link PropagationContextElement} if
* {@linkplain Hooks#isAutomaticContextPropagationEnabled() Reactor automatic context propagation} is enabled.
* @param method the suspending function to invoke * @param method the suspending function to invoke
* @param target the target to invoke {@code method} on * @param target the target to invoke {@code method} on
* @param args the function arguments. If the {@code Continuation} argument is specified as the last argument * @param args the function arguments. If the {@code Continuation} argument is specified as the last argument
@ -94,7 +97,9 @@ public abstract class CoroutinesUtils {
* @throws IllegalArgumentException if {@code method} is not a suspending function * @throws IllegalArgumentException if {@code method} is not a suspending function
*/ */
public static Publisher<?> invokeSuspendingFunction(Method method, Object target, @Nullable Object... args) { public static Publisher<?> invokeSuspendingFunction(Method method, Object target, @Nullable Object... args) {
return invokeSuspendingFunction(Dispatchers.getUnconfined(), method, target, args); CoroutineContext context = Hooks.isAutomaticContextPropagationEnabled() ?
Dispatchers.getUnconfined().plus(new PropagationContextElement()) : Dispatchers.getUnconfined();
return invokeSuspendingFunction(context, method, target, args);
} }
/** /**

30
spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt

@ -16,20 +16,26 @@
package org.springframework.core package org.springframework.core
import io.micrometer.observation.Observation
import io.micrometer.observation.tck.TestObservationRegistry
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.reactor.awaitSingle import kotlinx.coroutines.reactor.awaitSingle
import kotlinx.coroutines.reactor.awaitSingleOrNull import kotlinx.coroutines.reactor.awaitSingleOrNull
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.reactivestreams.Publisher
import reactor.core.publisher.Flux import reactor.core.publisher.Flux
import reactor.core.publisher.Hooks
import reactor.core.publisher.Mono import reactor.core.publisher.Mono
import reactor.test.StepVerifier import reactor.test.StepVerifier
import kotlin.coroutines.Continuation import kotlin.coroutines.Continuation
import kotlin.coroutines.coroutineContext import kotlin.coroutines.coroutineContext
import kotlin.reflect.full.primaryConstructor import kotlin.reflect.full.primaryConstructor
import kotlin.reflect.jvm.isAccessible import kotlin.reflect.jvm.isAccessible
import kotlin.reflect.jvm.javaMethod
/** /**
* Kotlin tests for [CoroutinesUtils]. * Kotlin tests for [CoroutinesUtils].
@ -38,6 +44,8 @@ import kotlin.reflect.jvm.isAccessible
*/ */
class CoroutinesUtilsTests { class CoroutinesUtilsTests {
private val observationRegistry = TestObservationRegistry.create()
@Test @Test
fun deferredToMono() { fun deferredToMono() {
runBlocking { runBlocking {
@ -285,6 +293,21 @@ class CoroutinesUtilsTests {
} }
} }
@Test
@Suppress("UNCHECKED_CAST")
fun invokeSuspendingFunctionWithObservation() {
Hooks.enableAutomaticContextPropagation()
val method = CoroutinesUtilsTests::suspendingObservationFunction::javaMethod.get()!!
val publisher = CoroutinesUtils.invokeSuspendingFunction(method, this, "test", null) as Publisher<String>
val observation = Observation.createNotStarted("coroutine", observationRegistry)
observation.observe {
val mono = Mono.from<String>(publisher)
val result = mono.block()
assertThat(result).isEqualTo("coroutine")
}
Hooks.disableAutomaticContextPropagation()
}
suspend fun suspendingFunction(value: String): String { suspend fun suspendingFunction(value: String): String {
delay(1) delay(1)
return value return value
@ -417,4 +440,11 @@ class CoroutinesUtilsTests {
class CustomException(message: String) : Throwable(message) class CustomException(message: String) : Throwable(message)
suspend fun suspendingObservationFunction(value: String): String? {
delay(1)
val currentObservation = observationRegistry.currentObservation
assertThat(currentObservation).isNotNull
return currentObservation?.context?.name
}
} }

Loading…
Cancel
Save