diff --git a/src/main/java/org/springframework/data/repository/core/support/QueryExecutionResultHandler.java b/src/main/java/org/springframework/data/repository/core/support/QueryExecutionResultHandler.java index 4988ea9ec..e551d1767 100644 --- a/src/main/java/org/springframework/data/repository/core/support/QueryExecutionResultHandler.java +++ b/src/main/java/org/springframework/data/repository/core/support/QueryExecutionResultHandler.java @@ -17,6 +17,8 @@ package org.springframework.data.repository.core.support; import java.lang.reflect.Method; import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -24,11 +26,11 @@ import org.springframework.core.CollectionFactory; import org.springframework.core.MethodParameter; import org.springframework.core.convert.ConversionService; import org.springframework.core.convert.TypeDescriptor; -import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.convert.support.GenericConversionService; import org.springframework.data.repository.util.NullableWrapper; import org.springframework.data.repository.util.QueryExecutionConverters; import org.springframework.data.repository.util.ReactiveWrapperConverters; +import org.springframework.data.util.Streamable; import org.springframework.lang.Nullable; /** @@ -44,15 +46,15 @@ class QueryExecutionResultHandler { private final GenericConversionService conversionService; + private final Object mutex = new Object(); + + // concurrent access guarded by mutex. + private Map descriptorCache = Collections.emptyMap(); + /** * Creates a new {@link QueryExecutionResultHandler}. */ - public QueryExecutionResultHandler() { - - GenericConversionService conversionService = new DefaultConversionService(); - QueryExecutionConverters.registerConvertersIn(conversionService); - conversionService.removeConvertible(Object.class, Object.class); - + public QueryExecutionResultHandler(GenericConversionService conversionService) { this.conversionService = conversionService; } @@ -70,10 +72,37 @@ class QueryExecutionResultHandler { return result; } - MethodParameter parameter = new MethodParameter(method, -1); + ReturnTypeDescriptor descriptor = getOrCreateReturnTypeDescriptor(method); + + return postProcessInvocationResult(result, 0, descriptor); + } + + private ReturnTypeDescriptor getOrCreateReturnTypeDescriptor(Method method) { + + Map descriptorCache = this.descriptorCache; + ReturnTypeDescriptor descriptor = descriptorCache.get(method); + + if (descriptor == null) { - return postProcessInvocationResult(result, 0, parameter); + descriptor = ReturnTypeDescriptor.of(method); + Map updatedDescriptorCache; + + if (descriptorCache.isEmpty()) { + updatedDescriptorCache = Collections.singletonMap(method, descriptor); + } else { + updatedDescriptorCache = new HashMap<>(descriptorCache.size() + 1, 1); + updatedDescriptorCache.putAll(descriptorCache); + updatedDescriptorCache.put(method, descriptor); + + } + + synchronized (mutex) { + this.descriptorCache = updatedDescriptorCache; + } + } + + return descriptor; } /** @@ -81,13 +110,13 @@ class QueryExecutionResultHandler { * * @param result can be {@literal null}. * @param nestingLevel - * @param parameter must not be {@literal null}. + * @param descriptor must not be {@literal null}. * @return */ @Nullable - Object postProcessInvocationResult(@Nullable Object result, int nestingLevel, MethodParameter parameter) { + Object postProcessInvocationResult(@Nullable Object result, int nestingLevel, ReturnTypeDescriptor descriptor) { - TypeDescriptor returnTypeDescriptor = TypeDescriptor.nested(parameter, nestingLevel); + TypeDescriptor returnTypeDescriptor = descriptor.getReturnTypeDescriptor(nestingLevel); if (returnTypeDescriptor == null) { return result; @@ -100,7 +129,7 @@ class QueryExecutionResultHandler { if (QueryExecutionConverters.supports(expectedReturnType)) { // For a wrapper type, try nested resolution first - result = postProcessInvocationResult(result, nestingLevel + 1, parameter); + result = postProcessInvocationResult(result, nestingLevel + 1, descriptor); if (conversionRequired(WRAPPER_TYPE, returnTypeDescriptor)) { return conversionService.convert(new NullableWrapper(result), returnTypeDescriptor); @@ -122,7 +151,18 @@ class QueryExecutionResultHandler { return ReactiveWrapperConverters.toWrapper(result, expectedReturnType); } - return conversionService.canConvert(TypeDescriptor.forObject(result), returnTypeDescriptor) + if (result instanceof Collection) { + + TypeDescriptor elementDescriptor = descriptor.getReturnTypeDescriptor(nestingLevel + 1); + boolean requiresConversion = requiresConversion((Collection) result, expectedReturnType, elementDescriptor); + + if (!requiresConversion) { + return result; + } + } + + TypeDescriptor resultDescriptor = TypeDescriptor.forObject(result); + return conversionService.canConvert(resultDescriptor, returnTypeDescriptor) ? conversionService.convert(result, returnTypeDescriptor) : result; } @@ -130,6 +170,29 @@ class QueryExecutionResultHandler { return Map.class.equals(expectedReturnType) // ? CollectionFactory.createMap(expectedReturnType, 0) // : null; + + } + private boolean requiresConversion(Collection collection, Class expectedReturnType, + @Nullable TypeDescriptor elementDescriptor) { + + if (Streamable.class.isAssignableFrom(expectedReturnType) || !expectedReturnType.isInstance(collection)) { + return true; + } + + if (elementDescriptor == null || !Iterable.class.isAssignableFrom(expectedReturnType)) { + return false; + } + + Class type = elementDescriptor.getType(); + + for (Object o : collection) { + + if (!type.isInstance(o)) { + return true; + } + } + + return false; } /** @@ -178,4 +241,54 @@ class QueryExecutionResultHandler { || source == null // || Collection.class.isInstance(source); } + + /** + * Value object capturing {@link MethodParameter} and {@link TypeDescriptor}s for top and nested levels. + */ + static class ReturnTypeDescriptor { + + private final MethodParameter methodParameter; + private final TypeDescriptor typeDescriptor; + private final @Nullable TypeDescriptor nestedTypeDescriptor; + + private ReturnTypeDescriptor(Method method) { + this.methodParameter = new MethodParameter(method, -1); + this.typeDescriptor = TypeDescriptor.nested(this.methodParameter, 0); + this.nestedTypeDescriptor = TypeDescriptor.nested(this.methodParameter, 1); + } + + /** + * Create a {@link ReturnTypeDescriptor} from a {@link Method}. + * + * @param method + * @return + */ + public static ReturnTypeDescriptor of(Method method) { + return new ReturnTypeDescriptor(method); + } + + /** + * Return the {@link TypeDescriptor} for a nested type declared within the method parameter described by + * {@code nestingLevel} . + * + * @param nestingLevel the nesting level. {@code 0} is the first level, {@code 1} the next inner one. + * @return the {@link TypeDescriptor} or {@literal null} if it could not be obtained. + * @see TypeDescriptor#nested(MethodParameter, int) + */ + @Nullable + public TypeDescriptor getReturnTypeDescriptor(int nestingLevel) { + + // optimizing for nesting level 0 and 1 (Optional, List) + // nesting level 2 (Optional>) uses the slow path. + + switch (nestingLevel) { + case 0: + return typeDescriptor; + case 1: + return nestedTypeDescriptor; + default: + return TypeDescriptor.nested(this.methodParameter, nestingLevel); + } + } + } } diff --git a/src/main/java/org/springframework/data/repository/core/support/RepositoryFactorySupport.java b/src/main/java/org/springframework/data/repository/core/support/RepositoryFactorySupport.java index 54f84a2af..877a4f2d3 100644 --- a/src/main/java/org/springframework/data/repository/core/support/RepositoryFactorySupport.java +++ b/src/main/java/org/springframework/data/repository/core/support/RepositoryFactorySupport.java @@ -44,6 +44,8 @@ import org.springframework.beans.factory.BeanClassLoaderAware; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.core.ResolvableType; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.convert.support.GenericConversionService; import org.springframework.data.projection.DefaultMethodInvokingMethodInterceptor; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; @@ -114,6 +116,13 @@ public abstract class RepositoryFactorySupport implements BeanClassLoaderAware, return o; }; + final static GenericConversionService CONVERSION_SERVICE = new DefaultConversionService(); + + static { + QueryExecutionConverters.registerConvertersIn(CONVERSION_SERVICE); + CONVERSION_SERVICE.removeConvertible(Object.class, Object.class); + } + private final Map repositoryInformationCache; private final List postProcessors; @@ -534,7 +543,7 @@ public abstract class RepositoryFactorySupport implements BeanClassLoaderAware, public QueryExecutorMethodInterceptor(RepositoryInformation repositoryInformation, ProjectionFactory projectionFactory) { - this.resultHandler = new QueryExecutionResultHandler(); + this.resultHandler = new QueryExecutionResultHandler(CONVERSION_SERVICE); Optional lookupStrategy = getQueryLookupStrategy(queryLookupStrategyKey, RepositoryFactorySupport.this.evaluationContextProvider); diff --git a/src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java b/src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java index 793fd05f9..6a2d4065e 100644 --- a/src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java +++ b/src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java @@ -178,7 +178,7 @@ public abstract class QueryExecutionConverters { return SUPPORTS_CACHE.computeIfAbsent(type, key -> { for (WrapperType candidate : WRAPPER_TYPES) { - if (candidate.getType().isAssignableFrom(type)) { + if (candidate.getType().isAssignableFrom(key)) { return true; } } diff --git a/src/test/java/org/springframework/data/repository/core/support/QueryExecutionResultHandlerUnitTests.java b/src/test/java/org/springframework/data/repository/core/support/QueryExecutionResultHandlerUnitTests.java index df7ebed45..1e64a7093 100755 --- a/src/test/java/org/springframework/data/repository/core/support/QueryExecutionResultHandlerUnitTests.java +++ b/src/test/java/org/springframework/data/repository/core/support/QueryExecutionResultHandlerUnitTests.java @@ -54,7 +54,7 @@ import org.springframework.data.util.Streamable; */ public class QueryExecutionResultHandlerUnitTests { - QueryExecutionResultHandler handler = new QueryExecutionResultHandler(); + QueryExecutionResultHandler handler = new QueryExecutionResultHandler(RepositoryFactorySupport.CONVERSION_SERVICE); @Test // DATACMNS-610 public void convertsListsToSet() throws Exception {