diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index fa2f63afd..b2c6df518 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -318,11 +318,12 @@ public class Aggregation { } /** - * Creates a new {@link GraphLookupOperation.FromBuilder} to construct a {@link GraphLookupOperation} given - * {@literal fromCollection}. + * Creates a new {@link GraphLookupOperation.GraphLookupOperationFromBuilder} to construct a + * {@link GraphLookupOperation} given {@literal fromCollection}. * * @param fromCollection must not be {@literal null} or empty. * @return + * @since 1.10 */ public static StartWithBuilder graphLookup(String fromCollection) { return GraphLookupOperation.builder().from(fromCollection); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperation.java index 4dd36b99d..d6f840ec4 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperation.java @@ -17,30 +17,35 @@ package org.springframework.data.mongodb.core.aggregation; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; import java.util.List; +import java.util.Set; +import org.bson.Document; import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; import org.springframework.data.mongodb.core.aggregation.FieldsExposingAggregationOperation.InheritsFieldsAggregationOperation; import org.springframework.data.mongodb.core.query.CriteriaDefinition; import org.springframework.util.Assert; - -import org.bson.Document; +import org.springframework.util.ClassUtils; /** - * Encapsulates the aggregation framework {@code $graphLookup}-operation. - *

+ * Encapsulates the aggregation framework {@code $graphLookup}-operation.
* Performs a recursive search on a collection, with options for restricting the search by recursion depth and query - * filter. - *

+ * filter.
* We recommend to use the static factory method {@link Aggregation#graphLookup(String)} instead of creating instances * of this class directly. * - * @see http://docs.mongodb.org/manual/reference/aggregation/graphLookup/ + * @see http://docs.mongodb.org/manual/reference/aggregation/graphLookup/ * @author Mark Paluch + * @author Christoph Strobl * @since 1.10 */ public class GraphLookupOperation implements InheritsFieldsAggregationOperation { + private static final Set> ALLOWED_START_TYPES = new HashSet>( + Arrays.> asList(AggregationExpression.class, String.class, Field.class, Document.class)); + private final String from; private final List startWith; private final Field connectFrom; @@ -65,7 +70,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation /** * Creates a new {@link FromBuilder} to build {@link GraphLookupOperation}. - * + * * @return a new {@link FromBuilder}. */ public static FromBuilder builder() { @@ -73,7 +78,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation } /* (non-Javadoc) - * @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDBObject(org.springframework.data.mongodb.core.aggregation.AggregationOperationContext) + * @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDocument(org.springframework.data.mongodb.core.aggregation.AggregationOperationContext) */ @Override public Document toDocument(AggregationOperationContext context) { @@ -82,24 +87,20 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation graphLookup.put("from", from); - List list = new ArrayList<>(startWith.size()); + List mappedStartWith = new ArrayList(startWith.size()); for (Object startWithElement : startWith) { if (startWithElement instanceof AggregationExpression) { - list.add(((AggregationExpression) startWithElement).toDocument(context)); - } - - if (startWithElement instanceof Field) { - list.add(context.getReference((Field) startWithElement).toString()); + mappedStartWith.add(((AggregationExpression) startWithElement).toDocument(context)); + } else if (startWithElement instanceof Field) { + mappedStartWith.add(context.getReference((Field) startWithElement).toString()); + } else { + mappedStartWith.add(startWithElement); } } - if (list.size() == 1) { - graphLookup.put("startWith", list.get(0)); - } else { - graphLookup.put("startWith", list); - } + graphLookup.put("startWith", mappedStartWith.size() == 1 ? mappedStartWith.iterator().next() : mappedStartWith); graphLookup.put("connectFromField", connectFrom.getName()); graphLookup.put("connectToField", connectTo.getName()); @@ -145,6 +146,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation /** * @author Mark Paluch + * @author Christoph Strobl */ public interface StartWithBuilder { @@ -163,6 +165,16 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation * @return */ ConnectFromBuilder startWith(AggregationExpression... expressions); + + /** + * Set the startWith as either {@literal fieldReferences}, {@link Fields}, {@link Document} or + * {@link AggregationExpression} to apply the {@code $graphLookup} to. + * + * @param expressions must not be {@literal null}. + * @return + * @throws IllegalArgumentException + */ + ConnectFromBuilder startWith(Object... expressions); } /** @@ -196,7 +208,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation /** * Builder to build the initial {@link GraphLookupOperationBuilder} that configures the initial mandatory set of * {@link GraphLookupOperation} properties. - * + * * @author Mark Paluch */ static final class GraphLookupOperationFromBuilder @@ -215,7 +227,6 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation Assert.hasText(collectionName, "CollectionName must not be null or empty!"); this.from = collectionName; - return this; } @@ -235,7 +246,6 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation } this.startWith = fields; - return this; } @@ -249,10 +259,50 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation Assert.noNullElements(expressions, "AggregationExpressions must not contain null elements!"); this.startWith = Arrays.asList(expressions); + return this; + } + + @Override + public ConnectFromBuilder startWith(Object... expressions) { + Assert.notNull(expressions, "Expressions must not be null!"); + Assert.noNullElements(expressions, "Expressions must not contain null elements!"); + + this.startWith = verifyAndPotentiallyTransformStartsWithTypes(expressions); return this; } + private List verifyAndPotentiallyTransformStartsWithTypes(Object... expressions) { + + List expressionsToUse = new ArrayList(expressions.length); + + for (Object expression : expressions) { + + assertStartWithType(expression); + + if (expression instanceof String) { + expressionsToUse.add(Fields.field((String) expression)); + } else { + expressionsToUse.add(expression); + } + + } + return expressionsToUse; + } + + private void assertStartWithType(Object expression) { + + for (Class type : ALLOWED_START_TYPES) { + + if (ClassUtils.isAssignable(type, expression.getClass())) { + return; + } + } + + throw new IllegalArgumentException( + String.format("Expression must be any of %s but was %s", ALLOWED_START_TYPES, expression.getClass())); + } + /* (non-Javadoc) * @see org.springframework.data.mongodb.core.aggregation.GraphLookupOperation.ConnectFromBuilder#connectFrom(java.lang.String) */ @@ -262,7 +312,6 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation Assert.hasText(fieldName, "ConnectFrom must not be null or empty!"); this.connectFrom = fieldName; - return this; } @@ -301,8 +350,8 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation } /** - * Limit the number of recursions. - * + * Optionally limit the number of recursions. + * * @param numberOfRecursions must be greater or equal to zero. * @return */ @@ -311,13 +360,12 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation Assert.isTrue(numberOfRecursions >= 0, "Max depth must be >= 0!"); this.maxDepth = numberOfRecursions; - return this; } /** - * Add a depth field {@literal fieldName} to each traversed document in the search path. - * + * Optionally add a depth field {@literal fieldName} to each traversed document in the search path. + * * @param fieldName must not be {@literal null} or empty. * @return */ @@ -326,13 +374,12 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation Assert.hasText(fieldName, "Depth field name must not be null or empty!"); this.depthField = Fields.field(fieldName); - return this; } /** - * Add a query specifying conditions to the recursive search. - * + * Optionally add a query specifying conditions to the recursive search. + * * @param criteriaDefinition must not be {@literal null}. * @return */ @@ -341,14 +388,13 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation Assert.notNull(criteriaDefinition, "CriteriaDefinition must not be null!"); this.restrictSearchWithMatch = criteriaDefinition; - return this; } /** * Set the name of the array field added to each output document and return the final {@link GraphLookupOperation}. * Contains the documents traversed in the {@literal $graphLookup} stage to reach the document. - * + * * @param fieldName must not be {@literal null} or empty. * @return the final {@link GraphLookupOperation}. */ diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperationUnitTests.java index e5022d694..9290bc502 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperationUnitTests.java @@ -21,6 +21,7 @@ import static org.springframework.data.mongodb.test.util.IsBsonObject.*; import org.bson.Document; import org.junit.Test; +import org.springframework.data.mongodb.core.Person; import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Literal; import org.springframework.data.mongodb.core.query.Criteria; @@ -30,8 +31,9 @@ import com.mongodb.util.JSON; /** * Unit tests for {@link GraphLookupOperation}. - * + * * @author Mark Paluch + * @author Christoph Strobl */ public class GraphLookupOperationUnitTests { @@ -102,6 +104,40 @@ public class GraphLookupOperationUnitTests { + "connectFromField: \"reportsTo\", connectToField: \"name\", as: \"reportingHierarchy\" } }"))); } + /** + * @see DATAMONGO-1551 + */ + @Test + public void shouldRenderMixedArrayOfStartsWithCorrectly() { + + GraphLookupOperation graphLookupOperation = GraphLookupOperation.builder() // + .from("employees") // + .startWith("reportsTo", Literal.asLiteral("$boss")) // + .connectFrom("reportsTo") // + .connectTo("name") // + .as("reportingHierarchy"); + + Document document = graphLookupOperation.toDocument(Aggregation.DEFAULT_CONTEXT); + + assertThat(document, + is(Document.parse("{ $graphLookup : { from: \"employees\", startWith: [\"$reportsTo\", { $literal: \"$boss\"}], " + + "connectFromField: \"reportsTo\", connectToField: \"name\", as: \"reportingHierarchy\" } }"))); + } + + /** + * @see DATAMONGO-1551 + */ + @Test(expected = IllegalArgumentException.class) + public void shouldRejectUnknownTypeInMixedArrayOfStartsWithCorrectly() { + + GraphLookupOperation graphLookupOperation = GraphLookupOperation.builder() // + .from("employees") // + .startWith("reportsTo", new Person()) // + .connectFrom("reportsTo") // + .connectTo("name") // + .as("reportingHierarchy"); + } + /** * @see DATAMONGO-1551 */ diff --git a/src/main/asciidoc/reference/mongodb.adoc b/src/main/asciidoc/reference/mongodb.adoc index 9b3154b17..30d10a414 100644 --- a/src/main/asciidoc/reference/mongodb.adoc +++ b/src/main/asciidoc/reference/mongodb.adoc @@ -1676,7 +1676,7 @@ At the time of this writing we provide support for the following Aggregation Ope [cols="2*"] |=== | Pipeline Aggregation Operators -| count, geoNear, group, limit, lookup, match, project, replaceRoot, skip, sort, unwind +| count, geoNear, graphLookup, group, limit, lookup, match, project, replaceRoot, skip, sort, unwind | Set Aggregation Operators | setEquals, setIntersection, setUnion, setDifference, setIsSubset, anyElementTrue, allElementsTrue