From 45ae00fda305a9ff5a4817cf6531aaf29c11e1bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Thu, 2 Feb 2023 14:06:29 +0100 Subject: [PATCH] Propagate the context in Coroutines transactions This commit ensures that CoroutineContext is properly propagated in transactional suspending functions. Both annotation and functional variants are supported. Closes gh-27308 --- .../interceptor/TransactionAspectSupport.java | 12 +++- .../TransactionalOperatorExtensions.kt | 17 ++++-- ...esAnnotationTransactionInterceptorTests.kt | 59 ++++++++++++++++++- .../TransactionalOperatorExtensionsTests.kt | 51 +++++++++++++++- 4 files changed, 128 insertions(+), 11 deletions(-) diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java index 1dbbdaa6055..a8a4de98f84 100644 --- a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,8 @@ import java.util.concurrent.ConcurrentMap; import io.vavr.control.Try; import kotlin.coroutines.Continuation; +import kotlin.coroutines.CoroutineContext; +import kotlinx.coroutines.Job; import kotlinx.coroutines.reactive.AwaitKt; import kotlinx.coroutines.reactive.ReactiveFlowKt; import org.apache.commons.logging.Log; @@ -363,7 +365,7 @@ public abstract class TransactionAspectSupport implements BeanFactoryAware, Init InvocationCallback callback = invocation; if (corInv != null) { - callback = () -> CoroutinesUtils.invokeSuspendingFunction(method, corInv.getTarget(), corInv.getArguments()); + callback = () -> KotlinDelegate.invokeSuspendingFunction(method, corInv); } Object result = txSupport.invokeWithinTransaction(method, targetClass, callback, txAttr, (ReactiveTransactionManager) tm); if (corInv != null) { @@ -883,6 +885,12 @@ public abstract class TransactionAspectSupport implements BeanFactoryAware, Init private static Object awaitSingleOrNull(Publisher publisher, Object continuation) { return AwaitKt.awaitSingleOrNull(publisher, (Continuation) continuation); } + + public static Publisher invokeSuspendingFunction(Method method, CoroutinesInvocationCallback callback) { + CoroutineContext coroutineContext = ((Continuation) callback.getContinuation()).getContext().minusKey(Job.Key); + return CoroutinesUtils.invokeSuspendingFunction(coroutineContext, method, callback.getTarget(), callback.getArguments()); + } + } diff --git a/spring-tx/src/main/kotlin/org/springframework/transaction/reactive/TransactionalOperatorExtensions.kt b/spring-tx/src/main/kotlin/org/springframework/transaction/reactive/TransactionalOperatorExtensions.kt index ef0f09b1fad..09d7e7efd2b 100644 --- a/spring-tx/src/main/kotlin/org/springframework/transaction/reactive/TransactionalOperatorExtensions.kt +++ b/spring-tx/src/main/kotlin/org/springframework/transaction/reactive/TransactionalOperatorExtensions.kt @@ -16,14 +16,17 @@ package org.springframework.transaction.reactive -import java.util.Optional -import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.reactive.asFlow import kotlinx.coroutines.reactive.awaitLast import kotlinx.coroutines.reactor.asFlux import kotlinx.coroutines.reactor.mono import org.springframework.transaction.ReactiveTransaction +import java.util.* +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext /** * Coroutines variant of [TransactionalOperator.transactional] as a [Flow] extension. @@ -31,8 +34,8 @@ import org.springframework.transaction.ReactiveTransaction * @author Sebastien Deleuze * @since 5.2 */ -fun Flow.transactional(operator: TransactionalOperator): Flow = - operator.transactional(asFlux()).asFlow() +fun Flow.transactional(operator: TransactionalOperator, context: CoroutineContext = EmptyCoroutineContext): Flow = + operator.transactional(asFlux(context)).asFlow() /** * Coroutines variant of [TransactionalOperator.execute] with a suspending lambda @@ -42,6 +45,8 @@ fun Flow.transactional(operator: TransactionalOperator): Flow = * @author Mark Paluch * @since 5.2 */ -suspend fun TransactionalOperator.executeAndAwait(f: suspend (ReactiveTransaction) -> T): T = - execute { status -> mono(Dispatchers.Unconfined) { f(status) } }.map { value -> Optional.ofNullable(value) } +suspend fun TransactionalOperator.executeAndAwait(f: suspend (ReactiveTransaction) -> T): T { + val context = currentCoroutineContext().minusKey(Job.Key) + return execute { status -> mono(context) { f(status) } }.map { value -> Optional.ofNullable(value) } .defaultIfEmpty(Optional.empty()).awaitLast().orElse(null) +} diff --git a/spring-tx/src/test/kotlin/org/springframework/transaction/annotation/CoroutinesAnnotationTransactionInterceptorTests.kt b/spring-tx/src/test/kotlin/org/springframework/transaction/annotation/CoroutinesAnnotationTransactionInterceptorTests.kt index 8af92330a3e..6130b4941f8 100644 --- a/spring-tx/src/test/kotlin/org/springframework/transaction/annotation/CoroutinesAnnotationTransactionInterceptorTests.kt +++ b/spring-tx/src/test/kotlin/org/springframework/transaction/annotation/CoroutinesAnnotationTransactionInterceptorTests.kt @@ -21,12 +21,15 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking -import org.assertj.core.api.Assertions.assertThat -import org.assertj.core.api.Assertions.fail +import kotlinx.coroutines.withContext +import org.assertj.core.api.Assertions.* import org.junit.jupiter.api.Test import org.springframework.aop.framework.ProxyFactory import org.springframework.transaction.interceptor.TransactionInterceptor import org.springframework.transaction.testfixture.ReactiveCallCountingTransactionManager +import kotlin.coroutines.AbstractCoroutineContextElement +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.coroutineContext /** * @author Sebastien Deleuze @@ -118,6 +121,36 @@ class CoroutinesAnnotationTransactionInterceptorTests { assertReactiveGetTransactionAndCommitCount(1) } + @Test + fun suspendingValueSuccessWithContext() { + val proxyFactory = ProxyFactory() + proxyFactory.setTarget(TestWithCoroutines()) + proxyFactory.addAdvice(TransactionInterceptor(rtm, source)) + val proxy = proxyFactory.proxy as TestWithCoroutines + assertThat(runBlocking { + withExampleContext("context") { + proxy.suspendingValueSuccessWithContext() + } + }).isEqualTo("context") + assertReactiveGetTransactionAndCommitCount(1) + } + + @Test + fun suspendingValueFailureWithContext() { + val proxyFactory = ProxyFactory() + proxyFactory.setTarget(TestWithCoroutines()) + proxyFactory.addAdvice(TransactionInterceptor(rtm, source)) + val proxy = proxyFactory.proxy as TestWithCoroutines + assertThatIllegalStateException().isThrownBy { + runBlocking { + withExampleContext("context") { + proxy.suspendingValueFailureWithContext() + } + } + }.withMessage("context") + assertReactiveGetTransactionAndRollbackCount(1) + } + private fun assertReactiveGetTransactionAndCommitCount(expectedCount: Int) { assertThat(rtm.begun).isEqualTo(expectedCount) assertThat(rtm.commits).isEqualTo(expectedCount) @@ -166,5 +199,27 @@ class CoroutinesAnnotationTransactionInterceptorTests { emit("foo") } } + + open suspend fun suspendingValueSuccessWithContext(): String { + delay(10) + return coroutineContext[ExampleContext.Key].toString() + } + + open suspend fun suspendingValueFailureWithContext(): String { + delay(10) + throw IllegalStateException(coroutineContext[ExampleContext.Key].toString()) + } } } + +data class ExampleContext(val value: String) : AbstractCoroutineContextElement(ExampleContext) { + + companion object Key : CoroutineContext.Key + + override fun toString(): String = value +} + +private suspend fun withExampleContext(inputValue: String, f: suspend () -> String) = + withContext(ExampleContext(inputValue)) { + f() + } diff --git a/spring-tx/src/test/kotlin/org/springframework/transaction/reactive/TransactionalOperatorExtensionsTests.kt b/spring-tx/src/test/kotlin/org/springframework/transaction/reactive/TransactionalOperatorExtensionsTests.kt index 82254142154..6f95b2b5bb3 100644 --- a/spring-tx/src/test/kotlin/org/springframework/transaction/reactive/TransactionalOperatorExtensionsTests.kt +++ b/spring-tx/src/test/kotlin/org/springframework/transaction/reactive/TransactionalOperatorExtensionsTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.transaction.reactive +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.delay import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.toList @@ -23,6 +24,8 @@ import kotlinx.coroutines.runBlocking import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.springframework.transaction.support.DefaultTransactionDefinition +import kotlin.coroutines.AbstractCoroutineContextElement +import kotlin.coroutines.CoroutineContext class TransactionalOperatorExtensionsTests { @@ -107,4 +110,50 @@ class TransactionalOperatorExtensionsTests { } } } + + @Test + fun coroutineContextWithSuspendingFunction() { + val operator = TransactionalOperator.create(tm, DefaultTransactionDefinition()) + runBlocking(User(role = "admin")) { + try { + operator.executeAndAwait { + delay(1) + val currentUser = currentCoroutineContext()[User] + assertThat(currentUser).isNotNull() + assertThat(currentUser!!.role).isEqualTo("admin") + throw IllegalStateException() + } + } catch (e: IllegalStateException) { + assertThat(tm.commit).isFalse() + assertThat(tm.rollback).isTrue() + return@runBlocking + } + } + } + + @Test + fun coroutineContextWithFlow() { + val operator = TransactionalOperator.create(tm, DefaultTransactionDefinition()) + val flow = flow { + delay(1) + val currentUser = currentCoroutineContext()[User] + assertThat(currentUser).isNotNull() + assertThat(currentUser!!.role).isEqualTo("admin") + throw IllegalStateException() + } + runBlocking(User(role = "admin")) { + try { + flow.transactional(operator, coroutineContext).toList() + } catch (e: IllegalStateException) { + assertThat(tm.commit).isFalse() + assertThat(tm.rollback).isTrue() + return@runBlocking + } + } + } + + + private data class User(val role: String) : AbstractCoroutineContextElement(User) { + companion object Key : CoroutineContext.Key + } }