Browse Source

Consider declaring class when evaluating method return type for query method post-processing.

We now consider the declaring class to properly resolve type variable references for the result post-processing of a query method result.

Previously, we attempted to resolve the return type without considering the actual repository class resolving always Object instead of the type parameter.

Closes #3125
3.2.x
Mark Paluch 1 year ago
parent
commit
5c59d87f41
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 53
      src/main/java/org/springframework/data/repository/core/support/QueryExecutionResultHandler.java
  2. 11
      src/main/java/org/springframework/data/repository/core/support/QueryExecutorMethodInterceptor.java
  3. 27
      src/test/java/org/springframework/data/repository/core/support/QueryExecutionResultHandlerUnitTests.java

53
src/main/java/org/springframework/data/repository/core/support/QueryExecutionResultHandler.java

@ -15,7 +15,6 @@
*/ */
package org.springframework.data.repository.core.support; package org.springframework.data.repository.core.support;
import java.lang.reflect.Method;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -23,13 +22,16 @@ import java.util.Map;
import java.util.Optional; import java.util.Optional;
import org.springframework.core.CollectionFactory; import org.springframework.core.CollectionFactory;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.core.convert.ConversionService; import org.springframework.core.convert.ConversionService;
import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.TypeDescriptor;
import org.springframework.core.convert.support.GenericConversionService; import org.springframework.core.convert.support.GenericConversionService;
import org.springframework.data.repository.util.ClassUtils;
import org.springframework.data.repository.util.QueryExecutionConverters; import org.springframework.data.repository.util.QueryExecutionConverters;
import org.springframework.data.repository.util.ReactiveWrapperConverters; import org.springframework.data.repository.util.ReactiveWrapperConverters;
import org.springframework.data.util.NullableWrapper; import org.springframework.data.util.NullableWrapper;
import org.springframework.data.util.ReactiveWrappers;
import org.springframework.data.util.Streamable; import org.springframework.data.util.Streamable;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
@ -44,12 +46,14 @@ class QueryExecutionResultHandler {
private static final TypeDescriptor WRAPPER_TYPE = TypeDescriptor.valueOf(NullableWrapper.class); private static final TypeDescriptor WRAPPER_TYPE = TypeDescriptor.valueOf(NullableWrapper.class);
private static final Class<?> FLOW_TYPE = loadIfPresent("kotlinx.coroutines.flow.Flow");
private final GenericConversionService conversionService; private final GenericConversionService conversionService;
private final Object mutex = new Object(); private final Object mutex = new Object();
// concurrent access guarded by mutex. // concurrent access guarded by mutex.
private Map<Method, ReturnTypeDescriptor> descriptorCache = Collections.emptyMap(); private Map<MethodParameter, ReturnTypeDescriptor> descriptorCache = Collections.emptyMap();
/** /**
* Creates a new {@link QueryExecutionResultHandler}. * Creates a new {@link QueryExecutionResultHandler}.
@ -58,6 +62,17 @@ class QueryExecutionResultHandler {
this.conversionService = conversionService; this.conversionService = conversionService;
} }
@Nullable
@SuppressWarnings("unchecked")
public static <T> Class<T> loadIfPresent(String type) {
try {
return (Class<T>) org.springframework.util.ClassUtils.forName(type, ClassUtils.class.getClassLoader());
} catch (ClassNotFoundException | LinkageError e) {
return null;
}
}
/** /**
* Post-processes the given result of a query invocation to match the return type of the given method. * Post-processes the given result of a query invocation to match the return type of the given method.
* *
@ -66,9 +81,9 @@ class QueryExecutionResultHandler {
* @return * @return
*/ */
@Nullable @Nullable
Object postProcessInvocationResult(@Nullable Object result, Method method) { Object postProcessInvocationResult(@Nullable Object result, MethodParameter method) {
if (!processingRequired(result, method.getReturnType())) { if (!processingRequired(result, method)) {
return result; return result;
} }
@ -77,16 +92,16 @@ class QueryExecutionResultHandler {
return postProcessInvocationResult(result, 0, descriptor); return postProcessInvocationResult(result, 0, descriptor);
} }
private ReturnTypeDescriptor getOrCreateReturnTypeDescriptor(Method method) { private ReturnTypeDescriptor getOrCreateReturnTypeDescriptor(MethodParameter method) {
Map<Method, ReturnTypeDescriptor> descriptorCache = this.descriptorCache; Map<MethodParameter, ReturnTypeDescriptor> descriptorCache = this.descriptorCache;
ReturnTypeDescriptor descriptor = descriptorCache.get(method); ReturnTypeDescriptor descriptor = descriptorCache.get(method);
if (descriptor == null) { if (descriptor == null) {
descriptor = ReturnTypeDescriptor.of(method); descriptor = ReturnTypeDescriptor.of(method);
Map<Method, ReturnTypeDescriptor> updatedDescriptorCache; Map<MethodParameter, ReturnTypeDescriptor> updatedDescriptorCache;
if (descriptorCache.isEmpty()) { if (descriptorCache.isEmpty()) {
updatedDescriptorCache = Collections.singletonMap(method, descriptor); updatedDescriptorCache = Collections.singletonMap(method, descriptor);
@ -94,7 +109,6 @@ class QueryExecutionResultHandler {
updatedDescriptorCache = new HashMap<>(descriptorCache.size() + 1, 1); updatedDescriptorCache = new HashMap<>(descriptorCache.size() + 1, 1);
updatedDescriptorCache.putAll(descriptorCache); updatedDescriptorCache.putAll(descriptorCache);
updatedDescriptorCache.put(method, descriptor); updatedDescriptorCache.put(method, descriptor);
} }
synchronized (mutex) { synchronized (mutex) {
@ -234,10 +248,21 @@ class QueryExecutionResultHandler {
* Returns whether we have to process the given source object in the first place. * Returns whether we have to process the given source object in the first place.
* *
* @param source can be {@literal null}. * @param source can be {@literal null}.
* @param targetType must not be {@literal null}. * @param methodParameter must not be {@literal null}.
* @return * @return
*/ */
private static boolean processingRequired(@Nullable Object source, Class<?> targetType) { private static boolean processingRequired(@Nullable Object source, MethodParameter methodParameter) {
Class<?> targetType = methodParameter.getParameterType();
if (source != null && ReactiveWrappers.KOTLIN_COROUTINES_PRESENT
&& KotlinDetector.isSuspendingFunction(methodParameter.getMethod())) {
// Spring's AOP invoker handles Publisher to Flow conversion, so we have to exempt these from post-processing.
if (FLOW_TYPE != null && FLOW_TYPE.isAssignableFrom(targetType)) {
return false;
}
}
return !targetType.isInstance(source) // return !targetType.isInstance(source) //
|| source == null // || source == null //
@ -253,19 +278,19 @@ class QueryExecutionResultHandler {
private final TypeDescriptor typeDescriptor; private final TypeDescriptor typeDescriptor;
private final @Nullable TypeDescriptor nestedTypeDescriptor; private final @Nullable TypeDescriptor nestedTypeDescriptor;
private ReturnTypeDescriptor(Method method) { private ReturnTypeDescriptor(MethodParameter methodParameter) {
this.methodParameter = new MethodParameter(method, -1); this.methodParameter = methodParameter;
this.typeDescriptor = TypeDescriptor.nested(this.methodParameter, 0); this.typeDescriptor = TypeDescriptor.nested(this.methodParameter, 0);
this.nestedTypeDescriptor = TypeDescriptor.nested(this.methodParameter, 1); this.nestedTypeDescriptor = TypeDescriptor.nested(this.methodParameter, 1);
} }
/** /**
* Create a {@link ReturnTypeDescriptor} from a {@link Method}. * Create a {@link ReturnTypeDescriptor} from a {@link MethodParameter}.
* *
* @param method * @param method
* @return * @return
*/ */
public static ReturnTypeDescriptor of(Method method) { public static ReturnTypeDescriptor of(MethodParameter method) {
return new ReturnTypeDescriptor(method); return new ReturnTypeDescriptor(method);
} }

11
src/main/java/org/springframework/data/repository/core/support/QueryExecutorMethodInterceptor.java

@ -21,9 +21,12 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation; import org.aopalliance.intercept.MethodInvocation;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.repository.core.NamedQueries; import org.springframework.data.repository.core.NamedQueries;
@ -55,6 +58,7 @@ class QueryExecutorMethodInterceptor implements MethodInterceptor {
private final RepositoryInformation repositoryInformation; private final RepositoryInformation repositoryInformation;
private final Map<Method, RepositoryQuery> queries; private final Map<Method, RepositoryQuery> queries;
private final Map<Method, RepositoryMethodInvoker> invocationMetadataCache = new ConcurrentReferenceHashMap<>(); private final Map<Method, RepositoryMethodInvoker> invocationMetadataCache = new ConcurrentReferenceHashMap<>();
private final Map<Method, MethodParameter> returnTypeMap = new ConcurrentHashMap<>();
private final QueryExecutionResultHandler resultHandler; private final QueryExecutionResultHandler resultHandler;
private final NamedQueries namedQueries; private final NamedQueries namedQueries;
private final List<QueryCreationListener<?>> queryPostProcessors; private final List<QueryCreationListener<?>> queryPostProcessors;
@ -135,16 +139,17 @@ class QueryExecutorMethodInterceptor implements MethodInterceptor {
public Object invoke(@SuppressWarnings("null") MethodInvocation invocation) throws Throwable { public Object invoke(@SuppressWarnings("null") MethodInvocation invocation) throws Throwable {
Method method = invocation.getMethod(); Method method = invocation.getMethod();
MethodParameter returnType = returnTypeMap.computeIfAbsent(method, it -> new MethodParameter(it, -1));
QueryExecutionConverters.ExecutionAdapter executionAdapter = QueryExecutionConverters // QueryExecutionConverters.ExecutionAdapter executionAdapter = QueryExecutionConverters //
.getExecutionAdapter(method.getReturnType()); .getExecutionAdapter(returnType.getParameterType());
if (executionAdapter == null) { if (executionAdapter == null) {
return resultHandler.postProcessInvocationResult(doInvoke(invocation), method); return resultHandler.postProcessInvocationResult(doInvoke(invocation), returnType);
} }
return executionAdapter // return executionAdapter //
.apply(() -> resultHandler.postProcessInvocationResult(doInvoke(invocation), method)); .apply(() -> resultHandler.postProcessInvocationResult(doInvoke(invocation), returnType));
} }
@Nullable @Nullable

27
src/test/java/org/springframework/data/repository/core/support/QueryExecutionResultHandlerUnitTests.java

@ -26,7 +26,6 @@ import io.vavr.control.Try;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.lang.reflect.Method;
import java.math.BigDecimal; import java.math.BigDecimal;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -40,6 +39,8 @@ import java.util.stream.Collectors;
import org.assertj.core.api.SoftAssertions; import org.assertj.core.api.SoftAssertions;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.springframework.core.MethodParameter;
import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.repository.Repository; import org.springframework.data.repository.Repository;
import org.springframework.data.util.Streamable; import org.springframework.data.util.Streamable;
@ -404,6 +405,17 @@ class QueryExecutionResultHandlerUnitTests {
}); });
} }
@Test // GH-3125
void considersTypeBoundsFromBaseInterface() throws NoSuchMethodException {
var method = CustomizedRepository.class.getMethod("findById", Object.class);
var result = handler.postProcessInvocationResult(Optional.of(new Entity()),
new MethodParameter(method, -1).withContainingClass(CustomizedRepository.class));
assertThat(result).isInstanceOf(Entity.class);
}
@Test // DATACMNS-1552 @Test // DATACMNS-1552
void keepsVavrOptionType() throws Exception { void keepsVavrOptionType() throws Exception {
@ -412,8 +424,17 @@ class QueryExecutionResultHandlerUnitTests {
assertThat(handler.postProcessInvocationResult(source, getMethod("option"))).isSameAs(source); assertThat(handler.postProcessInvocationResult(source, getMethod("option"))).isSameAs(source);
} }
private static Method getMethod(String methodName) throws Exception { private static MethodParameter getMethod(String methodName) throws Exception {
return Sample.class.getMethod(methodName); return new MethodParameter(Sample.class.getMethod(methodName), -1);
}
interface BaseRepository<T, ID> extends Repository<T, ID> {
T findById(ID id);
}
interface CustomizedRepository extends BaseRepository<Entity, Long> {
} }
static interface Sample extends Repository<Entity, Long> { static interface Sample extends Repository<Entity, Long> {

Loading…
Cancel
Save