diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index f352c5728..f3d1d03ac 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -48,6 +48,15 @@ public class Aggregation { private final List operations; + /** + * Creates a new {@link Aggregation} from the given {@link AggregationOperation}s. + * + * @param operations must not be {@literal null} or empty. + */ + public static Aggregation newAggregation(List operations) { + return newAggregation(operations.toArray(new AggregationOperation[operations.size()])); + } + /** * Creates a new {@link Aggregation} from the given {@link AggregationOperation}s. * @@ -57,6 +66,16 @@ public class Aggregation { return new Aggregation(operations); } + /** + * Creates a new {@link TypedAggregation} for the given type and {@link AggregationOperation}s. + * + * @param type must not be {@literal null}. + * @param operations must not be {@literal null} or empty. + */ + public static TypedAggregation newAggregation(Class type, List operations) { + return newAggregation(type, operations.toArray(new AggregationOperation[operations.size()])); + } + /** * Creates a new {@link TypedAggregation} for the given type and {@link AggregationOperation}s. * diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java index ac6df23e1..06ee8a1c4 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java @@ -17,15 +17,16 @@ package org.springframework.data.mongodb.core.aggregation; import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.*; +import static org.springframework.data.mongodb.core.DBObjectTestUtils.*; import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; import static org.springframework.data.mongodb.core.query.Criteria.*; +import java.util.ArrayList; import java.util.List; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.springframework.data.mongodb.core.DBObjectTestUtils; import com.mongodb.DBObject; @@ -116,7 +117,47 @@ public class AggregationUnitTests { @SuppressWarnings("unchecked") DBObject secondProjection = ((List) agg.get("pipeline")).get(2); - DBObject fields = DBObjectTestUtils.getAsDBObject(secondProjection, "$project"); + DBObject fields = getAsDBObject(secondProjection, "$project"); + assertThat(fields.get("aCnt"), is((Object) 1)); + assertThat(fields.get("a"), is((Object) "$_id.a")); + } + + /** + * @see DATAMONGO-791 + */ + @Test + public void allowAggregationOperationsToBePassedAsIterable() { + + List ops = new ArrayList(); + ops.add(project("a")); + ops.add(group("a").count().as("aCnt")); + ops.add(project("aCnt", "a")); + + DBObject agg = newAggregation(ops).toDbObject("foo", Aggregation.DEFAULT_CONTEXT); + + @SuppressWarnings("unchecked") + DBObject secondProjection = ((List) agg.get("pipeline")).get(2); + DBObject fields = getAsDBObject(secondProjection, "$project"); + assertThat(fields.get("aCnt"), is((Object) 1)); + assertThat(fields.get("a"), is((Object) "$_id.a")); + } + + /** + * @see DATAMONGO-791 + */ + @Test + public void allowTypedAggregationOperationsToBePassedAsIterable() { + + List ops = new ArrayList(); + ops.add(project("a")); + ops.add(group("a").count().as("aCnt")); + ops.add(project("aCnt", "a")); + + DBObject agg = newAggregation(DBObject.class, ops).toDbObject("foo", Aggregation.DEFAULT_CONTEXT); + + @SuppressWarnings("unchecked") + DBObject secondProjection = ((List) agg.get("pipeline")).get(2); + DBObject fields = getAsDBObject(secondProjection, "$project"); assertThat(fields.get("aCnt"), is((Object) 1)); assertThat(fields.get("a"), is((Object) "$_id.a")); }