diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java
index fa2f63afd..b2c6df518 100644
--- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java
+++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java
@@ -318,11 +318,12 @@ public class Aggregation {
}
/**
- * Creates a new {@link GraphLookupOperation.FromBuilder} to construct a {@link GraphLookupOperation} given
- * {@literal fromCollection}.
+ * Creates a new {@link GraphLookupOperation.GraphLookupOperationFromBuilder} to construct a
+ * {@link GraphLookupOperation} given {@literal fromCollection}.
*
* @param fromCollection must not be {@literal null} or empty.
* @return
+ * @since 1.10
*/
public static StartWithBuilder graphLookup(String fromCollection) {
return GraphLookupOperation.builder().from(fromCollection);
diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperation.java
index 4dd36b99d..d6f840ec4 100644
--- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperation.java
+++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperation.java
@@ -17,30 +17,35 @@ package org.springframework.data.mongodb.core.aggregation;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashSet;
import java.util.List;
+import java.util.Set;
+import org.bson.Document;
import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField;
import org.springframework.data.mongodb.core.aggregation.FieldsExposingAggregationOperation.InheritsFieldsAggregationOperation;
import org.springframework.data.mongodb.core.query.CriteriaDefinition;
import org.springframework.util.Assert;
-
-import org.bson.Document;
+import org.springframework.util.ClassUtils;
/**
- * Encapsulates the aggregation framework {@code $graphLookup}-operation.
- *
+ * Encapsulates the aggregation framework {@code $graphLookup}-operation.
* Performs a recursive search on a collection, with options for restricting the search by recursion depth and query
- * filter.
- *
+ * filter.
* We recommend to use the static factory method {@link Aggregation#graphLookup(String)} instead of creating instances
* of this class directly.
*
- * @see http://docs.mongodb.org/manual/reference/aggregation/graphLookup/
+ * @see http://docs.mongodb.org/manual/reference/aggregation/graphLookup/
* @author Mark Paluch
+ * @author Christoph Strobl
* @since 1.10
*/
public class GraphLookupOperation implements InheritsFieldsAggregationOperation {
+ private static final Set> ALLOWED_START_TYPES = new HashSet>(
+ Arrays.> asList(AggregationExpression.class, String.class, Field.class, Document.class));
+
private final String from;
private final List startWith;
private final Field connectFrom;
@@ -65,7 +70,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
/**
* Creates a new {@link FromBuilder} to build {@link GraphLookupOperation}.
- *
+ *
* @return a new {@link FromBuilder}.
*/
public static FromBuilder builder() {
@@ -73,7 +78,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
}
/* (non-Javadoc)
- * @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDBObject(org.springframework.data.mongodb.core.aggregation.AggregationOperationContext)
+ * @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDocument(org.springframework.data.mongodb.core.aggregation.AggregationOperationContext)
*/
@Override
public Document toDocument(AggregationOperationContext context) {
@@ -82,24 +87,20 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
graphLookup.put("from", from);
- List list = new ArrayList<>(startWith.size());
+ List mappedStartWith = new ArrayList(startWith.size());
for (Object startWithElement : startWith) {
if (startWithElement instanceof AggregationExpression) {
- list.add(((AggregationExpression) startWithElement).toDocument(context));
- }
-
- if (startWithElement instanceof Field) {
- list.add(context.getReference((Field) startWithElement).toString());
+ mappedStartWith.add(((AggregationExpression) startWithElement).toDocument(context));
+ } else if (startWithElement instanceof Field) {
+ mappedStartWith.add(context.getReference((Field) startWithElement).toString());
+ } else {
+ mappedStartWith.add(startWithElement);
}
}
- if (list.size() == 1) {
- graphLookup.put("startWith", list.get(0));
- } else {
- graphLookup.put("startWith", list);
- }
+ graphLookup.put("startWith", mappedStartWith.size() == 1 ? mappedStartWith.iterator().next() : mappedStartWith);
graphLookup.put("connectFromField", connectFrom.getName());
graphLookup.put("connectToField", connectTo.getName());
@@ -145,6 +146,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
/**
* @author Mark Paluch
+ * @author Christoph Strobl
*/
public interface StartWithBuilder {
@@ -163,6 +165,16 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
* @return
*/
ConnectFromBuilder startWith(AggregationExpression... expressions);
+
+ /**
+ * Set the startWith as either {@literal fieldReferences}, {@link Fields}, {@link Document} or
+ * {@link AggregationExpression} to apply the {@code $graphLookup} to.
+ *
+ * @param expressions must not be {@literal null}.
+ * @return
+ * @throws IllegalArgumentException
+ */
+ ConnectFromBuilder startWith(Object... expressions);
}
/**
@@ -196,7 +208,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
/**
* Builder to build the initial {@link GraphLookupOperationBuilder} that configures the initial mandatory set of
* {@link GraphLookupOperation} properties.
- *
+ *
* @author Mark Paluch
*/
static final class GraphLookupOperationFromBuilder
@@ -215,7 +227,6 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
Assert.hasText(collectionName, "CollectionName must not be null or empty!");
this.from = collectionName;
-
return this;
}
@@ -235,7 +246,6 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
}
this.startWith = fields;
-
return this;
}
@@ -249,10 +259,50 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
Assert.noNullElements(expressions, "AggregationExpressions must not contain null elements!");
this.startWith = Arrays.asList(expressions);
+ return this;
+ }
+
+ @Override
+ public ConnectFromBuilder startWith(Object... expressions) {
+ Assert.notNull(expressions, "Expressions must not be null!");
+ Assert.noNullElements(expressions, "Expressions must not contain null elements!");
+
+ this.startWith = verifyAndPotentiallyTransformStartsWithTypes(expressions);
return this;
}
+ private List verifyAndPotentiallyTransformStartsWithTypes(Object... expressions) {
+
+ List expressionsToUse = new ArrayList(expressions.length);
+
+ for (Object expression : expressions) {
+
+ assertStartWithType(expression);
+
+ if (expression instanceof String) {
+ expressionsToUse.add(Fields.field((String) expression));
+ } else {
+ expressionsToUse.add(expression);
+ }
+
+ }
+ return expressionsToUse;
+ }
+
+ private void assertStartWithType(Object expression) {
+
+ for (Class> type : ALLOWED_START_TYPES) {
+
+ if (ClassUtils.isAssignable(type, expression.getClass())) {
+ return;
+ }
+ }
+
+ throw new IllegalArgumentException(
+ String.format("Expression must be any of %s but was %s", ALLOWED_START_TYPES, expression.getClass()));
+ }
+
/* (non-Javadoc)
* @see org.springframework.data.mongodb.core.aggregation.GraphLookupOperation.ConnectFromBuilder#connectFrom(java.lang.String)
*/
@@ -262,7 +312,6 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
Assert.hasText(fieldName, "ConnectFrom must not be null or empty!");
this.connectFrom = fieldName;
-
return this;
}
@@ -301,8 +350,8 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
}
/**
- * Limit the number of recursions.
- *
+ * Optionally limit the number of recursions.
+ *
* @param numberOfRecursions must be greater or equal to zero.
* @return
*/
@@ -311,13 +360,12 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
Assert.isTrue(numberOfRecursions >= 0, "Max depth must be >= 0!");
this.maxDepth = numberOfRecursions;
-
return this;
}
/**
- * Add a depth field {@literal fieldName} to each traversed document in the search path.
- *
+ * Optionally add a depth field {@literal fieldName} to each traversed document in the search path.
+ *
* @param fieldName must not be {@literal null} or empty.
* @return
*/
@@ -326,13 +374,12 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
Assert.hasText(fieldName, "Depth field name must not be null or empty!");
this.depthField = Fields.field(fieldName);
-
return this;
}
/**
- * Add a query specifying conditions to the recursive search.
- *
+ * Optionally add a query specifying conditions to the recursive search.
+ *
* @param criteriaDefinition must not be {@literal null}.
* @return
*/
@@ -341,14 +388,13 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
Assert.notNull(criteriaDefinition, "CriteriaDefinition must not be null!");
this.restrictSearchWithMatch = criteriaDefinition;
-
return this;
}
/**
* Set the name of the array field added to each output document and return the final {@link GraphLookupOperation}.
* Contains the documents traversed in the {@literal $graphLookup} stage to reach the document.
- *
+ *
* @param fieldName must not be {@literal null} or empty.
* @return the final {@link GraphLookupOperation}.
*/
diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperationUnitTests.java
index e5022d694..9290bc502 100644
--- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperationUnitTests.java
+++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperationUnitTests.java
@@ -21,6 +21,7 @@ import static org.springframework.data.mongodb.test.util.IsBsonObject.*;
import org.bson.Document;
import org.junit.Test;
+import org.springframework.data.mongodb.core.Person;
import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Literal;
import org.springframework.data.mongodb.core.query.Criteria;
@@ -30,8 +31,9 @@ import com.mongodb.util.JSON;
/**
* Unit tests for {@link GraphLookupOperation}.
- *
+ *
* @author Mark Paluch
+ * @author Christoph Strobl
*/
public class GraphLookupOperationUnitTests {
@@ -102,6 +104,40 @@ public class GraphLookupOperationUnitTests {
+ "connectFromField: \"reportsTo\", connectToField: \"name\", as: \"reportingHierarchy\" } }")));
}
+ /**
+ * @see DATAMONGO-1551
+ */
+ @Test
+ public void shouldRenderMixedArrayOfStartsWithCorrectly() {
+
+ GraphLookupOperation graphLookupOperation = GraphLookupOperation.builder() //
+ .from("employees") //
+ .startWith("reportsTo", Literal.asLiteral("$boss")) //
+ .connectFrom("reportsTo") //
+ .connectTo("name") //
+ .as("reportingHierarchy");
+
+ Document document = graphLookupOperation.toDocument(Aggregation.DEFAULT_CONTEXT);
+
+ assertThat(document,
+ is(Document.parse("{ $graphLookup : { from: \"employees\", startWith: [\"$reportsTo\", { $literal: \"$boss\"}], "
+ + "connectFromField: \"reportsTo\", connectToField: \"name\", as: \"reportingHierarchy\" } }")));
+ }
+
+ /**
+ * @see DATAMONGO-1551
+ */
+ @Test(expected = IllegalArgumentException.class)
+ public void shouldRejectUnknownTypeInMixedArrayOfStartsWithCorrectly() {
+
+ GraphLookupOperation graphLookupOperation = GraphLookupOperation.builder() //
+ .from("employees") //
+ .startWith("reportsTo", new Person()) //
+ .connectFrom("reportsTo") //
+ .connectTo("name") //
+ .as("reportingHierarchy");
+ }
+
/**
* @see DATAMONGO-1551
*/
diff --git a/src/main/asciidoc/reference/mongodb.adoc b/src/main/asciidoc/reference/mongodb.adoc
index 9b3154b17..30d10a414 100644
--- a/src/main/asciidoc/reference/mongodb.adoc
+++ b/src/main/asciidoc/reference/mongodb.adoc
@@ -1676,7 +1676,7 @@ At the time of this writing we provide support for the following Aggregation Ope
[cols="2*"]
|===
| Pipeline Aggregation Operators
-| count, geoNear, group, limit, lookup, match, project, replaceRoot, skip, sort, unwind
+| count, geoNear, graphLookup, group, limit, lookup, match, project, replaceRoot, skip, sort, unwind
| Set Aggregation Operators
| setEquals, setIntersection, setUnion, setDifference, setIsSubset, anyElementTrue, allElementsTrue