From 7523eedd8de6a64cda5bb70a6b63e29a61bc4966 Mon Sep 17 00:00:00 2001 From: Sergey Shcherbakov Date: Tue, 19 Sep 2017 08:49:34 +0200 Subject: [PATCH] DATAMONGO-1784 - Add expression support to GroupOperation#sum(). We now allow passing an AggregationExpression to GroupOperation.sum which allows construction of more complex expressions. Original Pull Request: #501 --- .../core/aggregation/GroupOperation.java | 12 +++++ .../core/aggregation/AggregationTests.java | 44 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java index 49ef8d24e..4f023de69 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java @@ -40,6 +40,7 @@ import com.mongodb.DBObject; * @author Oliver Gierke * @author Gustavo de Geus * @author Christoph Strobl + * @author Sergey Shcherbakov * @since 1.3 * @see MongoDB Aggregation Framework: $group */ @@ -157,6 +158,17 @@ public class GroupOperation implements FieldsExposingAggregationOperation { return sum(reference, null); } + /** + * Generates an {@link GroupOperationBuilder} for an {@code $sum}-expression for the given + * {@link AggregationExpression}. + * + * @param expr + * @return + */ + public GroupOperationBuilder sum(AggregationExpression expr) { + return newBuilder(GroupOps.SUM, null, expr); + } + private GroupOperationBuilder sum(String reference, Object value) { return newBuilder(GroupOps.SUM, reference, value); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index 9b9766370..67a8054ea 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -88,6 +88,7 @@ import com.mongodb.util.JSON; * @author Christoph Strobl * @author Mark Paluch * @author Nikolay Bogdanov + * @author Sergey Shcherbakov */ @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration("classpath:infrastructure.xml") @@ -726,6 +727,49 @@ public class AggregationTests { assertThat(((Number) good.get("score")).longValue(), is(equalTo(9000L))); } + @Test // DATAMONGO-1784 + public void shouldAllowSumUsingConditionalExpressions() { + + mongoTemplate.dropCollection(CarPerson.class); + + CarPerson person1 = new CarPerson("first1", "last1", new CarDescriptor.Entry("MAKE1", "MODEL1", 2000), + new CarDescriptor.Entry("MAKE1", "MODEL2", 2001)); + + CarPerson person2 = new CarPerson("first2", "last2", new CarDescriptor.Entry("MAKE3", "MODEL4", 2014)); + CarPerson person3 = new CarPerson("first3", "last3", new CarDescriptor.Entry("MAKE2", "MODEL5", 2015)); + + mongoTemplate.save(person1); + mongoTemplate.save(person2); + mongoTemplate.save(person3); + + TypedAggregation agg = Aggregation.newAggregation(CarPerson.class, + unwind("descriptors.carDescriptor.entries"), // + project() // + .and(ConditionalOperators // + .when(Criteria.where("descriptors.carDescriptor.entries.make").is("MAKE1")).then("good") + .otherwise("meh")) + .as("make") // + .and("descriptors.carDescriptor.entries.model").as("model") // + .and("descriptors.carDescriptor.entries.year").as("year"), // + group("make").sum(ConditionalOperators // + .when(Criteria.where("year").gte(2012)) // + .then(1) // + .otherwise(9000)).as("score"), + sort(ASC, "make")); + + AggregationResults result = mongoTemplate.aggregate(agg, DBObject.class); + + assertThat(result.getMappedResults(), hasSize(2)); + + DBObject meh = result.getMappedResults().get(0); + assertThat((String) meh.get("_id"), is(equalTo("meh"))); + assertThat(((Number) meh.get("score")).longValue(), is(equalTo(2L))); + + DBObject good = result.getMappedResults().get(1); + assertThat((String) good.get("_id"), is(equalTo("good"))); + assertThat(((Number) good.get("score")).longValue(), is(equalTo(18000L))); + } + /** * @see Return