From d8c04f0ec98e6aae55fa8f3b79a287139d28b98f Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 17 Mar 2023 10:52:08 +0100 Subject: [PATCH] Use projecting read callback to allow interface projections. Along the lines fix entity operations proxy handling by reading the underlying map instead of inspecting the proxy interface. Also make sure to map potential raw fields back to the according property. See: #4308 Original Pull Request: #4317 --- .../data/mongodb/core/EntityOperations.java | 77 ++++-- .../data/mongodb/core/MongoTemplate.java | 5 +- .../mongodb/core/ReactiveMongoTemplate.java | 12 +- .../data/mongodb/core/ScrollUtils.java | 5 +- .../core/EntityOperationsUnitTests.java | 40 ++- .../core/MongoTemplateScrollTests.java | 239 ++++++++++++++---- .../ReactiveMongoTemplateScrollTests.java | 95 +++++-- 7 files changed, 382 insertions(+), 91 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java index cbf252f1d..9212e20ae 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java @@ -30,6 +30,8 @@ import org.springframework.data.mapping.IdentifierAccessor; import org.springframework.data.mapping.MappingException; import org.springframework.data.mapping.PersistentEntity; import org.springframework.data.mapping.PersistentPropertyAccessor; +import org.springframework.data.mapping.PersistentPropertyPath; +import org.springframework.data.mapping.PropertyPath; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mapping.model.ConvertingPropertyAccessor; import org.springframework.data.mongodb.core.CollectionOptions.TimeSeriesOptions; @@ -50,6 +52,7 @@ import org.springframework.data.mongodb.util.BsonUtils; import org.springframework.data.projection.EntityProjection; import org.springframework.data.projection.EntityProjectionIntrospector; import org.springframework.data.projection.ProjectionFactory; +import org.springframework.data.projection.TargetAware; import org.springframework.data.util.Optionals; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -117,12 +120,16 @@ class EntityOperations { Assert.notNull(entity, "Bean must not be null"); + if (entity instanceof TargetAware targetAware) { + return new SimpleMappedEntity((Map) targetAware.getTarget(), this); + } + if (entity instanceof String) { - return new UnmappedEntity(parse(entity.toString())); + return new UnmappedEntity(parse(entity.toString()), this); } if (entity instanceof Map) { - return new SimpleMappedEntity((Map) entity); + return new SimpleMappedEntity((Map) entity, this); } return MappedEntity.of(entity, context, this); @@ -142,11 +149,11 @@ class EntityOperations { Assert.notNull(conversionService, "ConversionService must not be null"); if (entity instanceof String) { - return new UnmappedEntity(parse(entity.toString())); + return new UnmappedEntity(parse(entity.toString()), this); } if (entity instanceof Map) { - return new SimpleMappedEntity((Map) entity); + return new SimpleMappedEntity((Map) entity, this); } return AdaptibleMappedEntity.of(entity, context, conversionService, this); @@ -287,7 +294,8 @@ class EntityOperations { */ public EntityProjection introspectProjection(Class resultType, Class entityType) { - if (!queryMapper.getMappingContext().hasPersistentEntityFor(entityType)) { + MongoPersistentEntity persistentEntity = queryMapper.getMappingContext().getPersistentEntity(entityType); + if (persistentEntity == null && !resultType.isInterface() || ClassUtils.isAssignable(Document.class, resultType)) { return (EntityProjection) EntityProjection.nonProjecting(resultType); } return introspector.introspect(resultType, entityType); @@ -369,6 +377,7 @@ class EntityOperations { * A representation of information about an entity. * * @author Oliver Gierke + * @author Christoph Strobl * @since 2.1 */ interface Entity { @@ -471,10 +480,10 @@ class EntityOperations { /** * @param sortObject * @return - * @since 3.1 + * @since 4.1 * @throws IllegalStateException if a sort key yields {@literal null}. */ - Map extractKeys(Document sortObject); + Map extractKeys(Document sortObject, Class sourceType); } @@ -523,9 +532,11 @@ class EntityOperations { private static class UnmappedEntity> implements AdaptibleEntity { private final T map; + private final EntityOperations entityOperations; - protected UnmappedEntity(T map) { + protected UnmappedEntity(T map, EntityOperations entityOperations) { this.map = map; + this.entityOperations = entityOperations; } @Override @@ -596,13 +607,19 @@ class EntityOperations { } @Override - public Map extractKeys(Document sortObject) { + public Map extractKeys(Document sortObject, Class sourceType) { Map keyset = new LinkedHashMap<>(); - keyset.put(ID_FIELD, getId()); + MongoPersistentEntity sourceEntity = entityOperations.context.getPersistentEntity(sourceType); + if (sourceEntity != null && sourceEntity.hasIdProperty()) { + keyset.put(sourceEntity.getRequiredIdProperty().getName(), getId()); + } else { + keyset.put(ID_FIELD, getId()); + } for (String key : sortObject.keySet()) { - Object value = BsonUtils.resolveValue(map, key); + + Object value = resolveValue(key, sourceEntity); if (value == null) { throw new IllegalStateException( @@ -614,12 +631,24 @@ class EntityOperations { return keyset; } + + @Nullable + private Object resolveValue(String key, @Nullable MongoPersistentEntity sourceEntity) { + + if (sourceEntity == null) { + return BsonUtils.resolveValue(map, key); + } + PropertyPath from = PropertyPath.from(key, sourceEntity.getTypeInformation()); + PersistentPropertyPath persistentPropertyPath = entityOperations.context + .getPersistentPropertyPath(from); + return BsonUtils.resolveValue(map, persistentPropertyPath.toDotPath(p -> p.getFieldName())); + } } private static class SimpleMappedEntity> extends UnmappedEntity { - protected SimpleMappedEntity(T map) { - super(map); + protected SimpleMappedEntity(T map, EntityOperations entityOperations) { + super(map, entityOperations); } @Override @@ -758,10 +787,15 @@ class EntityOperations { } @Override - public Map extractKeys(Document sortObject) { + public Map extractKeys(Document sortObject, Class sourceType) { Map keyset = new LinkedHashMap<>(); - keyset.put(entity.getRequiredIdProperty().getName(), getId()); + MongoPersistentEntity sourceEntity = entityOperations.context.getPersistentEntity(sourceType); + if (sourceEntity != null && sourceEntity.hasIdProperty()) { + keyset.put(sourceEntity.getRequiredIdProperty().getName(), getId()); + } else { + keyset.put(entity.getRequiredIdProperty().getName(), getId()); + } for (String key : sortObject.keySet()) { @@ -933,6 +967,14 @@ class EntityOperations { * @since 3.3 */ TimeSeriesOptions mapTimeSeriesOptions(TimeSeriesOptions options); + + /** + * @return the name of the id field. + * @since 4.1 + */ + default String getIdKeyName() { + return ID_FIELD; + } } /** @@ -1055,6 +1097,11 @@ class EntityOperations { MongoPersistentProperty persistentProperty = entity.getPersistentProperty(name); return persistentProperty != null ? persistentProperty.getFieldName() : name; } + + @Override + public String getIdKeyName() { + return entity.getIdProperty().getName(); + } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index ae2b26def..bc95f0cfe 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -870,7 +870,8 @@ public class MongoTemplate Assert.notNull(sourceClass, "Entity type must not be null"); Assert.notNull(targetClass, "Target type must not be null"); - ReadDocumentCallback callback = new ReadDocumentCallback<>(mongoConverter, targetClass, collectionName); + EntityProjection projection = operations.introspectProjection(targetClass, sourceClass); + ProjectingReadCallback callback = new ProjectingReadCallback<>(mongoConverter, projection, collectionName); int limit = query.isLimited() ? query.getLimit() + 1 : Integer.MAX_VALUE; if (query.hasKeyset()) { @@ -882,7 +883,7 @@ public class MongoTemplate keysetPaginationQuery.fields(), sourceClass, new QueryCursorPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass), callback); - return ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), result, operations); + return ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), result, sourceClass, operations); } List result = doFind(collectionName, createDelegate(query), query.getQueryObject(), query.getFieldsObject(), diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java index 22747d382..21869b6d6 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java @@ -849,6 +849,8 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati Assert.notNull(sourceClass, "Entity type must not be null"); Assert.notNull(targetClass, "Target type must not be null"); + EntityProjection projection = operations.introspectProjection(targetClass, sourceClass); + ProjectingReadCallback callback = new ProjectingReadCallback<>(mongoConverter, projection, collectionName); int limit = query.isLimited() ? query.getLimit() + 1 : Integer.MAX_VALUE; if (query.hasKeyset()) { @@ -857,15 +859,15 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati operations.getIdPropertyName(sourceClass)); Mono> result = doFind(collectionName, ReactiveCollectionPreparerDelegate.of(query), - keysetPaginationQuery.query(), keysetPaginationQuery.fields(), targetClass, - new QueryFindPublisherPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass)).collectList(); + keysetPaginationQuery.query(), keysetPaginationQuery.fields(), sourceClass, + new QueryFindPublisherPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass), callback).collectList(); - return result.map(it -> ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), it, operations)); + return result.map(it -> ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), it, sourceClass, operations)); } Mono> result = doFind(collectionName, ReactiveCollectionPreparerDelegate.of(query), query.getQueryObject(), - query.getFieldsObject(), targetClass, - new QueryFindPublisherPreparer(query, query.getSortObject(), limit, query.getSkip(), sourceClass)) + query.getFieldsObject(), sourceClass, + new QueryFindPublisherPreparer(query, query.getSortObject(), limit, query.getSkip(), sourceClass), callback) .collectList(); return result.map( diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ScrollUtils.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ScrollUtils.java index 7c0ae5da9..112f95270 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ScrollUtils.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ScrollUtils.java @@ -121,14 +121,15 @@ class ScrollUtils { return sortOrder == 1 ? "$gt" : "$lt"; } - static Window createWindow(Document sortObject, int limit, List result, EntityOperations operations) { + static Window createWindow(Document sortObject, int limit, List result, Class sourceType, + EntityOperations operations) { IntFunction positionFunction = value -> { T last = result.get(value); Entity entity = operations.forEntity(last); - Map keys = entity.extractKeys(sortObject); + Map keys = entity.extractKeys(sortObject, sourceType); return KeysetScrollPosition.of(keys); }; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/EntityOperationsUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/EntityOperationsUnitTests.java index 6107abaa3..3635cf4e3 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/EntityOperationsUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/EntityOperationsUnitTests.java @@ -33,11 +33,13 @@ import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; import org.springframework.data.mongodb.core.mapping.TimeSeries; import org.springframework.data.mongodb.test.util.MongoTestMappingContext; +import org.springframework.data.projection.SpelAwareProxyProjectionFactory; /** * Unit tests for {@link EntityOperations}. * * @author Mark Paluch + * @author Christoph Strobl */ class EntityOperationsUnitTests { @@ -70,7 +72,8 @@ class EntityOperationsUnitTests { WithNestedDocument object = new WithNestedDocument("foo"); - Map keys = operations.forEntity(object).extractKeys(new Document("id", 1)); + Map keys = operations.forEntity(object).extractKeys(new Document("id", 1), + WithNestedDocument.class); assertThat(keys).containsEntry("id", "foo"); } @@ -80,7 +83,7 @@ class EntityOperationsUnitTests { Document object = new Document("id", "foo"); - Map keys = operations.forEntity(object).extractKeys(new Document("id", 1)); + Map keys = operations.forEntity(object).extractKeys(new Document("id", 1), Document.class); assertThat(keys).containsEntry("id", "foo"); } @@ -90,7 +93,8 @@ class EntityOperationsUnitTests { WithNestedDocument object = new WithNestedDocument("foo", new WithNestedDocument("bar"), null); - Map keys = operations.forEntity(object).extractKeys(new Document("nested.id", 1)); + Map keys = operations.forEntity(object).extractKeys(new Document("nested.id", 1), + WithNestedDocument.class); assertThat(keys).containsEntry("nested.id", "bar"); } @@ -101,7 +105,8 @@ class EntityOperationsUnitTests { WithNestedDocument object = new WithNestedDocument("foo", new WithNestedDocument("bar"), new Document("john", "doe")); - Map keys = operations.forEntity(object).extractKeys(new Document("document.john", 1)); + Map keys = operations.forEntity(object).extractKeys(new Document("document.john", 1), + WithNestedDocument.class); assertThat(keys).containsEntry("document.john", "doe"); } @@ -111,11 +116,32 @@ class EntityOperationsUnitTests { Document object = new Document("document", new Document("john", "doe")); - Map keys = operations.forEntity(object).extractKeys(new Document("document.john", 1)); + Map keys = operations.forEntity(object).extractKeys(new Document("document.john", 1), + Document.class); assertThat(keys).containsEntry("document.john", "doe"); } + @Test // GH-4308 + void shouldExtractIdPropertyNameFromRawDocument() { + + Document object = new Document("_id", "id-1").append("value", "val"); + + Map keys = operations.forEntity(object).extractKeys(new Document("value", 1), DomainTypeWithIdProperty.class); + + assertThat(keys).containsEntry("id", "id-1"); + } + + @Test // GH-4308 + void shouldExtractValuesFromProxy() { + + ProjectionInterface source = new SpelAwareProxyProjectionFactory().createProjection(ProjectionInterface.class, new Document("_id", "id-1").append("value", "val")); + + Map keys = operations.forEntity(source).extractKeys(new Document("value", 1), DomainTypeWithIdProperty.class); + + assertThat(keys).isEqualTo(new Document("id", "id-1").append("value", "val")); + } + EntityOperations.AdaptibleEntity initAdaptibleEntity(T source) { return operations.forEntity(source, conversionService); } @@ -150,4 +176,8 @@ class EntityOperationsUnitTests { this.id = id; } } + + interface ProjectionInterface { + String getValue(); + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateScrollTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateScrollTests.java index aca3ea4bc..88c03e915 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateScrollTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateScrollTests.java @@ -18,10 +18,13 @@ package org.springframework.data.mongodb.core; import static org.assertj.core.api.Assertions.*; import static org.springframework.data.mongodb.core.query.Criteria.*; +import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; +import java.lang.reflect.Proxy; import java.util.Arrays; +import java.util.Comparator; import java.util.function.Function; import java.util.stream.Stream; @@ -44,10 +47,13 @@ import org.springframework.data.domain.Sort; import org.springframework.data.domain.Window; import org.springframework.data.mapping.context.PersistentEntities; import org.springframework.data.mongodb.core.MongoTemplateTests.PersonWithIdPropertyOfTypeUUIDListener; +import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.test.util.Client; import org.springframework.data.mongodb.test.util.MongoClientExtension; import org.springframework.data.mongodb.test.util.MongoTestTemplate; +import org.springframework.lang.Nullable; +import org.springframework.util.ObjectUtils; import com.mongodb.client.MongoClient; @@ -90,10 +96,22 @@ class MongoTemplateScrollTests { }); }); + private static int compareProxies(PersonInterfaceProjection actual, PersonInterfaceProjection expected) { + if (actual.getAge() != expected.getAge()) { + return -1; + } + if (!ObjectUtils.nullSafeEquals(actual.getFirstName(), expected.getFirstName())) { + return -1; + } + + return 0; + } + @BeforeEach void setUp() { template.remove(Person.class).all(); template.remove(WithNestedDocument.class).all(); + template.remove(WithRenamedField.class).all(); } @Test // GH-4308 @@ -112,19 +130,19 @@ class MongoTemplateScrollTests { .limit(2); q.with(KeysetScrollPosition.initial()); - Window scroll = template.scroll(q, WithNestedDocument.class); + Window window = template.scroll(q, WithNestedDocument.class); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(2); - assertThat(scroll).containsOnly(john20, john40); + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertThat(window).containsOnly(john20, john40); - scroll = template.scroll(q.with(scroll.positionAt(scroll.size() - 1)), WithNestedDocument.class); + window = template.scroll(q.with(window.positionAt(window.size() - 1)), WithNestedDocument.class); - assertThat(scroll.hasNext()).isFalse(); - assertThat(scroll.isLast()).isTrue(); - assertThat(scroll).hasSize(1); - assertThat(scroll).containsOnly(john41); + assertThat(window.hasNext()).isFalse(); + assertThat(window.isLast()).isTrue(); + assertThat(window).hasSize(1); + assertThat(window).containsOnly(john41); } @Test // GH-4308 @@ -162,35 +180,35 @@ class MongoTemplateScrollTests { .limit(2); q.with(KeysetScrollPosition.initial()); - Window scroll = template.scroll(q, WithNestedDocument.class); + Window window = template.scroll(q, WithNestedDocument.class); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(2); - assertThat(scroll).containsOnly(john20, john40); + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertThat(window).containsOnly(john20, john40); - scroll = template.scroll(q.with(scroll.positionAt(scroll.size() - 1)), WithNestedDocument.class); + window = template.scroll(q.with(window.positionAt(window.size() - 1)), WithNestedDocument.class); - assertThat(scroll.hasNext()).isFalse(); - assertThat(scroll.isLast()).isTrue(); - assertThat(scroll).hasSize(1); - assertThat(scroll).containsOnly(john41); + assertThat(window.hasNext()).isFalse(); + assertThat(window.isLast()).isTrue(); + assertThat(window).hasSize(1); + assertThat(window).containsOnly(john41); - KeysetScrollPosition scrollPosition = (KeysetScrollPosition) scroll.positionAt(0); + KeysetScrollPosition scrollPosition = (KeysetScrollPosition) window.positionAt(0); KeysetScrollPosition reversePosition = KeysetScrollPosition.of(scrollPosition.getKeys(), Direction.Backward); - scroll = template.scroll(q.with(reversePosition), WithNestedDocument.class); + window = template.scroll(q.with(reversePosition), WithNestedDocument.class); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(2); - assertThat(scroll).containsOnly(john20, john40); + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertThat(window).containsOnly(john20, john40); } @ParameterizedTest // GH-4308 @MethodSource("positions") public void shouldApplyCursoringCorrectly(ScrollPosition scrollPosition, Class resultType, - Function assertionConverter) { + Function assertionConverter, @Nullable Comparator comparator) { Person john20 = new Person("John", 20); Person john40_1 = new Person("John", 40); @@ -201,53 +219,182 @@ class MongoTemplateScrollTests { template.insertAll(Arrays.asList(john20, john40_1, john40_2, jane_20, jane_40, jane_42)); Query q = new Query(where("firstName").regex("J.*")).with(Sort.by("firstName", "age")).limit(2); - q.with(scrollPosition); - Window scroll = template.scroll(q, resultType, "person"); + Window window = template.query(Person.class).inCollection("person").as(resultType).matching(q) + .scroll(scrollPosition); + + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertWindow(window, comparator).containsOnly(assertionConverter.apply(jane_20), assertionConverter.apply(jane_40)); + + window = template.query(Person.class).inCollection("person").as(resultType).matching(q.limit(3)) + .scroll(window.positionAt(window.size() - 1)); + + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(3); + assertWindow(window, comparator).contains(assertionConverter.apply(jane_42), assertionConverter.apply(john20)); + assertWindow(window, comparator).containsAnyOf(assertionConverter.apply(john40_1), + assertionConverter.apply(john40_2)); + + window = template.query(Person.class).inCollection("person").as(resultType).matching(q.limit(1)) + .scroll(window.positionAt(window.size() - 1)); + + assertThat(window.hasNext()).isFalse(); + assertThat(window.isLast()).isTrue(); + assertThat(window).hasSize(1); + assertWindow(window, comparator).containsAnyOf(assertionConverter.apply(john40_1), + assertionConverter.apply(john40_2)); + } + + @ParameterizedTest // GH-4308 + @MethodSource("renamedFieldProjectTargets") + void scrollThroughResultsWithRenamedField(Class resultType, Function assertionConverter) { + + WithRenamedField one = new WithRenamedField("id-1", "v1", null); + WithRenamedField two = new WithRenamedField("id-2", "v2", null); + WithRenamedField three = new WithRenamedField("id-3", "v3", null); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(2); - assertThat(scroll).containsOnly(assertionConverter.apply(jane_20), assertionConverter.apply(jane_40)); + template.insertAll(Arrays.asList(one, two, three)); - scroll = template.scroll(q.with(scroll.positionAt(scroll.size() - 1)).limit(3), resultType, "person"); + Query q = new Query(where("value").regex("v.*")).with(Sort.by(Sort.Direction.DESC, "value")).limit(2); + q.with(KeysetScrollPosition.initial()); + + Window window = template.query(WithRenamedField.class).as(resultType).matching(q) + .scroll(KeysetScrollPosition.initial()); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(3); - assertThat(scroll).contains(assertionConverter.apply(jane_42), assertionConverter.apply(john20)); - assertThat(scroll).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertThat(window).containsOnly(assertionConverter.apply(three), assertionConverter.apply(two)); - scroll = template.scroll(q.with(scroll.positionAt(scroll.size() - 1)).limit(1), resultType, "person"); + window = template.query(WithRenamedField.class).as(resultType).matching(q) + .scroll(window.positionAt(window.size() - 1)); - assertThat(scroll.hasNext()).isFalse(); - assertThat(scroll.isLast()).isTrue(); - assertThat(scroll).hasSize(1); - assertThat(scroll).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); + assertThat(window.hasNext()).isFalse(); + assertThat(window.isLast()).isTrue(); + assertThat(window).hasSize(1); + assertThat(window).containsOnly(assertionConverter.apply(one)); } static Stream positions() { return Stream.of(args(KeysetScrollPosition.initial(), Person.class, Function.identity()), // args(KeysetScrollPosition.initial(), Document.class, MongoTemplateScrollTests::toDocument), // - args(OffsetScrollPosition.initial(), Person.class, Function.identity())); + args(OffsetScrollPosition.initial(), Person.class, Function.identity()), // + args(OffsetScrollPosition.initial(), PersonDtoProjection.class, + MongoTemplateScrollTests::toPersonDtoProjection), // + args(OffsetScrollPosition.initial(), PersonInterfaceProjection.class, + MongoTemplateScrollTests::toPersonInterfaceProjection, MongoTemplateScrollTests::compareProxies)); + } + + static Stream renamedFieldProjectTargets() { + return Stream.of(Arguments.of(WithRenamedField.class, Function.identity()), + Arguments.of(Document.class, new Function() { + @Override + public Document apply(WithRenamedField withRenamedField) { + return new Document("_id", withRenamedField.getId()).append("_val", withRenamedField.getValue()) + .append("_class", WithRenamedField.class.getName()); + } + })); + } + + static org.assertj.core.api.IterableAssert assertWindow(Window window, @Nullable Comparator comparator) { + return comparator != null ? assertThat(window).usingElementComparator(comparator) : assertThat(window); } private static Arguments args(ScrollPosition scrollPosition, Class resultType, Function assertionConverter) { - return Arguments.of(scrollPosition, resultType, assertionConverter); + return args(scrollPosition, resultType, assertionConverter, null); + } + + private static Arguments args(ScrollPosition scrollPosition, Class resultType, + Function assertionConverter, @Nullable Comparator comparator) { + return Arguments.of(scrollPosition, resultType, assertionConverter, comparator); } static Document toDocument(Person person) { + return new Document("_class", person.getClass().getName()).append("_id", person.getId()).append("active", true) .append("firstName", person.getFirstName()).append("age", person.getAge()); } + static PersonDtoProjection toPersonDtoProjection(Person person) { + + PersonDtoProjection dto = new PersonDtoProjection(); + dto.firstName = person.getFirstName(); + dto.age = person.getAge(); + return dto; + } + + static PersonInterfaceProjection toPersonInterfaceProjection(Person person) { + + return new PersonInterfaceProjectionImpl(person); + } + + @Data + static class PersonDtoProjection { + String firstName; + int age; + } + + interface PersonInterfaceProjection { + String getFirstName(); + + int getAge(); + } + + static class PersonInterfaceProjectionImpl implements PersonInterfaceProjection { + + final Person delegate; + + public PersonInterfaceProjectionImpl(Person delegate) { + this.delegate = delegate; + } + + @Override + public String getFirstName() { + return delegate.getFirstName(); + } + + @Override + public int getAge() { + return delegate.getAge(); + } + + @Override + public boolean equals(Object o) { + if (o instanceof Proxy) { + return true; + } + return false; + } + + @Override + public int hashCode() { + return ObjectUtils.nullSafeHashCode(delegate); + } + } + + @Data + @AllArgsConstructor + @NoArgsConstructor + static class WithRenamedField { + + String id; + + @Field("_val") String value; + + WithRenamedField nested; + } + @NoArgsConstructor @Data class WithNestedDocument { String id; + String name; int age; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateScrollTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateScrollTests.java index d42d8d99f..35f67782f 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateScrollTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateScrollTests.java @@ -18,6 +18,10 @@ package org.springframework.data.mongodb.core; import static org.springframework.data.mongodb.core.query.Criteria.*; import static org.springframework.data.mongodb.test.util.Assertions.*; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.springframework.data.mongodb.core.mapping.Field; import reactor.test.StepVerifier; import java.time.Duration; @@ -49,6 +53,7 @@ import com.mongodb.reactivestreams.client.MongoClient; * Integration tests for {@link Window} queries. * * @author Mark Paluch + * @author Christoph Strobl */ @ExtendWith(MongoClientExtension.class) class ReactiveMongoTemplateScrollTests { @@ -78,6 +83,11 @@ class ReactiveMongoTemplateScrollTests { .as(StepVerifier::create) // .expectNextCount(1) // .verifyComplete(); + + template.remove(WithRenamedField.class).all() // + .as(StepVerifier::create) // + .expectNextCount(1) // + .verifyComplete(); } @ParameterizedTest // GH-4308 @@ -100,29 +110,59 @@ class ReactiveMongoTemplateScrollTests { Query q = new Query(where("firstName").regex("J.*")).with(Sort.by("firstName", "age")).limit(2); q.with(scrollPosition); - Window scroll = template.scroll(q, resultType, "person").block(Duration.ofSeconds(10)); + Window window = template.scroll(q, resultType, "person").block(Duration.ofSeconds(10)); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(2); - assertThat(scroll).containsOnly(assertionConverter.apply(jane_20), assertionConverter.apply(jane_40)); + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertThat(window).containsOnly(assertionConverter.apply(jane_20), assertionConverter.apply(jane_40)); - scroll = template.scroll(q.limit(3).with(scroll.positionAt(scroll.size() - 1)), resultType, "person") + window = template.scroll(q.limit(3).with(window.positionAt(window.size() - 1)), resultType, "person") .block(Duration.ofSeconds(10)); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(3); - assertThat(scroll).contains(assertionConverter.apply(jane_42), assertionConverter.apply(john20)); - assertThat(scroll).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(3); + assertThat(window).contains(assertionConverter.apply(jane_42), assertionConverter.apply(john20)); + assertThat(window).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); - scroll = template.scroll(q.limit(1).with(scroll.positionAt(scroll.size() - 1)), resultType, "person") + window = template.scroll(q.limit(1).with(window.positionAt(window.size() - 1)), resultType, "person") .block(Duration.ofSeconds(10)); - assertThat(scroll.hasNext()).isFalse(); - assertThat(scroll.isLast()).isTrue(); - assertThat(scroll).hasSize(1); - assertThat(scroll).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); + assertThat(window.hasNext()).isFalse(); + assertThat(window.isLast()).isTrue(); + assertThat(window).hasSize(1); + assertThat(window).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); + } + + @ParameterizedTest // GH-4308 + @MethodSource("renamedFieldProjectTargets") + void scrollThroughResultsWithRenamedField(Class resultType, Function assertionConverter) { + + WithRenamedField one = new WithRenamedField("id-1", "v1", null); + WithRenamedField two = new WithRenamedField("id-2", "v2", null); + WithRenamedField three = new WithRenamedField("id-3", "v3", null); + + template.insertAll(Arrays.asList(one, two, three)).as(StepVerifier::create).expectNextCount(3).verifyComplete(); + + Query q = new Query(where("value").regex("v.*")).with(Sort.by(Sort.Direction.DESC, "value")).limit(2); + q.with(KeysetScrollPosition.initial()); + + Window window = template.query(WithRenamedField.class).as(resultType).matching(q) + .scroll(KeysetScrollPosition.initial()).block(Duration.ofSeconds(10)); + + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertThat(window).containsOnly(assertionConverter.apply(three), assertionConverter.apply(two)); + + window = template.query(WithRenamedField.class).as(resultType).matching(q) + .scroll(window.positionAt(window.size() - 1)).block(Duration.ofSeconds(10)); + + assertThat(window.hasNext()).isFalse(); + assertThat(window.isLast()).isTrue(); + assertThat(window).hasSize(1); + assertThat(window).containsOnly(assertionConverter.apply(one)); } static Stream positions() { @@ -132,6 +172,17 @@ class ReactiveMongoTemplateScrollTests { args(OffsetScrollPosition.initial(), Person.class, Function.identity())); } + static Stream renamedFieldProjectTargets() { + return Stream.of(Arguments.of(WithRenamedField.class, Function.identity()), + Arguments.of(Document.class, new Function() { + @Override + public Document apply(WithRenamedField withRenamedField) { + return new Document("_id", withRenamedField.getId()).append("_val", withRenamedField.getValue()) + .append("_class", WithRenamedField.class.getName()); + } + })); + } + private static Arguments args(ScrollPosition scrollPosition, Class resultType, Function assertionConverter) { return Arguments.of(scrollPosition, resultType, assertionConverter); @@ -141,4 +192,16 @@ class ReactiveMongoTemplateScrollTests { return new Document("_class", person.getClass().getName()).append("_id", person.getId()).append("active", true) .append("firstName", person.getFirstName()).append("age", person.getAge()); } + + @Data + @AllArgsConstructor + @NoArgsConstructor + static class WithRenamedField { + + String id; + + @Field("_val") String value; + + MongoTemplateScrollTests.WithRenamedField nested; + } }