From adea4ba0a95dccfa3f20ade98b85fbc8b650f59a Mon Sep 17 00:00:00 2001 From: Julia Lee <5765049+sxhinzvc@users.noreply.github.com> Date: Wed, 26 Jul 2023 08:47:41 -0400 Subject: [PATCH] Fix mapping custom field names in downstream stages in TypedAggregation pipelines. Use the root AggregationOperationContext in nested ExposedFieldsAggregationOperationContext to properly apply mapping for domain properties that use @Field. Closes #4443 Original Pull Request: #4459 --- ...osedFieldsAggregationOperationContext.java | 5 ++++ .../core/aggregation/AggregationTests.java | 23 +++++++++++++++++-- .../aggregation/AggregationUnitTests.java | 21 ++++++++++++++++- 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java index 568a1a031..7387625fc 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java @@ -60,6 +60,11 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo return rootContext.getMappedObject(document, type); } + @Override + public Document getMappedObject(Document document) { + return rootContext.getMappedObject(document); + } + @Override public FieldReference getReference(Field field) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index a8cc529b0..01e42e64d 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -91,6 +91,7 @@ import com.mongodb.client.MongoCollection; * @author Sergey Shcherbakov * @author Minsu Kim * @author Sangyong Choi + * @author Julia Lee */ @ExtendWith(MongoTemplateExtension.class) public class AggregationTests { @@ -119,7 +120,7 @@ public class AggregationTests { mongoTemplate.flush(Product.class, UserWithLikes.class, DATAMONGO753.class, Data.class, DATAMONGO788.class, User.class, Person.class, Reservation.class, Venue.class, MeterData.class, LineItem.class, InventoryItem.class, - Sales.class, Sales2.class, Employee.class, Art.class, Venue.class); + Sales.class, Sales2.class, Employee.class, Art.class, Venue.class, Item.class); mongoTemplate.dropCollection(INPUT_COLLECTION); mongoTemplate.dropCollection("personQueryTemp"); @@ -1992,6 +1993,23 @@ public class AggregationTests { assertThat(aggregate.getMappedResults()).contains(widget); } + @Test // GH-4443 + void shouldHonorFieldAliasesForFieldReferencesUsingFieldExposingOperation() { + + Item item1 = Item.builder().itemId("1").tags(Arrays.asList("a", "b")).build(); + Item item2 = Item.builder().itemId("1").tags(Arrays.asList("a", "c")).build(); + mongoTemplate.insert(Arrays.asList(item1, item2), Item.class); + + TypedAggregation aggregation = newAggregation(Item.class, + match(where("itemId").is("1")), + unwind("tags"), + match(where("itemId").is("1").and("tags").is("c"))); + AggregationResults results = mongoTemplate.aggregate(aggregation, Document.class); + List mappedResults = results.getMappedResults(); + assertThat(mappedResults).hasSize(1); + assertThat(mappedResults.get(0)).containsEntry("item_id", "1"); + } + private void createUsersWithReferencedPersons() { mongoTemplate.dropCollection(User.class); @@ -2244,7 +2262,7 @@ public class AggregationTests { List items; } - // DATAMONGO-1491 + // DATAMONGO-1491, GH-4443 @lombok.Data @Builder static class Item { @@ -2253,6 +2271,7 @@ public class AggregationTests { String itemId; Integer quantity; Long price; + List tags = new ArrayList<>(); } // DATAMONGO-1538 diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java index 69721a45c..154a17f2f 100755 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java @@ -49,6 +49,7 @@ import com.mongodb.client.model.Projections; * @author Thomas Darimont * @author Christoph Strobl * @author Mark Paluch + * @author Julia Lee */ public class AggregationUnitTests { @@ -612,7 +613,7 @@ public class AggregationUnitTests { WithRetypedIdField.class, mappingContext, new QueryMapper(new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, mappingContext))); Document document = project(WithRetypedIdField.class).toDocument(context); - assertThat(document).isEqualTo(new Document("$project", new Document("_id", 1).append("renamed-field", 1))); + assertThat(document).isEqualTo(new Document("$project", new Document("_id", 1).append("renamed-field", 1).append("entries", 1))); } @Test // GH-4038 @@ -653,6 +654,22 @@ public class AggregationUnitTests { assertThat(documents.get(2)).isEqualTo("{ $sort : { 'serial_number' : -1, 'label_name' : -1 } }"); } + @Test // GH-4443 + void fieldsExposingContextShouldUseCustomFieldNameFromRelaxedRootContext() { + + MongoMappingContext mappingContext = new MongoMappingContext(); + RelaxedTypeBasedAggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext( + WithRetypedIdField.class, mappingContext, + new QueryMapper(new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, mappingContext))); + + TypedAggregation agg = newAggregation(WithRetypedIdField.class, + unwind("entries"), match(where("foo").is("value 2"))); + List pipeline = agg.toPipeline(context); + + Document fields = getAsDocument(pipeline.get(1), "$match"); + assertThat(fields.get("renamed-field")).isEqualTo("value 2"); + } + private Document extractPipelineElement(Document agg, int index, String operation) { List pipeline = (List) agg.get("pipeline"); @@ -672,5 +689,7 @@ public class AggregationUnitTests { @org.springframework.data.mongodb.core.mapping.Field("renamed-field") private String foo; + private List entries = new ArrayList<>(); + } }