Browse Source

Use projecting read callback to allow interface projections.

Along the lines fix entity operations proxy handling by reading the underlying map instead of inspecting the proxy interface.
Also make sure to map potential raw fields back to the according property.

See: #4308
Original Pull Request: #4317
pull/4334/head
Christoph Strobl 3 years ago
parent
commit
d8c04f0ec9
No known key found for this signature in database
GPG Key ID: 8CC1AB53391458C8
  1. 77
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java
  2. 5
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java
  3. 12
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java
  4. 5
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ScrollUtils.java
  5. 40
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/EntityOperationsUnitTests.java
  6. 239
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateScrollTests.java
  7. 95
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateScrollTests.java

77
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java

@ -30,6 +30,8 @@ import org.springframework.data.mapping.IdentifierAccessor;
import org.springframework.data.mapping.MappingException; import org.springframework.data.mapping.MappingException;
import org.springframework.data.mapping.PersistentEntity; import org.springframework.data.mapping.PersistentEntity;
import org.springframework.data.mapping.PersistentPropertyAccessor; import org.springframework.data.mapping.PersistentPropertyAccessor;
import org.springframework.data.mapping.PersistentPropertyPath;
import org.springframework.data.mapping.PropertyPath;
import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.mapping.model.ConvertingPropertyAccessor; import org.springframework.data.mapping.model.ConvertingPropertyAccessor;
import org.springframework.data.mongodb.core.CollectionOptions.TimeSeriesOptions; import org.springframework.data.mongodb.core.CollectionOptions.TimeSeriesOptions;
@ -50,6 +52,7 @@ import org.springframework.data.mongodb.util.BsonUtils;
import org.springframework.data.projection.EntityProjection; import org.springframework.data.projection.EntityProjection;
import org.springframework.data.projection.EntityProjectionIntrospector; import org.springframework.data.projection.EntityProjectionIntrospector;
import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.projection.TargetAware;
import org.springframework.data.util.Optionals; import org.springframework.data.util.Optionals;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -117,12 +120,16 @@ class EntityOperations {
Assert.notNull(entity, "Bean must not be null"); Assert.notNull(entity, "Bean must not be null");
if (entity instanceof TargetAware targetAware) {
return new SimpleMappedEntity((Map<String, Object>) targetAware.getTarget(), this);
}
if (entity instanceof String) { if (entity instanceof String) {
return new UnmappedEntity(parse(entity.toString())); return new UnmappedEntity(parse(entity.toString()), this);
} }
if (entity instanceof Map) { if (entity instanceof Map) {
return new SimpleMappedEntity((Map<String, Object>) entity); return new SimpleMappedEntity((Map<String, Object>) entity, this);
} }
return MappedEntity.of(entity, context, this); return MappedEntity.of(entity, context, this);
@ -142,11 +149,11 @@ class EntityOperations {
Assert.notNull(conversionService, "ConversionService must not be null"); Assert.notNull(conversionService, "ConversionService must not be null");
if (entity instanceof String) { if (entity instanceof String) {
return new UnmappedEntity(parse(entity.toString())); return new UnmappedEntity(parse(entity.toString()), this);
} }
if (entity instanceof Map) { if (entity instanceof Map) {
return new SimpleMappedEntity((Map<String, Object>) entity); return new SimpleMappedEntity((Map<String, Object>) entity, this);
} }
return AdaptibleMappedEntity.of(entity, context, conversionService, this); return AdaptibleMappedEntity.of(entity, context, conversionService, this);
@ -287,7 +294,8 @@ class EntityOperations {
*/ */
public <M, D> EntityProjection<M, D> introspectProjection(Class<M> resultType, Class<D> entityType) { public <M, D> EntityProjection<M, D> introspectProjection(Class<M> resultType, Class<D> entityType) {
if (!queryMapper.getMappingContext().hasPersistentEntityFor(entityType)) { MongoPersistentEntity<?> persistentEntity = queryMapper.getMappingContext().getPersistentEntity(entityType);
if (persistentEntity == null && !resultType.isInterface() || ClassUtils.isAssignable(Document.class, resultType)) {
return (EntityProjection) EntityProjection.nonProjecting(resultType); return (EntityProjection) EntityProjection.nonProjecting(resultType);
} }
return introspector.introspect(resultType, entityType); return introspector.introspect(resultType, entityType);
@ -369,6 +377,7 @@ class EntityOperations {
* A representation of information about an entity. * A representation of information about an entity.
* *
* @author Oliver Gierke * @author Oliver Gierke
* @author Christoph Strobl
* @since 2.1 * @since 2.1
*/ */
interface Entity<T> { interface Entity<T> {
@ -471,10 +480,10 @@ class EntityOperations {
/** /**
* @param sortObject * @param sortObject
* @return * @return
* @since 3.1 * @since 4.1
* @throws IllegalStateException if a sort key yields {@literal null}. * @throws IllegalStateException if a sort key yields {@literal null}.
*/ */
Map<String, Object> extractKeys(Document sortObject); Map<String, Object> extractKeys(Document sortObject, Class<?> sourceType);
} }
@ -523,9 +532,11 @@ class EntityOperations {
private static class UnmappedEntity<T extends Map<String, Object>> implements AdaptibleEntity<T> { private static class UnmappedEntity<T extends Map<String, Object>> implements AdaptibleEntity<T> {
private final T map; private final T map;
private final EntityOperations entityOperations;
protected UnmappedEntity(T map) { protected UnmappedEntity(T map, EntityOperations entityOperations) {
this.map = map; this.map = map;
this.entityOperations = entityOperations;
} }
@Override @Override
@ -596,13 +607,19 @@ class EntityOperations {
} }
@Override @Override
public Map<String, Object> extractKeys(Document sortObject) { public Map<String, Object> extractKeys(Document sortObject, Class<?> sourceType) {
Map<String, Object> keyset = new LinkedHashMap<>(); Map<String, Object> keyset = new LinkedHashMap<>();
keyset.put(ID_FIELD, getId()); MongoPersistentEntity<?> sourceEntity = entityOperations.context.getPersistentEntity(sourceType);
if (sourceEntity != null && sourceEntity.hasIdProperty()) {
keyset.put(sourceEntity.getRequiredIdProperty().getName(), getId());
} else {
keyset.put(ID_FIELD, getId());
}
for (String key : sortObject.keySet()) { for (String key : sortObject.keySet()) {
Object value = BsonUtils.resolveValue(map, key);
Object value = resolveValue(key, sourceEntity);
if (value == null) { if (value == null) {
throw new IllegalStateException( throw new IllegalStateException(
@ -614,12 +631,24 @@ class EntityOperations {
return keyset; return keyset;
} }
@Nullable
private Object resolveValue(String key, @Nullable MongoPersistentEntity<?> sourceEntity) {
if (sourceEntity == null) {
return BsonUtils.resolveValue(map, key);
}
PropertyPath from = PropertyPath.from(key, sourceEntity.getTypeInformation());
PersistentPropertyPath<MongoPersistentProperty> persistentPropertyPath = entityOperations.context
.getPersistentPropertyPath(from);
return BsonUtils.resolveValue(map, persistentPropertyPath.toDotPath(p -> p.getFieldName()));
}
} }
private static class SimpleMappedEntity<T extends Map<String, Object>> extends UnmappedEntity<T> { private static class SimpleMappedEntity<T extends Map<String, Object>> extends UnmappedEntity<T> {
protected SimpleMappedEntity(T map) { protected SimpleMappedEntity(T map, EntityOperations entityOperations) {
super(map); super(map, entityOperations);
} }
@Override @Override
@ -758,10 +787,15 @@ class EntityOperations {
} }
@Override @Override
public Map<String, Object> extractKeys(Document sortObject) { public Map<String, Object> extractKeys(Document sortObject, Class<?> sourceType) {
Map<String, Object> keyset = new LinkedHashMap<>(); Map<String, Object> keyset = new LinkedHashMap<>();
keyset.put(entity.getRequiredIdProperty().getName(), getId()); MongoPersistentEntity<?> sourceEntity = entityOperations.context.getPersistentEntity(sourceType);
if (sourceEntity != null && sourceEntity.hasIdProperty()) {
keyset.put(sourceEntity.getRequiredIdProperty().getName(), getId());
} else {
keyset.put(entity.getRequiredIdProperty().getName(), getId());
}
for (String key : sortObject.keySet()) { for (String key : sortObject.keySet()) {
@ -933,6 +967,14 @@ class EntityOperations {
* @since 3.3 * @since 3.3
*/ */
TimeSeriesOptions mapTimeSeriesOptions(TimeSeriesOptions options); TimeSeriesOptions mapTimeSeriesOptions(TimeSeriesOptions options);
/**
* @return the name of the id field.
* @since 4.1
*/
default String getIdKeyName() {
return ID_FIELD;
}
} }
/** /**
@ -1055,6 +1097,11 @@ class EntityOperations {
MongoPersistentProperty persistentProperty = entity.getPersistentProperty(name); MongoPersistentProperty persistentProperty = entity.getPersistentProperty(name);
return persistentProperty != null ? persistentProperty.getFieldName() : name; return persistentProperty != null ? persistentProperty.getFieldName() : name;
} }
@Override
public String getIdKeyName() {
return entity.getIdProperty().getName();
}
} }
} }

5
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java

@ -870,7 +870,8 @@ public class MongoTemplate
Assert.notNull(sourceClass, "Entity type must not be null"); Assert.notNull(sourceClass, "Entity type must not be null");
Assert.notNull(targetClass, "Target type must not be null"); Assert.notNull(targetClass, "Target type must not be null");
ReadDocumentCallback<T> callback = new ReadDocumentCallback<>(mongoConverter, targetClass, collectionName); EntityProjection<T, ?> projection = operations.introspectProjection(targetClass, sourceClass);
ProjectingReadCallback<?,T> callback = new ProjectingReadCallback<>(mongoConverter, projection, collectionName);
int limit = query.isLimited() ? query.getLimit() + 1 : Integer.MAX_VALUE; int limit = query.isLimited() ? query.getLimit() + 1 : Integer.MAX_VALUE;
if (query.hasKeyset()) { if (query.hasKeyset()) {
@ -882,7 +883,7 @@ public class MongoTemplate
keysetPaginationQuery.fields(), sourceClass, keysetPaginationQuery.fields(), sourceClass,
new QueryCursorPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass), callback); new QueryCursorPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass), callback);
return ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), result, operations); return ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), result, sourceClass, operations);
} }
List<T> result = doFind(collectionName, createDelegate(query), query.getQueryObject(), query.getFieldsObject(), List<T> result = doFind(collectionName, createDelegate(query), query.getQueryObject(), query.getFieldsObject(),

12
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java

@ -849,6 +849,8 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati
Assert.notNull(sourceClass, "Entity type must not be null"); Assert.notNull(sourceClass, "Entity type must not be null");
Assert.notNull(targetClass, "Target type must not be null"); Assert.notNull(targetClass, "Target type must not be null");
EntityProjection<T, ?> projection = operations.introspectProjection(targetClass, sourceClass);
ProjectingReadCallback<?,T> callback = new ProjectingReadCallback<>(mongoConverter, projection, collectionName);
int limit = query.isLimited() ? query.getLimit() + 1 : Integer.MAX_VALUE; int limit = query.isLimited() ? query.getLimit() + 1 : Integer.MAX_VALUE;
if (query.hasKeyset()) { if (query.hasKeyset()) {
@ -857,15 +859,15 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati
operations.getIdPropertyName(sourceClass)); operations.getIdPropertyName(sourceClass));
Mono<List<T>> result = doFind(collectionName, ReactiveCollectionPreparerDelegate.of(query), Mono<List<T>> result = doFind(collectionName, ReactiveCollectionPreparerDelegate.of(query),
keysetPaginationQuery.query(), keysetPaginationQuery.fields(), targetClass, keysetPaginationQuery.query(), keysetPaginationQuery.fields(), sourceClass,
new QueryFindPublisherPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass)).collectList(); new QueryFindPublisherPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass), callback).collectList();
return result.map(it -> ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), it, operations)); return result.map(it -> ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), it, sourceClass, operations));
} }
Mono<List<T>> result = doFind(collectionName, ReactiveCollectionPreparerDelegate.of(query), query.getQueryObject(), Mono<List<T>> result = doFind(collectionName, ReactiveCollectionPreparerDelegate.of(query), query.getQueryObject(),
query.getFieldsObject(), targetClass, query.getFieldsObject(), sourceClass,
new QueryFindPublisherPreparer(query, query.getSortObject(), limit, query.getSkip(), sourceClass)) new QueryFindPublisherPreparer(query, query.getSortObject(), limit, query.getSkip(), sourceClass), callback)
.collectList(); .collectList();
return result.map( return result.map(

5
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ScrollUtils.java

@ -121,14 +121,15 @@ class ScrollUtils {
return sortOrder == 1 ? "$gt" : "$lt"; return sortOrder == 1 ? "$gt" : "$lt";
} }
static <T> Window<T> createWindow(Document sortObject, int limit, List<T> result, EntityOperations operations) { static <T> Window<T> createWindow(Document sortObject, int limit, List<T> result, Class<?> sourceType,
EntityOperations operations) {
IntFunction<KeysetScrollPosition> positionFunction = value -> { IntFunction<KeysetScrollPosition> positionFunction = value -> {
T last = result.get(value); T last = result.get(value);
Entity<T> entity = operations.forEntity(last); Entity<T> entity = operations.forEntity(last);
Map<String, Object> keys = entity.extractKeys(sortObject); Map<String, Object> keys = entity.extractKeys(sortObject, sourceType);
return KeysetScrollPosition.of(keys); return KeysetScrollPosition.of(keys);
}; };

40
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/EntityOperationsUnitTests.java

@ -33,11 +33,13 @@ import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
import org.springframework.data.mongodb.core.mapping.TimeSeries; import org.springframework.data.mongodb.core.mapping.TimeSeries;
import org.springframework.data.mongodb.test.util.MongoTestMappingContext; import org.springframework.data.mongodb.test.util.MongoTestMappingContext;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
/** /**
* Unit tests for {@link EntityOperations}. * Unit tests for {@link EntityOperations}.
* *
* @author Mark Paluch * @author Mark Paluch
* @author Christoph Strobl
*/ */
class EntityOperationsUnitTests { class EntityOperationsUnitTests {
@ -70,7 +72,8 @@ class EntityOperationsUnitTests {
WithNestedDocument object = new WithNestedDocument("foo"); WithNestedDocument object = new WithNestedDocument("foo");
Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("id", 1)); Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("id", 1),
WithNestedDocument.class);
assertThat(keys).containsEntry("id", "foo"); assertThat(keys).containsEntry("id", "foo");
} }
@ -80,7 +83,7 @@ class EntityOperationsUnitTests {
Document object = new Document("id", "foo"); Document object = new Document("id", "foo");
Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("id", 1)); Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("id", 1), Document.class);
assertThat(keys).containsEntry("id", "foo"); assertThat(keys).containsEntry("id", "foo");
} }
@ -90,7 +93,8 @@ class EntityOperationsUnitTests {
WithNestedDocument object = new WithNestedDocument("foo", new WithNestedDocument("bar"), null); WithNestedDocument object = new WithNestedDocument("foo", new WithNestedDocument("bar"), null);
Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("nested.id", 1)); Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("nested.id", 1),
WithNestedDocument.class);
assertThat(keys).containsEntry("nested.id", "bar"); assertThat(keys).containsEntry("nested.id", "bar");
} }
@ -101,7 +105,8 @@ class EntityOperationsUnitTests {
WithNestedDocument object = new WithNestedDocument("foo", new WithNestedDocument("bar"), WithNestedDocument object = new WithNestedDocument("foo", new WithNestedDocument("bar"),
new Document("john", "doe")); new Document("john", "doe"));
Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("document.john", 1)); Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("document.john", 1),
WithNestedDocument.class);
assertThat(keys).containsEntry("document.john", "doe"); assertThat(keys).containsEntry("document.john", "doe");
} }
@ -111,11 +116,32 @@ class EntityOperationsUnitTests {
Document object = new Document("document", new Document("john", "doe")); Document object = new Document("document", new Document("john", "doe"));
Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("document.john", 1)); Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("document.john", 1),
Document.class);
assertThat(keys).containsEntry("document.john", "doe"); assertThat(keys).containsEntry("document.john", "doe");
} }
@Test // GH-4308
void shouldExtractIdPropertyNameFromRawDocument() {
Document object = new Document("_id", "id-1").append("value", "val");
Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("value", 1), DomainTypeWithIdProperty.class);
assertThat(keys).containsEntry("id", "id-1");
}
@Test // GH-4308
void shouldExtractValuesFromProxy() {
ProjectionInterface source = new SpelAwareProxyProjectionFactory().createProjection(ProjectionInterface.class, new Document("_id", "id-1").append("value", "val"));
Map<String, Object> keys = operations.forEntity(source).extractKeys(new Document("value", 1), DomainTypeWithIdProperty.class);
assertThat(keys).isEqualTo(new Document("id", "id-1").append("value", "val"));
}
<T> EntityOperations.AdaptibleEntity<T> initAdaptibleEntity(T source) { <T> EntityOperations.AdaptibleEntity<T> initAdaptibleEntity(T source) {
return operations.forEntity(source, conversionService); return operations.forEntity(source, conversionService);
} }
@ -150,4 +176,8 @@ class EntityOperationsUnitTests {
this.id = id; this.id = id;
} }
} }
interface ProjectionInterface {
String getValue();
}
} }

