diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java index cbf252f1d..9212e20ae 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java @@ -30,6 +30,8 @@ import org.springframework.data.mapping.IdentifierAccessor; import org.springframework.data.mapping.MappingException; import org.springframework.data.mapping.PersistentEntity; import org.springframework.data.mapping.PersistentPropertyAccessor; +import org.springframework.data.mapping.PersistentPropertyPath; +import org.springframework.data.mapping.PropertyPath; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mapping.model.ConvertingPropertyAccessor; import org.springframework.data.mongodb.core.CollectionOptions.TimeSeriesOptions; @@ -50,6 +52,7 @@ import org.springframework.data.mongodb.util.BsonUtils; import org.springframework.data.projection.EntityProjection; import org.springframework.data.projection.EntityProjectionIntrospector; import org.springframework.data.projection.ProjectionFactory; +import org.springframework.data.projection.TargetAware; import org.springframework.data.util.Optionals; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -117,12 +120,16 @@ class EntityOperations { Assert.notNull(entity, "Bean must not be null"); + if (entity instanceof TargetAware targetAware) { + return new SimpleMappedEntity((Map) targetAware.getTarget(), this); + } + if (entity instanceof String) { - return new UnmappedEntity(parse(entity.toString())); + return new UnmappedEntity(parse(entity.toString()), this); } if (entity instanceof Map) { - return new SimpleMappedEntity((Map) entity); + return new SimpleMappedEntity((Map) entity, this); } return MappedEntity.of(entity, context, this); @@ -142,11 +149,11 @@ class EntityOperations { Assert.notNull(conversionService, "ConversionService must not be null"); if (entity instanceof String) { - return new UnmappedEntity(parse(entity.toString())); + return new UnmappedEntity(parse(entity.toString()), this); } if (entity instanceof Map) { - return new SimpleMappedEntity((Map) entity); + return new SimpleMappedEntity((Map) entity, this); } return AdaptibleMappedEntity.of(entity, context, conversionService, this); @@ -287,7 +294,8 @@ class EntityOperations { */ public EntityProjection introspectProjection(Class resultType, Class entityType) { - if (!queryMapper.getMappingContext().hasPersistentEntityFor(entityType)) { + MongoPersistentEntity persistentEntity = queryMapper.getMappingContext().getPersistentEntity(entityType); + if (persistentEntity == null && !resultType.isInterface() || ClassUtils.isAssignable(Document.class, resultType)) { return (EntityProjection) EntityProjection.nonProjecting(resultType); } return introspector.introspect(resultType, entityType); @@ -369,6 +377,7 @@ class EntityOperations { * A representation of information about an entity. * * @author Oliver Gierke + * @author Christoph Strobl * @since 2.1 */ interface Entity { @@ -471,10 +480,10 @@ class EntityOperations { /** * @param sortObject * @return - * @since 3.1 + * @since 4.1 * @throws IllegalStateException if a sort key yields {@literal null}. */ - Map extractKeys(Document sortObject); + Map extractKeys(Document sortObject, Class sourceType); } @@ -523,9 +532,11 @@ class EntityOperations { private static class UnmappedEntity> implements AdaptibleEntity { private final T map; + private final EntityOperations entityOperations; - protected UnmappedEntity(T map) { + protected UnmappedEntity(T map, EntityOperations entityOperations) { this.map = map; + this.entityOperations = entityOperations; } @Override @@ -596,13 +607,19 @@ class EntityOperations { } @Override - public Map extractKeys(Document sortObject) { + public Map extractKeys(Document sortObject, Class sourceType) { Map keyset = new LinkedHashMap<>(); - keyset.put(ID_FIELD, getId()); + MongoPersistentEntity sourceEntity = entityOperations.context.getPersistentEntity(sourceType); + if (sourceEntity != null && sourceEntity.hasIdProperty()) { + keyset.put(sourceEntity.getRequiredIdProperty().getName(), getId()); + } else { + keyset.put(ID_FIELD, getId()); + } for (String key : sortObject.keySet()) { - Object value = BsonUtils.resolveValue(map, key); + + Object value = resolveValue(key, sourceEntity); if (value == null) { throw new IllegalStateException( @@ -614,12 +631,24 @@ class EntityOperations { return keyset; } + + @Nullable + private Object resolveValue(String key, @Nullable MongoPersistentEntity sourceEntity) { + + if (sourceEntity == null) { + return BsonUtils.resolveValue(map, key); + } + PropertyPath from = PropertyPath.from(key, sourceEntity.getTypeInformation()); + PersistentPropertyPath persistentPropertyPath = entityOperations.context + .getPersistentPropertyPath(from); + return BsonUtils.resolveValue(map, persistentPropertyPath.toDotPath(p -> p.getFieldName())); + } } private static class SimpleMappedEntity> extends UnmappedEntity { - protected SimpleMappedEntity(T map) { - super(map); + protected SimpleMappedEntity(T map, EntityOperations entityOperations) { + super(map, entityOperations); } @Override @@ -758,10 +787,15 @@ class EntityOperations { } @Override - public Map extractKeys(Document sortObject) { + public Map extractKeys(Document sortObject, Class sourceType) { Map keyset = new LinkedHashMap<>(); - keyset.put(entity.getRequiredIdProperty().getName(), getId()); + MongoPersistentEntity sourceEntity = entityOperations.context.getPersistentEntity(sourceType); + if (sourceEntity != null && sourceEntity.hasIdProperty()) { + keyset.put(sourceEntity.getRequiredIdProperty().getName(), getId()); + } else { + keyset.put(entity.getRequiredIdProperty().getName(), getId()); + } for (String key : sortObject.keySet()) { @@ -933,6 +967,14 @@ class EntityOperations { * @since 3.3 */ TimeSeriesOptions mapTimeSeriesOptions(TimeSeriesOptions options); + + /** + * @return the name of the id field. + * @since 4.1 + */ + default String getIdKeyName() { + return ID_FIELD; + } } /** @@ -1055,6 +1097,11 @@ class EntityOperations { MongoPersistentProperty persistentProperty = entity.getPersistentProperty(name); return persistentProperty != null ? persistentProperty.getFieldName() : name; } + + @Override + public String getIdKeyName() { + return entity.getIdProperty().getName(); + } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index ae2b26def..bc95f0cfe 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -870,7 +870,8 @@ public class MongoTemplate Assert.notNull(sourceClass, "Entity type must not be null"); Assert.notNull(targetClass, "Target type must not be null"); - ReadDocumentCallback callback = new ReadDocumentCallback<>(mongoConverter, targetClass, collectionName); + EntityProjection projection = operations.introspectProjection(targetClass, sourceClass); + ProjectingReadCallback callback = new ProjectingReadCallback<>(mongoConverter, projection, collectionName); int limit = query.isLimited() ? query.getLimit() + 1 : Integer.MAX_VALUE; if (query.hasKeyset()) { @@ -882,7 +883,7 @@ public class MongoTemplate keysetPaginationQuery.fields(), sourceClass, new QueryCursorPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass), callback); - return ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), result, operations); + return ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), result, sourceClass, operations); } List result = doFind(collectionName, createDelegate(query), query.getQueryObject(), query.getFieldsObject(), diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java index 22747d382..21869b6d6 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java @@ -849,6 +849,8 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati Assert.notNull(sourceClass, "Entity type must not be null"); Assert.notNull(targetClass, "Target type must not be null"); + EntityProjection projection = operations.introspectProjection(targetClass, sourceClass); + ProjectingReadCallback callback = new ProjectingReadCallback<>(mongoConverter, projection, collectionName); int limit = query.isLimited() ? query.getLimit() + 1 : Integer.MAX_VALUE; if (query.hasKeyset()) { @@ -857,15 +859,15 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati operations.getIdPropertyName(sourceClass)); Mono> result = doFind(collectionName, ReactiveCollectionPreparerDelegate.of(query), - keysetPaginationQuery.query(), keysetPaginationQuery.fields(), targetClass, - new QueryFindPublisherPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass)).collectList(); + keysetPaginationQuery.query(), keysetPaginationQuery.fields(), sourceClass, + new QueryFindPublisherPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass), callback).collectList(); - return result.map(it -> ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), it, operations)); + return result.map(it -> ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), it, sourceClass, operations)); } Mono> result = doFind(collectionName, ReactiveCollectionPreparerDelegate.of(query), query.getQueryObject(), - query.getFieldsObject(), targetClass, - new QueryFindPublisherPreparer(query, query.getSortObject(), limit, query.getSkip(), sourceClass)) + query.getFieldsObject(), sourceClass, + new QueryFindPublisherPreparer(query, query.getSortObject(), limit, query.getSkip(), sourceClass), callback) .collectList(); return result.map( diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ScrollUtils.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ScrollUtils.java index 7c0ae5da9..112f95270 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ScrollUtils.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ScrollUtils.java @@ -121,14 +121,15 @@ class ScrollUtils { return sortOrder == 1 ? "$gt" : "$lt"; } - static Window createWindow(Document sortObject, int limit, List result, EntityOperations operations) { + static Window createWindow(Document sortObject, int limit, List result, Class sourceType, + EntityOperations operations) { IntFunction positionFunction = value -> { T last = result.get(value); Entity entity = operations.forEntity(last); - Map keys = entity.extractKeys(sortObject); + Map keys = entity.extractKeys(sortObject, sourceType); return KeysetScrollPosition.of(keys); }; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/EntityOperationsUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/EntityOperationsUnitTests.java index 6107abaa3..3635cf4e3 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/EntityOperationsUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/EntityOperationsUnitTests.java @@ -33,11 +33,13 @@ import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; import org.springframework.data.mongodb.core.mapping.TimeSeries; import org.springframework.data.mongodb.test.util.MongoTestMappingContext; +import org.springframework.data.projection.SpelAwareProxyProjectionFactory; /** * Unit tests for {@link EntityOperations}. * * @author Mark Paluch + * @author Christoph Strobl */ class EntityOperationsUnitTests { @@ -70,7 +72,8 @@ class EntityOperationsUnitTests { WithNestedDocument object = new WithNestedDocument("foo"); - Map keys = operations.forEntity(object).extractKeys(new Document("id", 1)); + Map keys = operations.forEntity(object).extractKeys(new Document("id", 1), + WithNestedDocument.class); assertThat(keys).containsEntry("id", "foo"); } @@ -80,7 +83,7 @@ class EntityOperationsUnitTests { Document object = new Document("id", "foo"); - Map keys = operations.forEntity(object).extractKeys(new Document("id", 1)); + Map keys = operations.forEntity(object).extractKeys(new Document("id", 1), Document.class); assertThat(keys).containsEntry("id", "foo"); } @@ -90,7 +93,8 @@ class EntityOperationsUnitTests { WithNestedDocument object = new WithNestedDocument("foo", new WithNestedDocument("bar"), null); - Map keys = operations.forEntity(object).extractKeys(new Document("nested.id", 1)); + Map keys = operations.forEntity(object).extractKeys(new Document("nested.id", 1), + WithNestedDocument.class); assertThat(keys).containsEntry("nested.id", "bar"); } @@ -101,7 +105,8 @@ class EntityOperationsUnitTests { WithNestedDocument object = new WithNestedDocument("foo", new WithNestedDocument("bar"), new Document("john", "doe")); - Map keys = operations.forEntity(object).extractKeys(new Document("document.john", 1)); + Map keys = operations.forEntity(object).extractKeys(new Document("document.john", 1), + WithNestedDocument.class); assertThat(keys).containsEntry("document.john", "doe"); } @@ -111,11 +116,32 @@ class EntityOperationsUnitTests { Document object = new Document("document", new Document("john", "doe")); - Map keys = operations.forEntity(object).extractKeys(new Document("document.john", 1)); + Map keys = operations.forEntity(object).extractKeys(new Document("document.john", 1), + Document.class); assertThat(keys).containsEntry("document.john", "doe"); } + @Test // GH-4308 + void shouldExtractIdPropertyNameFromRawDocument() { + + Document object = new Document("_id", "id-1").append("value", "val"); + + Map keys = operations.forEntity(object).extractKeys(new Document("value", 1), DomainTypeWithIdProperty.class); + + assertThat(keys).containsEntry("id", "id-1"); + } + + @Test // GH-4308 + void shouldExtractValuesFromProxy() { + + ProjectionInterface source = new SpelAwareProxyProjectionFactory().createProjection(ProjectionInterface.class, new Document("_id", "id-1").append("value", "val")); + + Map keys = operations.forEntity(source).extractKeys(new Document("value", 1), DomainTypeWithIdProperty.class); + + assertThat(keys).isEqualTo(new Document("id", "id-1").append("value", "val")); + } + EntityOperations.AdaptibleEntity initAdaptibleEntity(T source) { return operations.forEntity(source, conversionService); } @@ -150,4 +176,8 @@ class EntityOperationsUnitTests { this.id = id; } } + + interface ProjectionInterface { + String getValue(); + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateScrollTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateScrollTests.java index aca3ea4bc..88c03e915 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateScrollTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateScrollTests.java @@ -18,10 +18,13 @@ package org.springframework.data.mongodb.core; import static org.assertj.core.api.Assertions.*; import static org.springframework.data.mongodb.core.query.Criteria.*; +import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; +import java.lang.reflect.Proxy; import java.util.Arrays; +import java.util.Comparator; import java.util.function.Function; import java.util.stream.Stream; @@ -44,10 +47,13 @@ import org.springframework.data.domain.Sort; import org.springframework.data.domain.Window; import org.springframework.data.mapping.context.PersistentEntities; import org.springframework.data.mongodb.core.MongoTemplateTests.PersonWithIdPropertyOfTypeUUIDListener; +import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.test.util.Client; import org.springframework.data.mongodb.test.util.MongoClientExtension; import org.springframework.data.mongodb.test.util.MongoTestTemplate; +import org.springframework.lang.Nullable; +import org.springframework.util.ObjectUtils; import com.mongodb.client.MongoClient; @@ -90,10 +96,22 @@ class MongoTemplateScrollTests { }); }); + private static int compareProxies(PersonInterfaceProjection actual, PersonInterfaceProjection expected) { + if (actual.getAge() != expected.getAge()) { + return -1; + } + if (!ObjectUtils.nullSafeEquals(actual.getFirstName(), expected.getFirstName())) { + return -1; + } + + return 0; + } + @BeforeEach void setUp() { template.remove(Person.class).all(); template.remove(WithNestedDocument.class).all(); + template.remove(WithRenamedField.class).all(); } @Test // GH-4308 @@ -112,19 +130,19 @@ class MongoTemplateScrollTests { .limit(2); q.with(KeysetScrollPosition.initial()); - Window scroll = template.scroll(q, WithNestedDocument.class); + Window window = template.scroll(q, WithNestedDocument.class); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(2); - assertThat(scroll).containsOnly(john20, john40); + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertThat(window).containsOnly(john20, john40); - scroll = template.scroll(q.with(scroll.positionAt(scroll.size() - 1)), WithNestedDocument.class); + window = template.scroll(q.with(window.positionAt(window.size() - 1)), WithNestedDocument.class); - assertThat(scroll.hasNext()).isFalse(); - assertThat(scroll.isLast()).isTrue(); - assertThat(scroll).hasSize(1); - assertThat(scroll).containsOnly(john41); + assertThat(window.hasNext()).isFalse(); + assertThat(window.isLast()).isTrue(); + assertThat(window).hasSize(1); + assertThat(window).containsOnly(john41); } @Test // GH-4308 @@ -162,35 +180,35 @@ class MongoTemplateScrollTests { .limit(2); q.with(KeysetScrollPosition.initial()); - Window scroll = template.scroll(q, WithNestedDocument.class); + Window window = template.scroll(q, WithNestedDocument.class); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(2); - assertThat(scroll).containsOnly(john20, john40); + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertThat(window).containsOnly(john20, john40); - scroll = template.scroll(q.with(scroll.positionAt(scroll.size() - 1)), WithNestedDocument.class); + window = template.scroll(q.with(window.positionAt(window.size() - 1)), WithNestedDocument.class); - assertThat(scroll.hasNext()).isFalse(); - assertThat(scroll.isLast()).isTrue(); - assertThat(scroll).hasSize(1); - assertThat(scroll).containsOnly(john41); + assertThat(window.hasNext()).isFalse(); + assertThat(window.isLast()).isTrue(); + assertThat(window).hasSize(1); + assertThat(window).containsOnly(john41); - KeysetScrollPosition scrollPosition = (KeysetScrollPosition) scroll.positionAt(0); + KeysetScrollPosition scrollPosition = (KeysetScrollPosition) window.positionAt(0); KeysetScrollPosition reversePosition = KeysetScrollPosition.of(scrollPosition.getKeys(), Direction.Backward); - scroll = template.scroll(q.with(reversePosition), WithNestedDocument.class); + window = template.scroll(q.with(reversePosition), WithNestedDocument.class); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(2); - assertThat(scroll).containsOnly(john20, john40); + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertThat(window).containsOnly(john20, john40); } @ParameterizedTest // GH-4308 @MethodSource("positions") public void shouldApplyCursoringCorrectly(ScrollPosition scrollPosition, Class resultType, - Function assertionConverter) { + Function assertionConverter, @Nullable Comparator comparator) { Person john20 = new Person("John", 20); Person john40_1 = new Person("John", 40); @@ -201,53 +219,182 @@ class MongoTemplateScrollTests { template.insertAll(Arrays.asList(john20, john40_1, john40_2, jane_20, jane_40, jane_42)); Query q = new Query(where("firstName").regex("J.*")).with(Sort.by("firstName", "age")).limit(2); - q.with(scrollPosition); - Window scroll = template.scroll(q, resultType, "person"); + Window window = template.query(Person.class).inCollection("person").as(resultType).matching(q) + .scroll(scrollPosition); + + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertWindow(window, comparator).containsOnly(assertionConverter.apply(jane_20), assertionConverter.apply(jane_40)); + + window = template.query(Person.class).inCollection("person").as(resultType).matching(q.limit(3)) + .scroll(window.positionAt(window.size() - 1)); + + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(3); + assertWindow(window, comparator).contains(assertionConverter.apply(jane_42), assertionConverter.apply(john20)); + assertWindow(window, comparator).containsAnyOf(assertionConverter.apply(john40_1), + assertionConverter.apply(john40_2)); + + window = template.query(Person.class).inCollection("person").as(resultType).matching(q.limit(1)) + .scroll(window.positionAt(window.size() - 1)); + + assertThat(window.hasNext()).isFalse(); + assertThat(window.isLast()).isTrue(); + assertThat(window).hasSize(1); + assertWindow(window, comparator).containsAnyOf(assertionConverter.apply(john40_1), + assertionConverter.apply(john40_2)); + } + + @ParameterizedTest // GH-4308 + @MethodSource("renamedFieldProjectTargets") + void scrollThroughResultsWithRenamedField(Class resultType, Function assertionConverter) { + + WithRenamedField one = new WithRenamedField("id-1", "v1", null); + WithRenamedField two = new WithRenamedField("id-2", "v2", null); + WithRenamedField three = new WithRenamedField("id-3", "v3", null); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(2); - assertThat(scroll).containsOnly(assertionConverter.apply(jane_20), assertionConverter.apply(jane_40)); + template.insertAll(Arrays.asList(one, two, three)); - scroll = template.scroll(q.with(scroll.positionAt(scroll.size() - 1)).limit(3), resultType, "person"); + Query q = new Query(where("value").regex("v.*")).with(Sort.by(Sort.Direction.DESC, "value")).limit(2); + q.with(KeysetScrollPosition.initial()); + + Window window = template.query(WithRenamedField.class).as(resultType).matching(q) + .scroll(KeysetScrollPosition.initial()); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(3); - assertThat(scroll).contains(assertionConverter.apply(jane_42), assertionConverter.apply(john20)); - assertThat(scroll).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertThat(window).containsOnly(assertionConverter.apply(three), assertionConverter.apply(two)); - scroll = template.scroll(q.with(scroll.positionAt(scroll.size() - 1)).limit(1), resultType, "person"); + window = template.query(WithRenamedField.class).as(resultType).matching(q) + .scroll(window.positionAt(window.size() - 1)); - assertThat(scroll.hasNext()).isFalse(); - assertThat(scroll.isLast()).isTrue(); - assertThat(scroll).hasSize(1); - assertThat(scroll).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); + assertThat(window.hasNext()).isFalse(); + assertThat(window.isLast()).isTrue(); + assertThat(window).hasSize(1); + assertThat(window).containsOnly(assertionConverter.apply(one)); } static Stream positions() { return Stream.of(args(KeysetScrollPosition.initial(), Person.class, Function.identity()), // args(KeysetScrollPosition.initial(), Document.class, MongoTemplateScrollTests::toDocument), // - args(OffsetScrollPosition.initial(), Person.class, Function.identity())); + args(OffsetScrollPosition.initial(), Person.class, Function.identity()), // + args(OffsetScrollPosition.initial(), PersonDtoProjection.class, + MongoTemplateScrollTests::toPersonDtoProjection), // + args(OffsetScrollPosition.initial(), PersonInterfaceProjection.class, + MongoTemplateScrollTests::toPersonInterfaceProjection, MongoTemplateScrollTests::compareProxies)); + } + + static Stream renamedFieldProjectTargets() { + return Stream.of(Arguments.of(WithRenamedField.class, Function.identity()), + Arguments.of(Document.class, new Function() { + @Override + public Document apply(WithRenamedField withRenamedField) { + return new Document("_id", withRenamedField.getId()).append("_val", withRenamedField.getValue()) + .append("_class", WithRenamedField.class.getName()); + } + })); + } + + static org.assertj.core.api.IterableAssert assertWindow(Window window, @Nullable Comparator comparator) { + return comparator != null ? assertThat(window).usingElementComparator(comparator) : assertThat(window); } private static Arguments args(ScrollPosition scrollPosition, Class resultType, Function assertionConverter) { - return Arguments.of(scrollPosition, resultType, assertionConverter); + return args(scrollPosition, resultType, assertionConverter, null); + } + + private static Arguments args(ScrollPosition scrollPosition, Class resultType, + Function assertionConverter, @Nullable Comparator comparator) { + return Arguments.of(scrollPosition, resultType, assertionConverter, comparator); } static Document toDocument(Person person) { + return new Document("_class", person.getClass().getName()).append("_id", person.getId()).append("active", true) .append("firstName", person.getFirstName()).append("age", person.getAge()); } + static PersonDtoProjection toPersonDtoProjection(Person person) { + + PersonDtoProjection dto = new PersonDtoProjection(); + dto.firstName = person.getFirstName(); + dto.age = person.getAge(); + return dto; + } + + static PersonInterfaceProjection toPersonInterfaceProjection(Person person) { + + return new PersonInterfaceProjectionImpl(person); + } + + @Data + static class PersonDtoProjection { + String firstName; + int age; + } + + interface PersonInterfaceProjection { + String getFirstName(); + + int getAge(); + } + + static class PersonInterfaceProjectionImpl implements PersonInterfaceProjection { + + final Person delegate; + + public PersonInterfaceProjectionImpl(Person delegate) { + this.delegate = delegate; + } + + @Override + public String getFirstName() { + return delegate.getFirstName(); + } + + @Override + public int getAge() { + return delegate.getAge(); + } + + @Override + public boolean equals(Object o) { + if (o instanceof Proxy) { + return true; + } + return false; + } + + @Override + public int hashCode() { + return ObjectUtils.nullSafeHashCode(delegate); + } + } + + @Data + @AllArgsConstructor + @NoArgsConstructor + static class WithRenamedField { + + String id; + + @Field("_val") String value; + + WithRenamedField nested; + } + @NoArgsConstructor @Data class WithNestedDocument { String id; + String name; int age; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateScrollTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateScrollTests.java index d42d8d99f..35f67782f 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateScrollTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateScrollTests.java @@ -18,6 +18,10 @@ package org.springframework.data.mongodb.core; import static org.springframework.data.mongodb.core.query.Criteria.*; import static org.springframework.data.mongodb.test.util.Assertions.*; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.springframework.data.mongodb.core.mapping.Field; import reactor.test.StepVerifier; import java.time.Duration; @@ -49,6 +53,7 @@ import com.mongodb.reactivestreams.client.MongoClient; * Integration tests for {@link Window} queries. * * @author Mark Paluch + * @author Christoph Strobl */ @ExtendWith(MongoClientExtension.class) class ReactiveMongoTemplateScrollTests { @@ -78,6 +83,11 @@ class ReactiveMongoTemplateScrollTests { .as(StepVerifier::create) // .expectNextCount(1) // .verifyComplete(); + + template.remove(WithRenamedField.class).all() // + .as(StepVerifier::create) // + .expectNextCount(1) // + .verifyComplete(); } @ParameterizedTest // GH-4308 @@ -100,29 +110,59 @@ class ReactiveMongoTemplateScrollTests { Query q = new Query(where("firstName").regex("J.*")).with(Sort.by("firstName", "age")).limit(2); q.with(scrollPosition); - Window scroll = template.scroll(q, resultType, "person").block(Duration.ofSeconds(10)); + Window window = template.scroll(q, resultType, "person").block(Duration.ofSeconds(10)); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(2); - assertThat(scroll).containsOnly(assertionConverter.apply(jane_20), assertionConverter.apply(jane_40)); + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertThat(window).containsOnly(assertionConverter.apply(jane_20), assertionConverter.apply(jane_40)); - scroll = template.scroll(q.limit(3).with(scroll.positionAt(scroll.size() - 1)), resultType, "person") + window = template.scroll(q.limit(3).with(window.positionAt(window.size() - 1)), resultType, "person") .block(Duration.ofSeconds(10)); - assertThat(scroll.hasNext()).isTrue(); - assertThat(scroll.isLast()).isFalse(); - assertThat(scroll).hasSize(3); - assertThat(scroll).contains(assertionConverter.apply(jane_42), assertionConverter.apply(john20)); - assertThat(scroll).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(3); + assertThat(window).contains(assertionConverter.apply(jane_42), assertionConverter.apply(john20)); + assertThat(window).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); - scroll = template.scroll(q.limit(1).with(scroll.positionAt(scroll.size() - 1)), resultType, "person") + window = template.scroll(q.limit(1).with(window.positionAt(window.size() - 1)), resultType, "person") .block(Duration.ofSeconds(10)); - assertThat(scroll.hasNext()).isFalse(); - assertThat(scroll.isLast()).isTrue(); - assertThat(scroll).hasSize(1); - assertThat(scroll).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); + assertThat(window.hasNext()).isFalse(); + assertThat(window.isLast()).isTrue(); + assertThat(window).hasSize(1); + assertThat(window).containsAnyOf(assertionConverter.apply(john40_1), assertionConverter.apply(john40_2)); + } + + @ParameterizedTest // GH-4308 + @MethodSource("renamedFieldProjectTargets") + void scrollThroughResultsWithRenamedField(Class resultType, Function assertionConverter) { + + WithRenamedField one = new WithRenamedField("id-1", "v1", null); + WithRenamedField two = new WithRenamedField("id-2", "v2", null); + WithRenamedField three = new WithRenamedField("id-3", "v3", null); + + template.insertAll(Arrays.asList(one, two, three)).as(StepVerifier::create).expectNextCount(3).verifyComplete(); + + Query q = new Query(where("value").regex("v.*")).with(Sort.by(Sort.Direction.DESC, "value")).limit(2); + q.with(KeysetScrollPosition.initial()); + + Window window = template.query(WithRenamedField.class).as(resultType).matching(q) + .scroll(KeysetScrollPosition.initial()).block(Duration.ofSeconds(10)); + + assertThat(window.hasNext()).isTrue(); + assertThat(window.isLast()).isFalse(); + assertThat(window).hasSize(2); + assertThat(window).containsOnly(assertionConverter.apply(three), assertionConverter.apply(two)); + + window = template.query(WithRenamedField.class).as(resultType).matching(q) + .scroll(window.positionAt(window.size() - 1)).block(Duration.ofSeconds(10)); + + assertThat(window.hasNext()).isFalse(); + assertThat(window.isLast()).isTrue(); + assertThat(window).hasSize(1); + assertThat(window).containsOnly(assertionConverter.apply(one)); } static Stream positions() { @@ -132,6 +172,17 @@ class ReactiveMongoTemplateScrollTests { args(OffsetScrollPosition.initial(), Person.class, Function.identity())); } + static Stream renamedFieldProjectTargets() { + return Stream.of(Arguments.of(WithRenamedField.class, Function.identity()), + Arguments.of(Document.class, new Function() { + @Override + public Document apply(WithRenamedField withRenamedField) { + return new Document("_id", withRenamedField.getId()).append("_val", withRenamedField.getValue()) + .append("_class", WithRenamedField.class.getName()); + } + })); + } + private static Arguments args(ScrollPosition scrollPosition, Class resultType, Function assertionConverter) { return Arguments.of(scrollPosition, resultType, assertionConverter); @@ -141,4 +192,16 @@ class ReactiveMongoTemplateScrollTests { return new Document("_class", person.getClass().getName()).append("_id", person.getId()).append("active", true) .append("firstName", person.getFirstName()).append("age", person.getAge()); } + + @Data + @AllArgsConstructor + @NoArgsConstructor + static class WithRenamedField { + + String id; + + @Field("_val") String value; + + MongoTemplateScrollTests.WithRenamedField nested; + } }