diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java index ea29f751d..c15eaaf85 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java @@ -62,7 +62,7 @@ class AggregationOperationRenderer { if (operation instanceof InheritsFieldsAggregationOperation || exposedFieldsOperation.inheritsFields()) { contextToUse = contextToUse.inheritAndExpose(fields); } else { - contextToUse = fields.exposesNoFields() ? DEFAULT_CONTEXT + contextToUse = fields.exposesNoFields() ? ConverterAwareNoOpContext.instance(rootContext) : contextToUse.expose(fields); } } @@ -72,6 +72,39 @@ class AggregationOperationRenderer { return operationDocuments; } + private static class ConverterAwareNoOpContext implements AggregationOperationContext { + + AggregationOperationContext ctx; + + static ConverterAwareNoOpContext instance(AggregationOperationContext ctx) { + + if(ctx instanceof ConverterAwareNoOpContext noOpContext) { + return noOpContext; + } + + return new ConverterAwareNoOpContext(ctx); + } + + ConverterAwareNoOpContext(AggregationOperationContext ctx) { + this.ctx = ctx; + } + + @Override + public Document getMappedObject(Document document, @Nullable Class type) { + return ctx.getMappedObject(document, null); + } + + @Override + public FieldReference getReference(Field field) { + return new DirectFieldReference(new ExposedField(field, true)); + } + + @Override + public FieldReference getReference(String name) { + return new DirectFieldReference(new ExposedField(new AggregationField(name), true)); + } + } + /** * Simple {@link AggregationOperationContext} that just returns {@link FieldReference}s as is. * diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java index a8b32f957..36d3e62a4 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java @@ -19,14 +19,28 @@ import static org.mockito.Mockito.*; import static org.springframework.data.domain.Sort.Direction.*; import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; +import java.time.ZonedDateTime; import java.util.List; +import java.util.Set; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.bson.Document; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.springframework.data.annotation.Id; +import org.springframework.data.convert.ConverterBuilder; +import org.springframework.data.convert.CustomConversions; +import org.springframework.data.convert.CustomConversions.StoreConversions; +import org.springframework.data.domain.Sort.Direction; import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.mapping.Field; +import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.test.util.MongoTestMappingContext; /** @@ -47,6 +61,79 @@ public class AggregationOperationRendererUnitTests { verify(stage2).toPipelineStages(eq(rootContext)); } + @Test + void contextShouldCarryOnRelaxedFieldMapping() { + + MongoTestMappingContext ctx = new MongoTestMappingContext(cfg -> { + cfg.initialEntitySet(TestRecord.class); + }); + + MappingMongoConverter mongoConverter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, ctx); + + Aggregation agg = Aggregation.newAggregation(Aggregation.unwind("layerOne.layerTwo"), + project().and("layerOne.layerTwo.layerThree").as("layerOne.layerThree"), + sort(DESC, "layerOne.layerThree.fieldA")); + + AggregationOperationRenderer.toDocument(agg.getPipeline().getOperations(), + new RelaxedTypeBasedAggregationOperationContext(TestRecord.class, ctx, new QueryMapper(mongoConverter))); + } + + @Test // GH-4722 + void appliesConversionToValuesUsedInAggregation() { + + MongoTestMappingContext ctx = new MongoTestMappingContext(cfg -> { + cfg.initialEntitySet(TestRecord.class); + }); + + MappingMongoConverter mongoConverter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, ctx); + mongoConverter.setCustomConversions(new CustomConversions(StoreConversions.NONE, + Set.copyOf(ConverterBuilder.writing(ZonedDateTime.class, String.class, ZonedDateTime::toString) + .andReading(it -> ZonedDateTime.parse(it)).getConverters()))); + mongoConverter.afterPropertiesSet(); + + var agg = Aggregation.newAggregation(Aggregation.sort(Direction.DESC, "version"), + Aggregation.group("entityId").first(Aggregation.ROOT).as("value"), Aggregation.replaceRoot("value"), + Aggregation.match(Criteria.where("createdDate").lt(ZonedDateTime.now())) // here is the problem + ); + + List document = AggregationOperationRenderer.toDocument(agg.getPipeline().getOperations(), + new RelaxedTypeBasedAggregationOperationContext(TestRecord.class, ctx, new QueryMapper(mongoConverter))); + Assertions.assertThat(document).last() + .extracting(it -> it.getEmbedded(List.of("$match", "createdDate", "$lt"), Object.class)) + .isInstanceOf(String.class); + } + + @ParameterizedTest // GH-4722 + @MethodSource("studentAggregationContexts") + void mapsOperationThatDoesNotExposeDedicatedFieldsCorrectly(AggregationOperationContext aggregationContext) { + + var agg = newAggregation(Student.class, Aggregation.unwind("grades"), Aggregation.replaceRoot("grades"), + Aggregation.project("grades")); + + List mappedPipeline = AggregationOperationRenderer.toDocument(agg.getPipeline().getOperations(), + aggregationContext); + + Assertions.assertThat(mappedPipeline).last().isEqualTo(Document.parse("{\"$project\": {\"grades\": 1}}")); + } + + private static Stream studentAggregationContexts() { + + MongoTestMappingContext ctx = new MongoTestMappingContext(cfg -> { + cfg.initialEntitySet(Student.class); + }); + + MappingMongoConverter mongoConverter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, ctx); + mongoConverter.afterPropertiesSet(); + + QueryMapper queryMapper = new QueryMapper(mongoConverter); + + return Stream.of( + Arguments + .of(new TypeBasedAggregationOperationContext(Student.class, ctx, queryMapper, FieldLookupPolicy.strict())), + Arguments.of( + new TypeBasedAggregationOperationContext(Student.class, ctx, queryMapper, FieldLookupPolicy.relaxed()))); + } + record TestRecord(@Id String field1, String field2, LayerOne layerOne) { record LayerOne(List layerTwo) { } @@ -54,25 +141,20 @@ public class AggregationOperationRendererUnitTests { record LayerTwo(LayerThree layerThree) { } - record LayerThree(int fieldA, int fieldB) - {} + record LayerThree(int fieldA, int fieldB) { + } } - @Test - void xxx() { + static class Student { - MongoTestMappingContext ctx = new MongoTestMappingContext(cfg -> { - cfg.initialEntitySet(TestRecord.class); - }); + @Field("mark") List grades; - MappingMongoConverter mongoConverter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, ctx); + } - Aggregation agg = Aggregation.newAggregation( - Aggregation.unwind("layerOne.layerTwo"), - project().and("layerOne.layerTwo.layerThree").as("layerOne.layerThree"), - sort(DESC, "layerOne.layerThree.fieldA") - ); + static class Grade { - AggregationOperationRenderer.toDocument(agg.getPipeline().getOperations(), new RelaxedTypeBasedAggregationOperationContext(TestRecord.class, ctx, new QueryMapper(mongoConverter))); + int points; + String grades; } + }