@ -21,12 +21,15 @@ import kotlinx.coroutines.flow.Flow
@@ -21,12 +21,15 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertTha t
import org.assertj.core.api.Assertions.fail
import kotlinx.coroutines.withContex t
import org.assertj.core.api.Assertions.*
import org.junit.jupiter.api.Test
import org.springframework.aop.framework.ProxyFactory
import org.springframework.transaction.interceptor.TransactionInterceptor
import org.springframework.transaction.testfixture.ReactiveCallCountingTransactionManager
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext
/ * *
* @author Sebastien Deleuze
@ -118,6 +121,36 @@ class CoroutinesAnnotationTransactionInterceptorTests {
@@ -118,6 +121,36 @@ class CoroutinesAnnotationTransactionInterceptorTests {
assertReactiveGetTransactionAndCommitCount ( 1 )
}
@Test
fun suspendingValueSuccessWithContext ( ) {
val proxyFactory = ProxyFactory ( )
proxyFactory . setTarget ( TestWithCoroutines ( ) )
proxyFactory . addAdvice ( TransactionInterceptor ( rtm , source ) )
val proxy = proxyFactory . proxy as TestWithCoroutines
assertThat ( runBlocking {
withExampleContext ( " context " ) {
proxy . suspendingValueSuccessWithContext ( )
}
} ) . isEqualTo ( " context " )
assertReactiveGetTransactionAndCommitCount ( 1 )
}
@Test
fun suspendingValueFailureWithContext ( ) {
val proxyFactory = ProxyFactory ( )
proxyFactory . setTarget ( TestWithCoroutines ( ) )
proxyFactory . addAdvice ( TransactionInterceptor ( rtm , source ) )
val proxy = proxyFactory . proxy as TestWithCoroutines
assertThatIllegalStateException ( ) . isThrownBy {
runBlocking {
withExampleContext ( " context " ) {
proxy . suspendingValueFailureWithContext ( )
}
}
} . withMessage ( " context " )
assertReactiveGetTransactionAndRollbackCount ( 1 )
}
private fun assertReactiveGetTransactionAndCommitCount ( expectedCount : Int ) {
assertThat ( rtm . begun ) . isEqualTo ( expectedCount )
assertThat ( rtm . commits ) . isEqualTo ( expectedCount )
@ -166,5 +199,27 @@ class CoroutinesAnnotationTransactionInterceptorTests {
@@ -166,5 +199,27 @@ class CoroutinesAnnotationTransactionInterceptorTests {
emit ( " foo " )
}
}
open suspend fun suspendingValueSuccessWithContext ( ) : String {
delay ( 10 )
return coroutineContext [ ExampleContext . Key ] . toString ( )
}
open suspend fun suspendingValueFailureWithContext ( ) : String {
delay ( 10 )
throw IllegalStateException ( coroutineContext [ ExampleContext . Key ] . toString ( ) )
}
}
}
data class ExampleContext ( val value : String ) : AbstractCoroutineContextElement ( ExampleContext ) {
companion object Key : CoroutineContext . Key < ExampleContext >
override fun toString ( ) : String = value
}
private suspend fun withExampleContext ( inputValue : String , f : suspend ( ) -> String ) =
withContext ( ExampleContext ( inputValue ) ) {
f ( )
}