239
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateScrollTests.java

@ -18,10 +18,13 @@ package org.springframework.data.mongodb.core;
import static org.assertj.core.api.Assertions.*; import static org.assertj.core.api.Assertions.*;
import static org.springframework.data.mongodb.core.query.Criteria.*; import static org.springframework.data.mongodb.core.query.Criteria.*;
import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.lang.reflect.Proxy;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -44,10 +47,13 @@ import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Window; import org.springframework.data.domain.Window;
import org.springframework.data.mapping.context.PersistentEntities; import org.springframework.data.mapping.context.PersistentEntities;
import org.springframework.data.mongodb.core.MongoTemplateTests.PersonWithIdPropertyOfTypeUUIDListener; import org.springframework.data.mongodb.core.MongoTemplateTests.PersonWithIdPropertyOfTypeUUIDListener;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.test.util.Client; import org.springframework.data.mongodb.test.util.Client;
import org.springframework.data.mongodb.test.util.MongoClientExtension; import org.springframework.data.mongodb.test.util.MongoClientExtension;
import org.springframework.data.mongodb.test.util.MongoTestTemplate; import org.springframework.data.mongodb.test.util.MongoTestTemplate;
import org.springframework.lang.Nullable;
import org.springframework.util.ObjectUtils;
import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClient;
@ -90,10 +96,22 @@ class MongoTemplateScrollTests {
}); });
}); });
private static int compareProxies(PersonInterfaceProjection actual, PersonInterfaceProjection expected) {
if (actual.getAge() != expected.getAge()) {
return -1;
}
if (!ObjectUtils.nullSafeEquals(actual.getFirstName(), expected.getFirstName())) {
return -1;
}
return 0;
}
@BeforeEach @BeforeEach
void setUp() { void setUp() {
template.remove(Person.class).all(); template.remove(Person.class).all();
template.remove(WithNestedDocument.class).all(); template.remove(WithNestedDocument.class).all();
template.remove(WithRenamedField.class).all();
} }
@Test // GH-4308 @Test // GH-4308
@ -112,19 +130,19 @@ class MongoTemplateScrollTests {
.limit(2); .limit(2);
q.with(KeysetScrollPosition.initial()); q.with(KeysetScrollPosition.initial());
Window<WithNestedDocument> scroll = template.scroll(q, WithNestedDocument.class); Window<WithNestedDocument> window = template.scroll(q, WithNestedDocument.class);
assertThat(scroll.hasNext()).isTrue(); assertThat(window.hasNext()).isTrue();
assertThat(scroll.isLast()).isFalse(); assertThat(window.isLast()).isFalse();
assertThat(scroll).hasSize(2); assertThat(window).hasSize(2);
assertThat(scroll).containsOnly(john20, john40); assertThat(window).containsOnly(john20, john40);
scroll = template.scroll(q.with(scroll.positionAt(scroll.size() - 1)), WithNestedDocument.class); window = template.scroll(q.with(window.positionAt(window.size() - 1)), WithNestedDocument.class);
assertThat(scroll.hasNext()).isFalse(); assertThat(window.hasNext()).isFalse();
assertThat(scroll.isLast()).isTrue(); assertThat(window.isLast()).isTrue();
assertThat(scroll).hasSize(1); assertThat(window).hasSize(1);
assertThat(scroll).containsOnly(john41); assertThat(window).containsOnly(john41);
} }
@Test // GH-4308 @Test // GH-4308
@ -162,35 +180,35 @@ class MongoTemplateScrollTests {
.limit(2); .limit(2);
q.with(KeysetScrollPosition.initial()); q.with(KeysetScrollPosition.initial());
Window<WithNestedDocument> scroll = template.scroll(q, WithNestedDocument.class); Window<WithNestedDocument> window = template.scroll(q, WithNestedDocument.class);
assertThat(scroll.hasNext()).isTrue(); assertThat(window.hasNext()).isTrue();
assertThat(scroll.isLast()).isFalse(); assertThat(window.isLast()).isFalse();
assertThat(scroll).hasSize(2); assertThat(window).hasSize(2);
assertThat(scroll).containsOnly(john20, john40); assertThat(window).containsOnly(john20, john40);
scroll = template.scroll(q.with(scroll.positionAt(scroll.size() - 1)), WithNestedDocument.class); window = template.scroll(q.with(window.positionAt(window.size() - 1)), WithNestedDocument.class);
assertThat(scroll.hasNext()).isFalse(); assertThat(window.hasNext()).isFalse();
assertThat(scroll.isLast()).isTrue(); assertThat(window.isLast()).isTrue();
assertThat(scroll).hasSize(1); assertThat(window).hasSize(1);
assertThat(scroll).containsOnly(john41); assertThat(window).containsOnly(john41);
KeysetScrollPosition scrollPosition = (KeysetScrollPosition) scroll.positionAt(0); KeysetScrollPosition scrollPosition = (KeysetScrollPosition) window.positionAt(0);
KeysetScrollPosition reversePosition = KeysetScrollPosition.of(scrollPosition.getKeys(), Direction.Backward); KeysetScrollPosition reversePosition = KeysetScrollPosition.of(scrollPosition.getKeys(), Direction.Backward);
scroll = template.scroll(q.with(reversePosition), WithNestedDocument.class); window = template.scroll(q.with(reversePosition), WithNestedDocument.class);
assertThat(scroll.hasNext()).isTrue(); assertThat(window.hasNext()).isTrue();
assertThat(scroll.isLast()).isFalse(); assertThat(window.isLast()).isFalse();
assertThat(scroll).hasSize(2); assertThat(window).hasSize(2);
assertThat(scroll).containsOnly(john20, john40); assertThat(window).containsOnly(john20, john40);
} }
@ParameterizedTest // GH-4308 @ParameterizedTest // GH-4308
@MethodSource("positions") @MethodSource("positions")
public <T> void shouldApplyCursoringCorrectly(ScrollPosition scrollPosition, Class<T> resultType, public <T> void shouldApplyCursoringCorrectly(ScrollPosition scrollPosition, Class<T> resultType,
Function<Person, T> assertionConverter) { Function<Person, T> assertionConverter, @Nullable Comparator<T> comparator) {
Person john20 = new Person("John", 20); Person john20 = new Person("John", 20);
Person john40_1 = new Person("John", 40); Person john40_1 = new Person("John", 40);
@ -201,53 +219,182 @@ class MongoTemplateScrollTests {
template.insertAll(Arrays.asList(john20, john40_1, john40_2, jane_20, jane_40, jane_42)); template.insertAll(Arrays.asList(john20, john40_1, john40_2, jane_20, jane_40, jane_42));
Query q = new Query(where("firstName").regex("J.*")).with(Sort.by("firstName", "age")).limit(2); Query q = new Query(where("firstName").regex("J.*")).with(Sort.by("firstName", "age")).limit(2);
q.with(scrollPosition);
Window<T> scroll = template.scroll(q, resultType, "person"); Window<T> window = template.query(Person.class).inCollection("person").as(resultType).matching(q)
.scroll(scrollPosition);
assertThat(window.hasNext()).isTrue();
assertThat(window.isLast()).isFalse();
assertThat(window).hasSize(2);
assertWindow(window, comparator).containsOnly(assertionConverter.apply(jane_20), assertionConverter.apply(jane_40));
window = template.query(Person.class).inCollection("person").as(resultType).matching(q.limit(3))
.scroll(window.positionAt(window.size() - 1));
assertThat(window.hasNext()).isTrue();
assertThat(window.isLast()).isFalse();
assertThat(window).hasSize(3);
assertWindow(window, comparator).contains(assertionConverter.apply(jane_42), assertionConverter.apply(john20));
assertWindow(window, comparator).containsAnyOf(assertionConverter.apply(john40_1),
assertionConverter.apply(john40_2));
window = template.query(Person.class).inCollection("person").as(resultType).matching(q.limit(1))
.scroll(window.positionAt(window.size() - 1));
assertThat(window.hasNext()).isFalse();
assertThat(window.isLast()).isTrue();
assertThat(window).hasSize(1);
assertWindow(window, comparator).containsAnyOf(assertionConverter.apply(john40_1),
assertionConverter.apply(john40_2));
}
@ParameterizedTest // GH-4308
@MethodSource("renamedFieldProjectTargets")
<T> void scrollThroughResultsWithRenamedField(Class<T> resultType, Function<WithRenamedField, T> assertionConverter) {
WithRenamedField one = new WithRenamedField("id-1", "v1", null);
WithRenamedField two = new WithRenamedField("id-2", "v2", null);
WithRenamedField three = new WithRenamedField("id-3", "v3", null);
assertThat(scroll.hasNext()).isTrue(); template.insertAll(Arrays.asList(one, two, three));
assertThat(scroll.isLast()).isFalse();
assertThat(scroll).hasSize(2);
assertThat(scroll).containsOnly(assertionConverter.apply(jane_20), assertionConverter.apply(jane_40));
scroll = template.scroll(q.with(scroll.positionAt(scroll.size() - 1)).limit(3), resultType, "person"); Query q = new Query(where("value").regex("v.*")).with(Sort.by(Sort.Direction.DESC, "value")).limit(2);
q.with(KeysetScrollPosition.initial());
Window<T> window = template.query(WithRenamedField.class).as(resultType).matching(q)
.scroll(KeysetScrollPosition.initial());
assertThat(scroll.hasNext()).isTrue(); assertThat(window.hasNext()).isTrue();
assertThat(scroll.isLast()).isFalse(); assertThat(window.isLast()).isFalse();
assertThat(scroll).hasSize(3); assertThat(window).hasSize(2);
assertThat(scroll).contains(assertionConverter.apply(jane_42), assertionConverter.apply(john20)); assertThat(window).containsOnly(assertionConverter.apply(three), assertionConverter.apply(two));
assertThat(scroll).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2));
scroll = template.scroll(q.with(scroll.positionAt(scroll.size() - 1)).limit(1), resultType, "person"); window = template.query(WithRenamedField.class).as(resultType).matching(q)
.scroll(window.positionAt(window.size() - 1));
assertThat(scroll.hasNext()).isFalse(); assertThat(window.hasNext()).isFalse();
assertThat(scroll.isLast()).isTrue(); assertThat(window.isLast()).isTrue();
assertThat(scroll).hasSize(1); assertThat(window).hasSize(1);
assertThat(scroll).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); assertThat(window).containsOnly(assertionConverter.apply(one));
} }
static Stream<Arguments> positions() { static Stream<Arguments> positions() {
return Stream.of(args(KeysetScrollPosition.initial(), Person.class, Function.identity()), // return Stream.of(args(KeysetScrollPosition.initial(), Person.class, Function.identity()), //
args(KeysetScrollPosition.initial(), Document.class, MongoTemplateScrollTests::toDocument), // args(KeysetScrollPosition.initial(), Document.class, MongoTemplateScrollTests::toDocument), //
args(OffsetScrollPosition.initial(), Person.class, Function.identity())); args(OffsetScrollPosition.initial(), Person.class, Function.identity()), //
args(OffsetScrollPosition.initial(), PersonDtoProjection.class,
MongoTemplateScrollTests::toPersonDtoProjection), //
args(OffsetScrollPosition.initial(), PersonInterfaceProjection.class,
MongoTemplateScrollTests::toPersonInterfaceProjection, MongoTemplateScrollTests::compareProxies));
}
static Stream<Arguments> renamedFieldProjectTargets() {
return Stream.of(Arguments.of(WithRenamedField.class, Function.identity()),
Arguments.of(Document.class, new Function<WithRenamedField, Document>() {
@Override
public Document apply(WithRenamedField withRenamedField) {
return new Document("_id", withRenamedField.getId()).append("_val", withRenamedField.getValue())
.append("_class", WithRenamedField.class.getName());
}
}));
}
static <T> org.assertj.core.api.IterableAssert<T> assertWindow(Window<T> window, @Nullable Comparator<T> comparator) {
return comparator != null ? assertThat(window).usingElementComparator(comparator) : assertThat(window);
} }
private static <T> Arguments args(ScrollPosition scrollPosition, Class<T> resultType, private static <T> Arguments args(ScrollPosition scrollPosition, Class<T> resultType,
Function<Person, T> assertionConverter) { Function<Person, T> assertionConverter) {
return Arguments.of(scrollPosition, resultType, assertionConverter); return args(scrollPosition, resultType, assertionConverter, null);
}
private static <T> Arguments args(ScrollPosition scrollPosition, Class<T> resultType,
Function<Person, T> assertionConverter, @Nullable Comparator<T> comparator) {
return Arguments.of(scrollPosition, resultType, assertionConverter, comparator);
} }
static Document toDocument(Person person) { static Document toDocument(Person person) {
return new Document("_class", person.getClass().getName()).append("_id", person.getId()).append("active", true) return new Document("_class", person.getClass().getName()).append("_id", person.getId()).append("active", true)
.append("firstName", person.getFirstName()).append("age", person.getAge()); .append("firstName", person.getFirstName()).append("age", person.getAge());
} }
static PersonDtoProjection toPersonDtoProjection(Person person) {
PersonDtoProjection dto = new PersonDtoProjection();
dto.firstName = person.getFirstName();
dto.age = person.getAge();
return dto;
}
static PersonInterfaceProjection toPersonInterfaceProjection(Person person) {
return new PersonInterfaceProjectionImpl(person);
}
@Data
static class PersonDtoProjection {
String firstName;
int age;
}
interface PersonInterfaceProjection {
String getFirstName();
int getAge();
}
static class PersonInterfaceProjectionImpl implements PersonInterfaceProjection {
final Person delegate;
public PersonInterfaceProjectionImpl(Person delegate) {
this.delegate = delegate;
}
@Override
public String getFirstName() {
return delegate.getFirstName();
}
@Override
public int getAge() {
return delegate.getAge();
}
@Override
public boolean equals(Object o) {
if (o instanceof Proxy) {
return true;
}
return false;
}
@Override
public int hashCode() {
return ObjectUtils.nullSafeHashCode(delegate);
}
}
@Data
@AllArgsConstructor
@NoArgsConstructor
static class WithRenamedField {
String id;
@Field("_val") String value;
WithRenamedField nested;
}
@NoArgsConstructor @NoArgsConstructor
@Data @Data
class WithNestedDocument { class WithNestedDocument {
String id; String id;
String name; String name;
int age; int age;

95
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateScrollTests.java

@ -18,6 +18,10 @@ package org.springframework.data.mongodb.core;
import static org.springframework.data.mongodb.core.query.Criteria.*; import static org.springframework.data.mongodb.core.query.Criteria.*;
import static org.springframework.data.mongodb.test.util.Assertions.*; import static org.springframework.data.mongodb.test.util.Assertions.*;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.data.mongodb.core.mapping.Field;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
import java.time.Duration; import java.time.Duration;
@ -49,6 +53,7 @@ import com.mongodb.reactivestreams.client.MongoClient;
* Integration tests for {@link Window} queries. * Integration tests for {@link Window} queries.
* *
* @author Mark Paluch * @author Mark Paluch
* @author Christoph Strobl
*/ */
@ExtendWith(MongoClientExtension.class) @ExtendWith(MongoClientExtension.class)
class ReactiveMongoTemplateScrollTests { class ReactiveMongoTemplateScrollTests {
@ -78,6 +83,11 @@ class ReactiveMongoTemplateScrollTests {
.as(StepVerifier::create) // .as(StepVerifier::create) //
.expectNextCount(1) // .expectNextCount(1) //
.verifyComplete(); .verifyComplete();
template.remove(WithRenamedField.class).all() //
.as(StepVerifier::create) //
.expectNextCount(1) //
.verifyComplete();
} }
@ParameterizedTest // GH-4308 @ParameterizedTest // GH-4308
@ -100,29 +110,59 @@ class ReactiveMongoTemplateScrollTests {
Query q = new Query(where("firstName").regex("J.*")).with(Sort.by("firstName", "age")).limit(2); Query q = new Query(where("firstName").regex("J.*")).with(Sort.by("firstName", "age")).limit(2);
q.with(scrollPosition); q.with(scrollPosition);
Window<T> scroll = template.scroll(q, resultType, "person").block(Duration.ofSeconds(10)); Window<T> window = template.scroll(q, resultType, "person").block(Duration.ofSeconds(10));
assertThat(scroll.hasNext()).isTrue(); assertThat(window.hasNext()).isTrue();
assertThat(scroll.isLast()).isFalse(); assertThat(window.isLast()).isFalse();
assertThat(scroll).hasSize(2); assertThat(window).hasSize(2);
assertThat(scroll).containsOnly(assertionConverter.apply(jane_20), assertionConverter.apply(jane_40)); assertThat(window).containsOnly(assertionConverter.apply(jane_20), assertionConverter.apply(jane_40));
scroll = template.scroll(q.limit(3).with(scroll.positionAt(scroll.size() - 1)), resultType, "person") window = template.scroll(q.limit(3).with(window.positionAt(window.size() - 1)), resultType, "person")
.block(Duration.ofSeconds(10)); .block(Duration.ofSeconds(10));
assertThat(scroll.hasNext()).isTrue(); assertThat(window.hasNext()).isTrue();
assertThat(scroll.isLast()).isFalse(); assertThat(window.isLast()).isFalse();
assertThat(scroll).hasSize(3); assertThat(window).hasSize(3);
assertThat(scroll).contains(assertionConverter.apply(jane_42), assertionConverter.apply(john20)); assertThat(window).contains(assertionConverter.apply(jane_42), assertionConverter.apply(john20));
assertThat(scroll).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); assertThat(window).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2));
scroll = template.scroll(q.limit(1).with(scroll.positionAt(scroll.size() - 1)), resultType, "person") window = template.scroll(q.limit(1).with(window.positionAt(window.size() - 1)), resultType, "person")
.block(Duration.ofSeconds(10)); .block(Duration.ofSeconds(10));
assertThat(scroll.hasNext()).isFalse(); assertThat(window.hasNext()).isFalse();
assertThat(scroll.isLast()).isTrue(); assertThat(window.isLast()).isTrue();
assertThat(scroll).hasSize(1); assertThat(window).hasSize(1);
assertThat(scroll).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); assertThat(window).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2));
}
@ParameterizedTest // GH-4308
@MethodSource("renamedFieldProjectTargets")
<T> void scrollThroughResultsWithRenamedField(Class<T> resultType, Function<WithRenamedField, T> assertionConverter) {
WithRenamedField one = new WithRenamedField("id-1", "v1", null);
WithRenamedField two = new WithRenamedField("id-2", "v2", null);
WithRenamedField three = new WithRenamedField("id-3", "v3", null);
template.insertAll(Arrays.asList(one, two, three)).as(StepVerifier::create).expectNextCount(3).verifyComplete();
Query q = new Query(where("value").regex("v.*")).with(Sort.by(Sort.Direction.DESC, "value")).limit(2);
q.with(KeysetScrollPosition.initial());
Window<T> window = template.query(WithRenamedField.class).as(resultType).matching(q)
.scroll(KeysetScrollPosition.initial()).block(Duration.ofSeconds(10));
assertThat(window.hasNext()).isTrue();
assertThat(window.isLast()).isFalse();
assertThat(window).hasSize(2);
assertThat(window).containsOnly(assertionConverter.apply(three), assertionConverter.apply(two));
window = template.query(WithRenamedField.class).as(resultType).matching(q)
.scroll(window.positionAt(window.size() - 1)).block(Duration.ofSeconds(10));
assertThat(window.hasNext()).isFalse();
assertThat(window.isLast()).isTrue();
assertThat(window).hasSize(1);
assertThat(window).containsOnly(assertionConverter.apply(one));
} }
static Stream<Arguments> positions() { static Stream<Arguments> positions() {
@ -132,6 +172,17 @@ class ReactiveMongoTemplateScrollTests {
args(OffsetScrollPosition.initial(), Person.class, Function.identity())); args(OffsetScrollPosition.initial(), Person.class, Function.identity()));
} }
static Stream<Arguments> renamedFieldProjectTargets() {
return Stream.of(Arguments.of(WithRenamedField.class, Function.identity()),
Arguments.of(Document.class, new Function<WithRenamedField, Document>() {
@Override
public Document apply(WithRenamedField withRenamedField) {
return new Document("_id", withRenamedField.getId()).append("_val", withRenamedField.getValue())
.append("_class", WithRenamedField.class.getName());
}
}));
}
private static <T> Arguments args(ScrollPosition scrollPosition, Class<T> resultType, private static <T> Arguments args(ScrollPosition scrollPosition, Class<T> resultType,
Function<Person, T> assertionConverter) { Function<Person, T> assertionConverter) {
return Arguments.of(scrollPosition, resultType, assertionConverter); return Arguments.of(scrollPosition, resultType, assertionConverter);
@ -141,4 +192,16 @@ class ReactiveMongoTemplateScrollTests {
return new Document("_class", person.getClass().getName()).append("_id", person.getId()).append("active", true) return new Document("_class", person.getClass().getName()).append("_id", person.getId()).append("active", true)
.append("firstName", person.getFirstName()).append("age", person.getAge()); .append("firstName", person.getFirstName()).append("age", person.getAge());
} }
@Data
@AllArgsConstructor
@NoArgsConstructor
static class WithRenamedField {
String id;
@Field("_val") String value;
MongoTemplateScrollTests.WithRenamedField nested;
}
} }

Loading…
Cancel
Save