Browse Source

Fix query execution mode detection for aggregate types that implement Streamable.

We now short-circuit the QueryMethod.isCollectionQuery() algorithm in case we find the concrete domain type or any subclass of it.

Fixes #2869.
pull/2874/head
Oliver Drotbohm 3 years ago
parent
commit
ca9f9bfdc8
No known key found for this signature in database
GPG Key ID: C25FBFA0DA493A1D
  1. 5
      src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java
  2. 13
      src/main/java/org/springframework/data/repository/query/QueryMethod.java
  3. 25
      src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java
  4. 40
      src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java
  5. 44
      src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java

5
src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java

@ -96,12 +96,13 @@ public abstract class AbstractRepositoryMetadata implements RepositoryMetadata { @@ -96,12 +96,13 @@ public abstract class AbstractRepositoryMetadata implements RepositoryMetadata {
return returnType;
}
@Override
public Class<?> getReturnedDomainClass(Method method) {
TypeInformation<?> returnType = getReturnType(method);
returnType = ReactiveWrapperConverters.unwrapWrapperTypes(returnType);
return QueryExecutionConverters.unwrapWrapperTypes(ReactiveWrapperConverters.unwrapWrapperTypes(returnType))
.getType();
return QueryExecutionConverters.unwrapWrapperTypes(returnType, getDomainTypeInformation()).getType();
}
public Class<?> getRepositoryInterface() {

13
src/main/java/org/springframework/data/repository/query/QueryMethod.java

@ -24,16 +24,17 @@ import java.util.stream.Stream; @@ -24,16 +24,17 @@ import java.util.stream.Stream;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Window;
import org.springframework.data.domain.ScrollPosition;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Window;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.repository.core.EntityMetadata;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.util.QueryExecutionConverters;
import org.springframework.data.repository.util.ReactiveWrapperConverters;
import org.springframework.data.util.Lazy;
import org.springframework.data.util.NullableWrapperConverters;
import org.springframework.data.util.ReactiveWrappers;
import org.springframework.data.util.TypeInformation;
import org.springframework.util.Assert;
@ -296,7 +297,15 @@ public class QueryMethod { @@ -296,7 +297,15 @@ public class QueryMethod {
return false;
}
Class<?> returnType = metadata.getReturnType(method).getType();
TypeInformation<?> returnTypeInformation = metadata.getReturnType(method);
// Check against simple wrapper types first
if (metadata.getDomainTypeInformation()
.isAssignableFrom(NullableWrapperConverters.unwrapActualType(returnTypeInformation))) {
return false;
}
Class<?> returnType = returnTypeInformation.getType();
if (QueryExecutionConverters.supports(returnType) && !QueryExecutionConverters.isSingleValue(returnType)) {
return true;

25
src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java

@ -36,8 +36,8 @@ import org.springframework.core.convert.converter.GenericConverter; @@ -36,8 +36,8 @@ import org.springframework.core.convert.converter.GenericConverter;
import org.springframework.core.convert.support.ConfigurableConversionService;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Window;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Window;
import org.springframework.data.geo.GeoResults;
import org.springframework.data.util.CustomCollections;
import org.springframework.data.util.NullableWrapper;
@ -85,6 +85,7 @@ public abstract class QueryExecutionConverters { @@ -85,6 +85,7 @@ public abstract class QueryExecutionConverters {
private static final Set<Class<?>> ALLOWED_PAGEABLE_TYPES = new HashSet<>();
private static final Map<Class<?>, ExecutionAdapter> EXECUTION_ADAPTER = new HashMap<>();
private static final Map<Class<?>, Boolean> supportsCache = new ConcurrentReferenceHashMap<>();
private static final TypeInformation<Void> VOID_INFORMATION = TypeInformation.of(Void.class);
static {
@ -235,15 +236,21 @@ public abstract class QueryExecutionConverters { @@ -235,15 +236,21 @@ public abstract class QueryExecutionConverters {
}
/**
* Recursively unwraps well known wrapper types from the given {@link TypeInformation}.
* Recursively unwraps well known wrapper types from the given {@link TypeInformation} but aborts at the given
* reference type.
*
* @param type must not be {@literal null}.
* @param reference must not be {@literal null}.
* @return will never be {@literal null}.
*/
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type, TypeInformation<?> reference) {
Assert.notNull(type, "type must not be null");
if (reference.isAssignableFrom(type)) {
return type;
}
Class<?> rawType = type.getType();
boolean needToUnwrap = type.isCollectionLike() //
@ -253,7 +260,17 @@ public abstract class QueryExecutionConverters { @@ -253,7 +260,17 @@ public abstract class QueryExecutionConverters {
|| supports(rawType) //
|| Stream.class.isAssignableFrom(rawType);
return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType()) : type;
return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType(), reference) : type;
}
/**
* Recursively unwraps well known wrapper types from the given {@link TypeInformation}.
*
* @param type must not be {@literal null}.
* @return will never be {@literal null}.
*/
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
return unwrapWrapperTypes(type, VOID_INFORMATION);
}
/**

40
src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java

@ -21,14 +21,19 @@ import java.io.Serializable; @@ -21,14 +21,19 @@ import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestFactory;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.querydsl.User;
import org.springframework.data.repository.PagingAndSortingRepository;
import org.springframework.data.repository.Repository;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.util.Streamable;
/**
* Unit tests for {@link AbstractRepositoryMetadata}.
@ -111,6 +116,25 @@ class AbstractRepositoryMetadataUnitTests { @@ -111,6 +116,25 @@ class AbstractRepositoryMetadataUnitTests {
assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(Container.class);
}
@TestFactory // GH-2869
Stream<DynamicTest> detectsReturnTypesForStreamableAggregates() throws Exception {
var metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class);
var methods = Stream.of(
Map.entry("findBy", StreamableAggregate.class),
Map.entry("findSubTypeBy", StreamableAggregateSubType.class),
Map.entry("findAllBy", StreamableAggregate.class),
Map.entry("findOptional", StreamableAggregate.class));
return DynamicTest.stream(methods, //
it -> it.getKey() + "'s returned domain class is " + it.getValue(), //
it -> {
var method = StreamableAggregateRepository.class.getMethod(it.getKey());
assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(it.getValue());
});
}
interface UserRepository extends Repository<User, Long> {
User findSingle();
@ -155,4 +179,20 @@ class AbstractRepositoryMetadataUnitTests { @@ -155,4 +179,20 @@ class AbstractRepositoryMetadataUnitTests {
interface CompletePageableAndSortingRepository extends PagingAndSortingRepository<Container, Long> {}
// GH-2869
static abstract class StreamableAggregate implements Streamable<Object> {}
interface StreamableAggregateRepository extends Repository<StreamableAggregate, Object> {
StreamableAggregate findBy();
StreamableAggregateSubType findSubTypeBy();
Streamable<StreamableAggregate> findAllBy();
Optional<StreamableAggregate> findOptional();
}
static abstract class StreamableAggregateSubType extends StreamableAggregate {}
}

44
src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java

@ -24,12 +24,16 @@ import reactor.core.publisher.Mono; @@ -24,12 +24,16 @@ import reactor.core.publisher.Mono;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.stream.Stream;
import org.eclipse.collections.api.list.ImmutableList;
import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestFactory;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.ScrollPosition;
@ -41,6 +45,7 @@ import org.springframework.data.repository.Repository; @@ -41,6 +45,7 @@ import org.springframework.data.repository.Repository;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.AbstractRepositoryMetadata;
import org.springframework.data.repository.core.support.DefaultRepositoryMetadata;
import org.springframework.data.util.Streamable;
/**
* Unit tests for {@link QueryMethod}.
@ -302,6 +307,28 @@ class QueryMethodUnitTests { @@ -302,6 +307,28 @@ class QueryMethodUnitTests {
assertThat(queryMethod.isCollectionQuery()).isTrue();
}
@TestFactory // GH-2869
Stream<DynamicTest> doesNotConsiderQueryMethodReturningAggregateImplementingStreamableACollectionQuery()
throws Exception {
var metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class);
var stream = Stream.of(
Map.entry("findBy", false),
Map.entry("findSubTypeBy", false),
Map.entry("findAllBy", true),
Map.entry("findOptionalBy", false));
return DynamicTest.stream(stream, //
it -> it.getKey() + " considered collection query -> " + it.getValue(), //
it -> {
var method = StreamableAggregateRepository.class.getMethod(it.getKey());
var queryMethod = new QueryMethod(method, metadata, factory);
assertThat(queryMethod.isCollectionQuery()).isEqualTo(it.getValue());
});
}
interface SampleRepository extends Repository<User, Serializable> {
String pagingMethodWithInvalidReturnType(Pageable pageable);
@ -379,4 +406,21 @@ class QueryMethodUnitTests { @@ -379,4 +406,21 @@ class QueryMethodUnitTests {
interface ContainerRepository extends Repository<Container, Long> {
Container someMethod();
}
// GH-2869
static abstract class StreamableAggregate implements Streamable<Object> {}
interface StreamableAggregateRepository extends Repository<StreamableAggregate, Object> {
StreamableAggregate findBy();
StreamableAggregateSubType findSubTypeBy();
Optional<StreamableAggregate> findOptionalBy();
Streamable<StreamableAggregate> findAllBy();
}
static abstract class StreamableAggregateSubType extends StreamableAggregate {}
}

Loading…
Cancel
Save