Browse Source

Adapt for Spring Framework Coroutines AOP support.

This commit adapts Spring Data RepositoryMethodInvoker
and related tests in order to remove most of the
Coroutines specific code and rely on Spring Framework
Coroutines AOP support.

Closes #2926
pull/2927/head
Sébastien Deleuze 2 years ago committed by Mark Paluch
parent
commit
110756a40a
No known key found for this signature in database
GPG Key ID: 4406B84C1661DCD1
  1. 86
      src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java
  2. 47
      src/test/java/org/springframework/data/repository/core/support/RepositoryMethodInvokerUnitTests.java
  3. 6
      src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryCustomImplementationUnitTests.kt
  4. 12
      src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt

86
src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java

@ -15,18 +15,16 @@ @@ -15,18 +15,16 @@
*/
package org.springframework.data.repository.core.support;
import kotlin.coroutines.Continuation;
import kotlin.reflect.KFunction;
import kotlinx.coroutines.reactive.AwaitKt;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.stream.Stream;
import org.reactivestreams.Publisher;
import org.springframework.aop.support.AopUtils;
import org.springframework.core.KotlinDetector;
import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocation;
import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocationResult;
@ -116,12 +114,7 @@ abstract class RepositoryMethodInvoker { @@ -116,12 +114,7 @@ abstract class RepositoryMethodInvoker {
@Nullable
public Object invoke(Class<?> repositoryInterface, RepositoryInvocationMulticaster multicaster, Object[] args)
throws Exception {
return shouldAdaptReactiveToSuspended() ? doInvokeReactiveToSuspended(repositoryInterface, multicaster, args)
: doInvoke(repositoryInterface, multicaster, args);
}
protected boolean shouldAdaptReactiveToSuspended() {
return suspendedDeclaredMethod;
return doInvoke(repositoryInterface, multicaster, args);
}
@Nullable
@ -153,46 +146,6 @@ abstract class RepositoryMethodInvoker { @@ -153,46 +146,6 @@ abstract class RepositoryMethodInvoker {
}
}
@Nullable
@SuppressWarnings({ "unchecked", "ConstantConditions" })
private Object doInvokeReactiveToSuspended(Class<?> repositoryInterface, RepositoryInvocationMulticaster multicaster,
Object[] args) throws Exception {
/*
* Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context.
* We're invoking a method without Continuation as we expect the method to return any sort of reactive type,
* therefore we need to strip the Continuation parameter.
*/
Continuation<Object> continuation = (Continuation) args[args.length - 1];
args[args.length - 1] = null;
RepositoryMethodInvocationCaptor invocationResultCaptor = RepositoryMethodInvocationCaptor
.captureInvocationOn(repositoryInterface);
try {
Publisher<?> result = new ReactiveInvocationListenerDecorator().decorate(repositoryInterface, multicaster, args,
invokable.invoke(args));
if (returnsReactiveType) {
return ReactiveWrapperConverters.toWrapper(result, returnedType);
}
if (Collection.class.isAssignableFrom(returnedType)) {
result = (Publisher<?>) collectToList(result);
}
return AwaitKt.awaitFirstOrNull(result, continuation);
} catch (Exception e) {
multicaster.notifyListeners(method, args, computeInvocationResult(invocationResultCaptor.error(e)));
throw e;
}
}
// to avoid NoClassDefFoundError: org/reactivestreams/Publisher when loading this class ¯\_(ツ)_/¯
private static Object collectToList(Object result) {
return Flux.from((Publisher<?>) result).collectList();
}
private RepositoryMethodInvocation computeInvocationResult(RepositoryMethodInvocationCaptor captured) {
return new RepositoryMethodInvocation(captured.getRepositoryInterface(), method, captured.getCapturedResult(),
captured.getDuration());
@ -271,30 +224,27 @@ abstract class RepositoryMethodInvoker { @@ -271,30 +224,27 @@ abstract class RepositoryMethodInvoker {
public RepositoryFragmentMethodInvoker(CoroutineAdapterInformation adapterInformation, Method declaredMethod,
Object instance, Method baseClassMethod) {
super(declaredMethod, args -> {
if (adapterInformation.isAdapterMethod()) {
/*
* Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context.
* We're invoking a method without Continuation as we expect the method to return any sort of reactive type,
* therefore we need to strip the Continuation parameter.
*/
Object[] invocationArguments = new Object[args.length - 1];
System.arraycopy(args, 0, invocationArguments, 0, invocationArguments.length);
return baseClassMethod.invoke(instance, invocationArguments);
try {
if (adapterInformation.shouldAdaptReactiveToSuspended()) {
/*
* Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context.
* We're invoking a method without Continuation as we expect the method to return any sort of reactive type,
* therefore we need to strip the Continuation parameter.
*/
Object[] invocationArguments = new Object[args.length - 1];
System.arraycopy(args, 0, invocationArguments, 0, invocationArguments.length);
return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, invocationArguments);
}
return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, args);
} catch (RuntimeException e) {
throw e;
} catch (Throwable e) {
throw new RuntimeException(e);
}
return baseClassMethod.invoke(instance, args);
});
this.adapterInformation = adapterInformation;
}
@Override
protected boolean shouldAdaptReactiveToSuspended() {
return adapterInformation.shouldAdaptReactiveToSuspended();
}
/**
* Value object capturing whether a suspended Kotlin method (Coroutine method) can be bridged with a native or
* reactive fragment method.

47
src/test/java/org/springframework/data/repository/core/support/RepositoryMethodInvokerUnitTests.java

@ -15,18 +15,6 @@ @@ -15,18 +15,6 @@
*/
package org.springframework.data.repository.core.support;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.Mockito.*;
import kotlin.coroutines.Continuation;
import kotlin.coroutines.CoroutineContext;
import kotlinx.coroutines.flow.Flow;
import kotlinx.coroutines.flow.FlowKt;
import kotlinx.coroutines.reactor.ReactorContext;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Iterator;
@ -38,6 +26,8 @@ import java.util.concurrent.TimeUnit; @@ -38,6 +26,8 @@ import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Stream;
import kotlin.coroutines.Continuation;
import kotlinx.coroutines.reactive.ReactiveFlowKt;
import org.assertj.core.api.Assertions;
import org.assertj.core.data.Percentage;
import org.jetbrains.annotations.NotNull;
@ -49,6 +39,10 @@ import org.mockito.internal.stubbing.answers.AnswersWithDelay; @@ -49,6 +39,10 @@ import org.mockito.internal.stubbing.answers.AnswersWithDelay;
import org.mockito.internal.stubbing.answers.Returns;
import org.mockito.junit.jupiter.MockitoExtension;
import org.reactivestreams.Subscription;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.data.repository.CrudRepository;
import org.springframework.data.repository.core.support.CoroutineRepositoryMetadataUnitTests.MyCoroutineRepository;
import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocation;
@ -59,6 +53,12 @@ import org.springframework.lang.Nullable; @@ -59,6 +53,12 @@ import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* @author Christoph Strobl
* @author Johannes Englmeier
@ -244,29 +244,12 @@ class RepositoryMethodInvokerUnitTests { @@ -244,29 +244,12 @@ class RepositoryMethodInvokerUnitTests {
@Test // DATACMNS-1764
void capturesKotlinSuspendFunctionsCorrectly() throws Exception {
var result = Flux.just(new TestDummy());
var result = ReactiveFlowKt.asFlow(Flux.just(new TestDummy()));
when(query.execute(any())).thenReturn(result);
Flow<TestDummy> flow = new RepositoryMethodInvokerStub(MyCoroutineRepository.class, multicaster,
Flux<TestDummy> flux = new RepositoryMethodInvokerStub(MyCoroutineRepository.class, multicaster,
"suspendedQueryMethod", query::execute).invoke(mock(Continuation.class));
assertThat(multicaster).isEmpty();
FlowKt.toCollection(flow, new ArrayList<>(), new Continuation<ArrayList<? extends Object>>() {
ReactorContext ctx = new ReactorContext(reactor.util.context.Context.empty());
@NotNull
@Override
public CoroutineContext getContext() {
return ctx;
}
@Override
public void resumeWith(@NotNull Object o) {
}
});
flux.subscribe();
assertThat(multicaster.first().getResult().getState()).isEqualTo(State.SUCCESS);
assertThat(multicaster.first().getResult().getError()).isNull();

6
src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryCustomImplementationUnitTests.kt

@ -25,6 +25,7 @@ import org.springframework.data.repository.core.RepositoryMetadata @@ -25,6 +25,7 @@ import org.springframework.data.repository.core.RepositoryMetadata
import org.springframework.data.repository.core.support.DummyReactiveRepositoryFactory
import org.springframework.data.repository.core.support.RepositoryComposition
import org.springframework.data.repository.core.support.RepositoryFragment
import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener
import org.springframework.data.repository.reactive.ReactiveCrudRepository
import org.springframework.data.repository.sample.User
@ -42,7 +43,12 @@ class CoroutineCrudRepositoryCustomImplementationUnitTests { @@ -42,7 +43,12 @@ class CoroutineCrudRepositoryCustomImplementationUnitTests {
@BeforeEach
fun before() {
factory = CustomDummyReactiveRepositoryFactory(backingRepository)
factory.addInvocationListener(RepositoryMethodInvocationListener {
repositoryMethodInvocation ->
println(repositoryMethodInvocation)
})
coRepository = factory.getRepository(MyCoRepository::class.java)
}
@Test // DATACMNS-1508

12
src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt

@ -19,7 +19,6 @@ import io.mockk.every @@ -19,7 +19,6 @@ import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.toList
@ -28,6 +27,7 @@ import org.assertj.core.api.Assertions.assertThat @@ -28,6 +27,7 @@ import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito
import org.reactivestreams.Publisher
import org.springframework.data.repository.core.support.DummyReactiveRepositoryFactory
@ -199,7 +199,7 @@ class CoroutineCrudRepositoryUnitTests { @@ -199,7 +199,7 @@ class CoroutineCrudRepositoryUnitTests {
val sample = User()
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Mono.just(sample))
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(sample))
val result = runBlocking {
coRepository.findOne("foo")
@ -215,7 +215,7 @@ class CoroutineCrudRepositoryUnitTests { @@ -215,7 +215,7 @@ class CoroutineCrudRepositoryUnitTests {
val sample = User()
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Single.just(sample))
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(sample))
val result = runBlocking {
coRepository.findOne("foo")
@ -263,7 +263,7 @@ class CoroutineCrudRepositoryUnitTests { @@ -263,7 +263,7 @@ class CoroutineCrudRepositoryUnitTests {
val sample = User()
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Flux.just(sample), Flux.empty<User>())
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Flux.just(sample), Flux.empty<User>())
val result = runBlocking {
coRepository.findSuspendedMultiple("foo").toList()
@ -283,7 +283,7 @@ class CoroutineCrudRepositoryUnitTests { @@ -283,7 +283,7 @@ class CoroutineCrudRepositoryUnitTests {
val sample = User()
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Flux.just(sample), Flux.empty<User>())
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(listOf(sample)), Mono.empty<User>())
val result = runBlocking {
coRepository.findSuspendedAsList("foo")
@ -295,7 +295,7 @@ class CoroutineCrudRepositoryUnitTests { @@ -295,7 +295,7 @@ class CoroutineCrudRepositoryUnitTests {
coRepository.findSuspendedAsList("foo")
}
assertThat(emptyResult).isEmpty()
assertThat(emptyResult).isNull()
}
interface MyCoRepository : CoroutineCrudRepository<User, String> {

Loading…
Cancel
Save