diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index cc09f54ec..cb9e70dd1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -28,6 +28,7 @@ import org.springframework.data.mongodb.core.aggregation.AddFieldsOperation.AddF import org.springframework.data.mongodb.core.aggregation.CountOperation.CountOperationBuilder; import org.springframework.data.mongodb.core.aggregation.FacetOperation.FacetOperationBuilder; import org.springframework.data.mongodb.core.aggregation.GraphLookupOperation.StartWithBuilder; +import org.springframework.data.mongodb.core.aggregation.LookupOperation.LookupOperationBuilder; import org.springframework.data.mongodb.core.aggregation.MergeOperation.MergeOperationBuilder; import org.springframework.data.mongodb.core.aggregation.ReplaceRootOperation.ReplaceRootDocumentOperationBuilder; import org.springframework.data.mongodb.core.aggregation.ReplaceRootOperation.ReplaceRootOperationBuilder; @@ -665,16 +666,21 @@ public class Aggregation { return new LookupOperation(from, localField, foreignField, as); } - public static LookupOperation lookup(String from, String localField, String foreignField, String as, List aggregationOperations) { - return lookup(field(from), field(localField), field(foreignField), field(as), null, new AggregationPipeline(aggregationOperations)); - } - - public static LookupOperation lookup(String from, String localField, String foreignField, String as, List letExpressionVars, List aggregationOperations) { - return lookup(field(from), field(localField), field(foreignField), field(as), new LookupOperation.Let(letExpressionVars), new AggregationPipeline(aggregationOperations)); - } - - public static LookupOperation lookup(Field from, Field localField, Field foreignField, Field as, LookupOperation.Let let, AggregationPipeline pipeline) { - return new LookupOperation(from, localField, foreignField, as, let, pipeline); + /** + * Entrypoint for creating {@link LookupOperation $lookup} using a fluent builder API. + *
+	 * Aggregation.lookup().from("restaurants")
+	 * 	.localField("restaurant_name")
+	 * 	.foreignField("name")
+	 * 	.let(newVariable("orders_drink").forField("drink"))
+	 * 	.pipeline(match(ctx -> new Document("$expr", new Document("$in", List.of("$$orders_drink", "$beverages")))))
+	 * 	.as("matches")
+	 * 
+ * @return new instance of {@link LookupOperationBuilder}. + * @since 4.1 + */ + public static LookupOperationBuilder lookup() { + return new LookupOperationBuilder(); } /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java index ff7999e43..44d0f1569 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/LookupOperation.java @@ -15,17 +15,19 @@ */ package org.springframework.data.mongodb.core.aggregation; -import java.util.List; +import java.util.function.Supplier; import org.bson.Document; import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; import org.springframework.data.mongodb.core.aggregation.FieldsExposingAggregationOperation.InheritsFieldsAggregationOperation; +import org.springframework.data.mongodb.core.aggregation.VariableOperators.Let; +import org.springframework.data.mongodb.core.aggregation.VariableOperators.Let.ExpressionVariable; import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** - * Encapsulates the aggregation framework {@code $lookup}-operation. We recommend to use the static factory method - * {@link Aggregation#lookup(String, String, String, String)} instead of creating instances of this class directly. + * Encapsulates the aggregation framework {@code $lookup}-operation. We recommend to use the builder provided via + * {@link #newLookup()} instead of creating instances of this class directly. * * @author Alessio Fachechi * @author Christoph Strobl @@ -37,16 +39,22 @@ import org.springframework.util.Assert; */ public class LookupOperation implements FieldsExposingAggregationOperation, InheritsFieldsAggregationOperation { - private final Field from; + private final String from; + + @Nullable // private final Field localField; + + @Nullable // private final Field foreignField; - private final ExposedField as; - @Nullable + @Nullable // private final Let let; - @Nullable + + @Nullable // private final AggregationPipeline pipeline; + private final ExposedField as; + /** * Creates a new {@link LookupOperation} for the given {@link Field}s. * @@ -56,13 +64,47 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe * @param as must not be {@literal null}. */ public LookupOperation(Field from, Field localField, Field foreignField, Field as) { - this(from, localField, foreignField, as, null, null); + this(((Supplier) () -> { + + Assert.notNull(from, "From must not be null"); + return from.getTarget(); + }).get(), localField, foreignField, null, null, as); } - public LookupOperation(Field from, Field localField, Field foreignField, Field as, @Nullable Let let, @Nullable AggregationPipeline pipeline) { + /** + * Creates a new {@link LookupOperation} for the given combination of {@link Field}s and {@link AggregationPipeline + * pipeline}. + * + * @param from must not be {@literal null}. + * @param let must not be {@literal null}. + * @param as must not be {@literal null}. + * @since 4.1 + */ + public LookupOperation(String from, @Nullable Let let, AggregationPipeline pipeline, Field as) { + this(from, null, null, let, pipeline, as); + } + + /** + * Creates a new {@link LookupOperation} for the given combination of {@link Field}s and {@link AggregationPipeline + * pipeline}. + * + * @param from must not be {@literal null}. + * @param localField can be {@literal null} if {@literal pipeline} is present. + * @param foreignField can be {@literal null} if {@literal pipeline} is present. + * @param let can be {@literal null} if {@literal localField} and {@literal foreignField} are present. + * @param as must not be {@literal null}. + * @since 4.1 + */ + public LookupOperation(String from, @Nullable Field localField, @Nullable Field foreignField, @Nullable Let let, + @Nullable AggregationPipeline pipeline, Field as) { + Assert.notNull(from, "From must not be null"); - Assert.notNull(localField, "LocalField must not be null"); - Assert.notNull(foreignField, "ForeignField must not be null"); + if (pipeline == null) { + Assert.notNull(localField, "LocalField must not be null"); + Assert.notNull(foreignField, "ForeignField must not be null"); + } else if (localField == null && foreignField == null) { + Assert.notNull(pipeline, "Pipeline must not be null"); + } Assert.notNull(as, "As must not be null"); this.from = from; @@ -83,19 +125,22 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe Document lookupObject = new Document(); - lookupObject.append("from", from.getTarget()); - lookupObject.append("localField", localField.getTarget()); - lookupObject.append("foreignField", foreignField.getTarget()); - lookupObject.append("as", as.getTarget()); - + lookupObject.append("from", from); + if (localField != null) { + lookupObject.append("localField", localField.getTarget()); + } + if (foreignField != null) { + lookupObject.append("foreignField", foreignField.getTarget()); + } if (let != null) { - lookupObject.append("let", let.toDocument(context)); + lookupObject.append("let", let.toDocument(context).get("$let", Document.class).get("vars")); } - if (pipeline != null) { lookupObject.append("pipeline", pipeline.toDocuments(context)); } + lookupObject.append("as", as.getTarget()); + return new Document(getOperator(), lookupObject); } @@ -122,7 +167,7 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe LocalFieldBuilder from(String name); } - public static interface LocalFieldBuilder { + public static interface LocalFieldBuilder extends PipelineBuilder { /** * @param name the field from the documents input to the {@code $lookup} stage, must not be {@literal null} or @@ -141,7 +186,67 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe AsBuilder foreignField(String name); } - public static interface AsBuilder { + /** + * @since 4.1 + * @author Christoph Strobl + */ + public interface LetBuilder { + + /** + * Specifies {@link Let#getVariableNames() variables) that can be used in the + * {@link PipelineBuilder#pipeline(AggregationOperation...) pipeline stages}. + * + * @param let must not be {@literal null}. + * @return never {@literal null}. + * @see PipelineBuilder + */ + PipelineBuilder let(Let let); + + /** + * Specifies {@link Let#getVariableNames() variables) that can be used in the + * {@link PipelineBuilder#pipeline(AggregationOperation...) pipeline stages}. + * + * @param variables must not be {@literal null}. + * @return never {@literal null}. + * @see PipelineBuilder + */ + default PipelineBuilder let(ExpressionVariable... variables) { + return let(Let.just(variables)); + } + } + + /** + * @since 4.1 + * @author Christoph Strobl + */ + public interface PipelineBuilder extends LetBuilder { + + /** + * Specifies the {@link AggregationPipeline pipeline} that determines the resulting documents. + * + * @param pipeline must not be {@literal null}. + * @return never {@literal null}. + */ + AsBuilder pipeline(AggregationPipeline pipeline); + + /** + * Specifies the {@link AggregationPipeline#getOperations() stages} that determine the resulting documents. + * + * @param stages must not be {@literal null} can be empty. + * @return never {@literal null}. + */ + default AsBuilder pipeline(AggregationOperation... stages) { + return pipeline(AggregationPipeline.of(stages)); + } + + /** + * @param name the name of the new array field to add to the input documents, must not be {@literal null} or empty. + * @return new instance of {@link LookupOperation}. + */ + LookupOperation as(String name); + } + + public static interface AsBuilder extends PipelineBuilder { /** * @param name the name of the new array field to add to the input documents, must not be {@literal null} or empty. @@ -159,10 +264,12 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe public static final class LookupOperationBuilder implements FromBuilder, LocalFieldBuilder, ForeignFieldBuilder, AsBuilder { - private @Nullable Field from; + private @Nullable String from; private @Nullable Field localField; private @Nullable Field foreignField; private @Nullable ExposedField as; + private @Nullable Let let; + private @Nullable AggregationPipeline pipeline; /** * Creates new builder for {@link LookupOperation}. @@ -177,18 +284,10 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe public LocalFieldBuilder from(String name) { Assert.hasText(name, "'From' must not be null or empty"); - from = Fields.field(name); + from = name; return this; } - @Override - public LookupOperation as(String name) { - - Assert.hasText(name, "'As' must not be null or empty"); - as = new ExposedField(Fields.field(name), true); - return new LookupOperation(from, localField, foreignField, as); - } - @Override public AsBuilder foreignField(String name) { @@ -204,50 +303,29 @@ public class LookupOperation implements FieldsExposingAggregationOperation, Inhe localField = Fields.field(name); return this; } - } - - public static class Let implements AggregationExpression{ - - private final List vars; - - public Let(List vars) { - Assert.notEmpty(vars, "'let' must not be null or empty"); - this.vars = vars; - } @Override - public Document toDocument(AggregationOperationContext context) { - return toLet(); - } - - private Document toLet() { - Document mappedVars = new Document(); + public PipelineBuilder let(Let let) { - for (ExpressionVariable var : this.vars) { - mappedVars.putAll(getMappedVariable(var)); - } - - return mappedVars; + Assert.notNull(let, "Let must not be null"); + this.let = let; + return this; } - private Document getMappedVariable(ExpressionVariable var) { - return new Document(var.variableName, prefixDollarSign(var.expression)); - } + @Override + public AsBuilder pipeline(AggregationPipeline pipeline) { - private String prefixDollarSign(String expression) { - return "$" + expression; + Assert.notNull(pipeline, "Pipeline must not be null"); + this.pipeline = pipeline; + return this; } - public static class ExpressionVariable { - - private final String variableName; - - private final String expression; + @Override + public LookupOperation as(String name) { - public ExpressionVariable(String variableName, String expression) { - this.variableName = variableName; - this.expression = expression; - } + Assert.hasText(name, "'As' must not be null or empty"); + as = new ExposedField(Fields.field(name), true); + return new LookupOperation(from, localField, foreignField, let, pipeline, as); } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java index 53cbd4c5e..e8cf46802 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java @@ -16,7 +16,6 @@ package org.springframework.data.mongodb.core.aggregation; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -224,28 +223,41 @@ public class VariableOperators { public static class Let implements AggregationExpression { private final List vars; + + @Nullable // private final AggregationExpression expression; - private Let(List vars, AggregationExpression expression) { + private Let(List vars, @Nullable AggregationExpression expression) { this.vars = vars; this.expression = expression; } + /** + * Create a new {@link Let} holding just the given {@literal variables}. + * + * @param variables must not be {@literal null}. + * @return new instance of {@link Let}. + * @since 4.1 + */ + public static Let just(ExpressionVariable... variables) { + return new Let(List.of(variables), null); + } + /** * Start creating new {@link Let} by defining the variables for {@code $vars}. * * @param variables must not be {@literal null}. * @return */ - public static LetBuilder define(final Collection variables) { + public static LetBuilder define(Collection variables) { Assert.notNull(variables, "Variables must not be null"); return new LetBuilder() { @Override - public Let andApply(final AggregationExpression expression) { + public Let andApply(AggregationExpression expression) { Assert.notNull(expression, "Expression must not be null"); return new Let(new ArrayList(variables), expression); @@ -259,19 +271,10 @@ public class VariableOperators { * @param variables must not be {@literal null}. * @return */ - public static LetBuilder define(final ExpressionVariable... variables) { + public static LetBuilder define(ExpressionVariable... variables) { Assert.notNull(variables, "Variables must not be null"); - - return new LetBuilder() { - - @Override - public Let andApply(final AggregationExpression expression) { - - Assert.notNull(expression, "Expression must not be null"); - return new Let(Arrays.asList(variables), expression); - } - }; + return define(List.of(variables)); } public interface LetBuilder { @@ -283,10 +286,11 @@ public class VariableOperators { * @return */ Let andApply(AggregationExpression expression); + } @Override - public Document toDocument(final AggregationOperationContext context) { + public Document toDocument(AggregationOperationContext context) { return toLet(ExposedFields.synthetic(Fields.fields(getVariableNames())), context); } @@ -312,16 +316,22 @@ public class VariableOperators { } letExpression.put("vars", mappedVars); - letExpression.put("in", getMappedIn(operationContext)); + if (expression != null) { + letExpression.put("in", getMappedIn(operationContext)); + } return new Document("$let", letExpression); } private Document getMappedVariable(ExpressionVariable var, AggregationOperationContext context) { - return new Document(var.variableName, - var.expression instanceof AggregationExpression ? ((AggregationExpression) var.expression).toDocument(context) - : var.expression); + if (var.expression instanceof AggregationExpression expression) { + return new Document(var.variableName, expression.toDocument(context)); + } + if (var.expression instanceof Field field) { + return new Document(var.variableName, context.getReference(field).toString()); + } + return new Document(var.variableName, var.expression); } private Object getMappedIn(AggregationOperationContext context) { @@ -373,6 +383,10 @@ public class VariableOperators { return new ExpressionVariable(variableName, expression); } + public ExpressionVariable forField(String fieldRef) { + return new ExpressionVariable(variableName, Fields.field(fieldRef)); + } + /** * Create a new {@link ExpressionVariable} with current name and given {@literal expressionObject}. * diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index 77f06796c..b992a5ecb 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -60,6 +60,7 @@ import org.springframework.data.mongodb.core.TestEntities; import org.springframework.data.mongodb.core.Venue; import org.springframework.data.mongodb.core.aggregation.AggregationTests.CarDescriptor.Entry; import org.springframework.data.mongodb.core.aggregation.BucketAutoOperation.Granularities; +import org.springframework.data.mongodb.core.aggregation.VariableOperators.Let; import org.springframework.data.mongodb.core.aggregation.VariableOperators.Let.ExpressionVariable; import org.springframework.data.mongodb.core.geo.GeoJsonPoint; import org.springframework.data.mongodb.core.index.GeoSpatialIndexType; @@ -1518,12 +1519,12 @@ public class AggregationTests { assertThat(firstItem).containsEntry("linkedPerson.[0].firstname", "u1"); } - @Test + @Test // GH-3322 void shouldLookupPeopleCorrectlyWithPipeline() { createUsersWithReferencedPersons(); TypedAggregation agg = newAggregation(User.class, // - lookup("person", "_id", "firstname", "linkedPerson", List.of(match(where("firstname").is("u1")))), // + lookup().from("person").localField("_id").foreignField("firstname").pipeline(match(where("firstname").is("u1"))).as("linkedPerson"), // sort(ASC, "id")); AggregationResults results = mongoTemplate.aggregate(agg, User.class, Document.class); @@ -1536,18 +1537,13 @@ public class AggregationTests { assertThat(firstItem).containsEntry("linkedPerson.[0].firstname", "u1"); } - @Test + @Test // GH-3322 void shouldLookupPeopleCorrectlyWithPipelineAndLet() { createUsersWithReferencedPersons(); TypedAggregation agg = newAggregation(User.class, // - lookup( - "person", - "_id", - "firstname", - "linkedPerson", - List.of(new LookupOperation.Let.ExpressionVariable("personFirstname", "firstname")), - List.of(match(where("firstname").is("u1")))), + lookup().from("person").localField("_id").foreignField("firstname").let(Let.ExpressionVariable.newVariable("the_id").forField("_id")).pipeline( + match(ctx -> new Document("$expr", new Document("$eq", List.of("$$the_id", "u1"))))).as("linkedPerson"), sort(ASC, "id")); AggregationResults results = mongoTemplate.aggregate(agg, User.class, Document.class); @@ -1561,7 +1557,7 @@ public class AggregationTests { } @Test // DATAMONGO-1326 - void shouldGroupByAndLookupPeopleCorectly() { + void shouldGroupByAndLookupPeopleCorrectly() { createUsersWithReferencedPersons(); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/LookupOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/LookupOperationUnitTests.java index 360a23776..45b63763c 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/LookupOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/LookupOperationUnitTests.java @@ -16,11 +16,16 @@ package org.springframework.data.mongodb.core.aggregation; import static org.assertj.core.api.Assertions.*; +import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; +import static org.springframework.data.mongodb.core.aggregation.VariableOperators.Let.ExpressionVariable.*; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; + +import java.util.List; import org.bson.Document; import org.junit.jupiter.api.Test; - import org.springframework.data.mongodb.core.DocumentTestUtils; +import org.springframework.data.mongodb.core.query.Criteria; /** * Unit tests for {@link LookupOperation}. @@ -62,7 +67,7 @@ public class LookupOperationUnitTests { Document lookupClause = extractDocumentFromLookupOperation(lookupOperation); - assertThat(lookupClause).containsEntry("from", "a") // + org.assertj.core.api.Assertions.assertThat(lookupClause).containsEntry("from", "a") // .containsEntry("localField", "b") // .containsEntry("foreignField", "c") // .containsEntry("as", "d"); @@ -114,7 +119,7 @@ public class LookupOperationUnitTests { Document lookupClause = extractDocumentFromLookupOperation(lookupOperation); - assertThat(lookupClause).containsEntry("from", "a") // + org.assertj.core.api.Assertions.assertThat(lookupClause).containsEntry("from", "a") // .containsEntry("localField", "b") // .containsEntry("foreignField", "c") // .containsEntry("as", "d"); @@ -129,4 +134,86 @@ public class LookupOperationUnitTests { assertThat(lookupOperation.getFields().exposesSingleFieldOnly()).isTrue(); assertThat(lookupOperation.getFields().getField("d")).isNotNull(); } + + @Test // GH-3322 + void buildsLookupWithLetAndPipeline() { + + LookupOperation lookupOperation = LookupOperation.newLookup().from("warehouses") + .let(newVariable("order_item").forField("item"), newVariable("order_qty").forField("ordered")) + .pipeline(match(ctx -> new Document("$expr", + new Document("$and", List.of(Document.parse("{ $eq: [ \"$stock_item\", \"$$order_item\" ] }"), + Document.parse("{ $gte: [ \"$instock\", \"$$order_qty\" ] }")))))) + .as("stockdata"); + + assertThat(lookupOperation.toDocument(Aggregation.DEFAULT_CONTEXT)).isEqualTo(""" + { $lookup: { + from: "warehouses", + let: { order_item: "$item", order_qty: "$ordered" }, + pipeline: [ + { $match: + { $expr: + { $and: + [ + { $eq: [ "$stock_item", "$$order_item" ] }, + { $gte: [ "$instock", "$$order_qty" ] } + ] + } + } + } + ], + as: "stockdata" + }} + """); + } + + @Test // GH-3322 + void buildsLookupWithJustPipeline() { + + LookupOperation lookupOperation = LookupOperation.newLookup().from("holidays") // + .pipeline( // + match(Criteria.where("year").is(2018)), // + project().andExclude("_id").and(ctx -> new Document("name", "$name").append("date", "$date")).as("date"), // + Aggregation.replaceRoot("date") // + ).as("holidays"); + + assertThat(lookupOperation.toDocument(Aggregation.DEFAULT_CONTEXT)).isEqualTo(""" + { $lookup: + { + from: "holidays", + pipeline: [ + { $match: { year: 2018 } }, + { $project: { _id: 0, date: { name: "$name", date: "$date" } } }, + { $replaceRoot: { newRoot: "$date" } } + ], + as: "holidays" + } + }} + """); + } + + @Test // GH-3322 + void buildsLookupWithLocalAndForeignFieldAsWellAsLetAndPipeline() { + + LookupOperation lookupOperation = Aggregation.lookup().from("restaurants") // + .localField("restaurant_name") + .foreignField("name") + .let(newVariable("orders_drink").forField("drink")) // + .pipeline(match(ctx -> new Document("$expr", new Document("$in", List.of("$$orders_drink", "$beverages"))))) + .as("matches"); + + assertThat(lookupOperation.toDocument(Aggregation.DEFAULT_CONTEXT)).isEqualTo(""" + { $lookup: { + from: "restaurants", + localField: "restaurant_name", + foreignField: "name", + let: { orders_drink: "$drink" }, + pipeline: [{ + $match: { + $expr: { $in: [ "$$orders_drink", "$beverages" ] } + } + }], + as: "matches" + }} + """); + } }