From 8bcab93588b9930b9d13747e7fac1f961908aac7 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Mon, 28 Nov 2022 11:13:25 +0100 Subject: [PATCH] Avoid multiple mapping iterations. A 2nd pass is no longer needed as the context already does all the work. Closes: #4043 Original pull request: #4240 --- .../data/mongodb/core/AggregationUtil.java | 19 ++-------- .../core/aggregation/AggregationTests.java | 35 ++++++++++++++++++- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java index 63f27dca0..97635d5d9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java @@ -32,7 +32,6 @@ import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; import org.springframework.data.util.Lazy; import org.springframework.lang.Nullable; -import org.springframework.util.ObjectUtils; /** * Utility methods to map {@link org.springframework.data.mongodb.core.aggregation.Aggregation} pipeline definitions and @@ -96,12 +95,7 @@ class AggregationUtil { * @return */ List createPipeline(Aggregation aggregation, AggregationOperationContext context) { - - if (ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) { - return aggregation.toPipeline(context); - } - - return mapAggregationPipeline(aggregation.toPipeline(context)); + return aggregation.toPipeline(context); } /** @@ -112,16 +106,7 @@ class AggregationUtil { * @return */ Document createCommand(String collection, Aggregation aggregation, AggregationOperationContext context) { - - Document command = aggregation.toDocument(collection, context); - - if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) { - return command; - } - - command.put("pipeline", mapAggregationPipeline(command.get("pipeline", List.class))); - - return command; + return aggregation.toDocument(collection, context); } private List mapAggregationPipeline(List pipeline) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index 6ada403c4..c98b0c0b5 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -43,11 +43,11 @@ import java.util.stream.Stream; import org.assertj.core.data.Offset; import org.bson.Document; +import org.bson.types.ObjectId; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; - import org.springframework.core.io.ClassPathResource; import org.springframework.data.annotation.Id; import org.springframework.data.domain.Sort; @@ -65,6 +65,7 @@ import org.springframework.data.mongodb.core.aggregation.VariableOperators.Let.E import org.springframework.data.mongodb.core.geo.GeoJsonPoint; import org.springframework.data.mongodb.core.index.GeoSpatialIndexType; import org.springframework.data.mongodb.core.index.GeospatialIndex; +import org.springframework.data.mongodb.core.mapping.MongoId; import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; @@ -1933,6 +1934,24 @@ public class AggregationTests { assertThat(results.getMappedResults()).hasSize(1); } + @Test // GH-4043 + void considersMongoIdWithinTypedCollections() { + + UserRef userRef = new UserRef(); + userRef.id = "4ee921aca44fd11b3254e001"; + userRef.name = "u-1"; + + Widget widget = new Widget(); + widget.id = "w-1"; + widget.users = List.of(userRef); + + mongoTemplate.save(widget); + + Criteria criteria = Criteria.where("users").elemMatch(Criteria.where("id").is("4ee921aca44fd11b3254e001")); + AggregationResults aggregate = mongoTemplate.aggregate(newAggregation(match(criteria)), Widget.class, Widget.class); + assertThat(aggregate.getMappedResults()).contains(widget); + } + private void createUsersWithReferencedPersons() { mongoTemplate.dropCollection(User.class); @@ -2250,4 +2269,18 @@ public class AggregationTests { @Id String id; MyEnum enumValue; } + + @lombok.Data + static class Widget { + @Id + String id; + List users; + } + + @lombok.Data + static class UserRef { + @MongoId + String id; + String name; + } }