Browse Source

Introduce automatic context propagation in Coroutines

Closes gh-35485
pull/35641/head
Sébastien Deleuze 5 months ago
parent
commit
ec77bb0032
  1. 9
      spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java
  2. 30
      spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt

9
spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java

@ -45,6 +45,7 @@ import kotlinx.coroutines.reactor.ReactorFlowKt; @@ -45,6 +45,7 @@ import kotlinx.coroutines.reactor.ReactorFlowKt;
import org.jspecify.annotations.Nullable;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import org.springframework.util.Assert;
@ -85,7 +86,9 @@ public abstract class CoroutinesUtils { @@ -85,7 +86,9 @@ public abstract class CoroutinesUtils {
/**
* Invoke a suspending function and convert it to {@link Mono} or {@link Flux}.
* Uses an {@linkplain Dispatchers#getUnconfined() unconfined} dispatcher.
* Uses an {@linkplain Dispatchers#getUnconfined() unconfined} dispatcher, augmented with
* {@link PropagationContextElement} if
* {@linkplain Hooks#isAutomaticContextPropagationEnabled() Reactor automatic context propagation} is enabled.
* @param method the suspending function to invoke
* @param target the target to invoke {@code method} on
* @param args the function arguments. If the {@code Continuation} argument is specified as the last argument
@ -94,7 +97,9 @@ public abstract class CoroutinesUtils { @@ -94,7 +97,9 @@ public abstract class CoroutinesUtils {
* @throws IllegalArgumentException if {@code method} is not a suspending function
*/
public static Publisher<?> invokeSuspendingFunction(Method method, Object target, @Nullable Object... args) {
return invokeSuspendingFunction(Dispatchers.getUnconfined(), method, target, args);
CoroutineContext context = Hooks.isAutomaticContextPropagationEnabled() ?
Dispatchers.getUnconfined().plus(new PropagationContextElement()) : Dispatchers.getUnconfined();
return invokeSuspendingFunction(context, method, target, args);
}
/**

30
spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt

@ -16,20 +16,26 @@ @@ -16,20 +16,26 @@
package org.springframework.core
import io.micrometer.observation.Observation
import io.micrometer.observation.tck.TestObservationRegistry
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.reactor.awaitSingle
import kotlinx.coroutines.reactor.awaitSingleOrNull
import org.assertj.core.api.Assertions
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.reactivestreams.Publisher
import reactor.core.publisher.Flux
import reactor.core.publisher.Hooks
import reactor.core.publisher.Mono
import reactor.test.StepVerifier
import kotlin.coroutines.Continuation
import kotlin.coroutines.coroutineContext
import kotlin.reflect.full.primaryConstructor
import kotlin.reflect.jvm.isAccessible
import kotlin.reflect.jvm.javaMethod
/**
* Kotlin tests for [CoroutinesUtils].
@ -38,6 +44,8 @@ import kotlin.reflect.jvm.isAccessible @@ -38,6 +44,8 @@ import kotlin.reflect.jvm.isAccessible
*/
class CoroutinesUtilsTests {
private val observationRegistry = TestObservationRegistry.create()
@Test
fun deferredToMono() {
runBlocking {
@ -285,6 +293,21 @@ class CoroutinesUtilsTests { @@ -285,6 +293,21 @@ class CoroutinesUtilsTests {
}
}
@Test
@Suppress("UNCHECKED_CAST")
fun invokeSuspendingFunctionWithObservation() {
Hooks.enableAutomaticContextPropagation()
val method = CoroutinesUtilsTests::suspendingObservationFunction::javaMethod.get()!!
val publisher = CoroutinesUtils.invokeSuspendingFunction(method, this, "test", null) as Publisher<String>
val observation = Observation.createNotStarted("coroutine", observationRegistry)
observation.observe {
val mono = Mono.from<String>(publisher)
val result = mono.block()
assertThat(result).isEqualTo("coroutine")
}
Hooks.disableAutomaticContextPropagation()
}
suspend fun suspendingFunction(value: String): String {
delay(1)
return value
@ -417,4 +440,11 @@ class CoroutinesUtilsTests { @@ -417,4 +440,11 @@ class CoroutinesUtilsTests {
class CustomException(message: String) : Throwable(message)
suspend fun suspendingObservationFunction(value: String): String? {
delay(1)
val currentObservation = observationRegistry.currentObservation
assertThat(currentObservation).isNotNull
return currentObservation?.context?.name
}
}

Loading…
Cancel
Save