From f7b913535e06b9fea2acc7f0188261322eb9bbda Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Mon, 28 Nov 2022 11:13:25 +0100 Subject: [PATCH] Avoid multiple mapping iterations. A 2nd pass is no longer needed as the context already does all the work. Closes: #4043 Original pull request: #4240 --- spring-data-mongodb/pom.xml | 2 +- .../data/mongodb/core/AggregationUtil.java | 35 +++---------------- .../core/aggregation/AggregationTests.java | 33 ++++++++++++++++- 3 files changed, 38 insertions(+), 32 deletions(-) diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index 83cbb58fc..7a74583a5 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -305,7 +305,7 @@ io.mockk - mockk + mockk-jvm ${mockk} test diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java index bfbd5d40a..968f3bda1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java @@ -15,8 +15,6 @@ */ package org.springframework.data.mongodb.core; -import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -24,23 +22,16 @@ import java.util.stream.Collectors; import org.bson.Document; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.aggregation.Aggregation; -import org.springframework.data.mongodb.core.aggregation.AggregationOperation; import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; -import org.springframework.data.mongodb.core.aggregation.AggregationOptions; import org.springframework.data.mongodb.core.aggregation.AggregationOptions.DomainTypeMapping; -import org.springframework.data.mongodb.core.aggregation.CountOperation; import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.convert.QueryMapper; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; -import org.springframework.data.mongodb.core.query.CriteriaDefinition; -import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.util.Lazy; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; -import org.springframework.util.ObjectUtils; /** * Utility methods to map {@link org.springframework.data.mongodb.core.aggregation.Aggregation} pipeline definitions and @@ -75,12 +66,11 @@ class AggregationUtil { if (!(aggregation instanceof TypedAggregation)) { - if(inputType == null) { + if (inputType == null) { return untypedMappingContext.get(); } - if (domainTypeMapping == DomainTypeMapping.STRICT - && !aggregation.getPipeline().containsUnionWith()) { + if (domainTypeMapping == DomainTypeMapping.STRICT && !aggregation.getPipeline().containsUnionWith()) { return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper); } @@ -88,8 +78,7 @@ class AggregationUtil { } inputType = ((TypedAggregation) aggregation).getInputType(); - if (domainTypeMapping == DomainTypeMapping.STRICT - && !aggregation.getPipeline().containsUnionWith()) { + if (domainTypeMapping == DomainTypeMapping.STRICT && !aggregation.getPipeline().containsUnionWith()) { return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper); } @@ -104,12 +93,7 @@ class AggregationUtil { * @return */ List createPipeline(Aggregation aggregation, AggregationOperationContext context) { - - if (ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) { - return aggregation.toPipeline(context); - } - - return mapAggregationPipeline(aggregation.toPipeline(context)); + return aggregation.toPipeline(context); } /** @@ -120,16 +104,7 @@ class AggregationUtil { * @return */ Document createCommand(String collection, Aggregation aggregation, AggregationOperationContext context) { - - Document command = aggregation.toDocument(collection, context); - - if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) { - return command; - } - - command.put("pipeline", mapAggregationPipeline(command.get("pipeline", List.class))); - - return command; + return aggregation.toDocument(collection, context); } private List mapAggregationPipeline(List pipeline) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index 2f58bec24..078eb65ba 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -44,7 +44,6 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; - import org.springframework.core.io.ClassPathResource; import org.springframework.data.annotation.Id; import org.springframework.data.domain.Sort; @@ -62,6 +61,7 @@ import org.springframework.data.mongodb.core.aggregation.VariableOperators.Let.E import org.springframework.data.mongodb.core.geo.GeoJsonPoint; import org.springframework.data.mongodb.core.index.GeoSpatialIndexType; import org.springframework.data.mongodb.core.index.GeospatialIndex; +import org.springframework.data.mongodb.core.mapping.MongoId; import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; @@ -1943,6 +1943,25 @@ public class AggregationTests { assertThat(results.getMappedResults()).hasSize(1); } + @Test // GH-4043 + void considersMongoIdWithinTypedCollections() { + + UserRef userRef = new UserRef(); + userRef.id = "4ee921aca44fd11b3254e001"; + userRef.name = "u-1"; + + Widget widget = new Widget(); + widget.id = "w-1"; + widget.users = Collections.singletonList(userRef); + + mongoTemplate.save(widget); + + Criteria criteria = Criteria.where("users").elemMatch(Criteria.where("id").is("4ee921aca44fd11b3254e001")); + AggregationResults aggregate = mongoTemplate.aggregate(newAggregation(match(criteria)), Widget.class, + Widget.class); + assertThat(aggregate.getMappedResults()).contains(widget); + } + private void createUsersWithReferencedPersons() { mongoTemplate.dropCollection(User.class); @@ -2266,4 +2285,16 @@ public class AggregationTests { @Id String id; MyEnum enumValue; } + + @lombok.Data + static class Widget { + @Id String id; + List users; + } + + @lombok.Data + static class UserRef { + @MongoId String id; + String name; + } }