diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java index 205edc469..e3d8f1f3a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java @@ -16,6 +16,8 @@ package org.springframework.data.mongodb.core.aggregation; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Locale; @@ -53,20 +55,29 @@ public class GroupOperation extends ExposedFieldsAggregationOperationContext imp } /** - * Creates a new {@link GroupOperation} from the given {@link GroupOperation} and the given {@link Operation}. + * Creates a new {@link GroupOperation} from the given {@link GroupOperation}. * - * @param current must not be {@literal null}. - * @param operation must not be {@literal null}. + * @param groupOperation must not be {@literal null}. + */ + protected GroupOperation(GroupOperation groupOperation) { + this(groupOperation, Collections. emptyList()); + } + + /** + * Creates a new {@link GroupOperation} from the given {@link GroupOperation} and the given {@link Operation}s. + * + * @param groupOperation + * @param nextOperations */ - protected GroupOperation(GroupOperation current, Operation operation) { + private GroupOperation(GroupOperation groupOperation, List nextOperations) { - Assert.notNull(current, "GroupOperation must not be null!"); - Assert.notNull(operation, "Operation must not be null!"); + Assert.notNull(groupOperation, "GroupOperation must not be null!"); + Assert.notNull(nextOperations, "NextOperations must not be null!"); - this.nonSynthecticFields = current.nonSynthecticFields; - this.operations = new ArrayList(current.operations.size() + 1); - this.operations.addAll(current.operations); - this.operations.add(operation); + this.nonSynthecticFields = groupOperation.nonSynthecticFields; + this.operations = new ArrayList(nextOperations.size() + 1); + this.operations.addAll(groupOperation.operations); + this.operations.addAll(nextOperations); } /** @@ -76,88 +87,171 @@ public class GroupOperation extends ExposedFieldsAggregationOperationContext imp * @return */ protected GroupOperation and(Operation operation) { - return new GroupOperation(this, operation); + return new GroupOperation(this, Arrays.asList(operation)); } /** - * Returns a {@link GroupOperationBuilder} to build a grouping operation for the field with the given name + * Builder for {@link GroupOperation}s on a field. * - * @param field must not be {@literal null} or empty. - * @return + * @author Thomas Darimont */ - public GroupOperationBuilder and(String field) { - return new GroupOperationBuilder(field, this); - } - public class GroupOperationBuilder { - private final String name; - private final GroupOperation current; + private final GroupOperation groupOperation; + private final Operation operation; - public GroupOperationBuilder(String name, GroupOperation current) { + /** + * Creates a new {@link GroupOperationBuilder} from the given {@link GroupOperation} and {@link Operation}. + * + * @param groupOperation + * @param operation + */ + private GroupOperationBuilder(GroupOperation groupOperation, Operation operation) { - Assert.hasText(name, "Field name must not be null or empty!"); - Assert.notNull(current, "GroupOperation must not be null!"); + Assert.notNull(groupOperation, "GroupOperation must not be null!"); + Assert.notNull(operation, "Operation must not be null!"); - this.name = name; - this.current = current; + this.groupOperation = groupOperation; + this.operation = operation; } - public GroupOperation count() { - return sum(1); + /** + * Allows to specify an alias for the new-operation operation. + * + * @param alias + * @return + */ + public GroupOperation as(String alias) { + return this.groupOperation.and(operation.withAlias(alias)); } + } - public GroupOperation count(String reference) { - return sum(reference, 1); - } + /** + * Generates an {@link GroupOperationBuilder} for a {@code $sum}-expression. + *

+ * Count expressions are emulated via {@code $sum: 1}. + *

+ * + * @return + */ + public GroupOperationBuilder count() { + return newBuilder(GroupOps.SUM, null, 1); + } - public GroupOperation sum() { - return sum(name); - } + /** + * Generates an {@link GroupOperationBuilder} for a {@code $sum}-expression for the given field-reference. + * + * @param reference + * @return + */ + public GroupOperationBuilder sum(String reference) { + return sum(reference, null); + } - public GroupOperation sum(String reference) { - return sum(reference, null); - } + private GroupOperationBuilder sum(String reference, Object value) { + return newBuilder(GroupOps.SUM, reference, value); + } - public GroupOperation sum(Object value) { - return sum(null, value); - } + /** + * Generates an {@link GroupOperationBuilder} for an {@code $add_to_set}-expression for the given field-reference. + * + * @param reference + * @return + */ + public GroupOperationBuilder addToSet(String reference) { + return addToSet(reference, null); + } - public GroupOperation sum(String reference, Object value) { - return current.and(new Operation(GroupOps.SUM, name, reference, value)); - } + /** + * Generates an {@link GroupOperationBuilder} for an {@code $add_to_set}-expression for the given value. + * + * @param value + * @return + */ + public GroupOperationBuilder addToSet(Object value) { + return addToSet(null, value); + } - public GroupOperation addToSet() { - return addToSet(null); - } + private GroupOperationBuilder addToSet(String reference, Object value) { + return newBuilder(GroupOps.ADD_TO_SET, reference, value); + } - public GroupOperation addToSet(String reference) { - return current.and(new Operation(GroupOps.ADD_TO_SET, name, reference, null)); - } + /** + * Generates an {@link GroupOperationBuilder} for an {@code $last}-expression for the given field-reference. + * + * @param reference + * @return + */ + public GroupOperationBuilder last(String reference) { + return newBuilder(GroupOps.LAST, reference, null); + } - public GroupOperation last() { - return last(null); - } + /** + * Generates an {@link GroupOperationBuilder} for a {@code $first}-expression for the given field-reference. + * + * @param reference + * @return + */ + public GroupOperationBuilder first(String reference) { + return newBuilder(GroupOps.FIRST, reference, null); + } - public GroupOperation last(String reference) { - return current.and(new Operation(GroupOps.LAST, name, reference, null)); - } + /** + * Generates an {@link GroupOperationBuilder} for an {@code $avg}-expression for the given field-reference. + * + * @param reference + * @return + */ + public GroupOperationBuilder avg(String reference) { + return newBuilder(GroupOps.AVG, reference, null); + } - public GroupOperation first() { - return first(null); - } + /** + * Generates an {@link GroupOperationBuilder} for an {@code $push}-expression for the given field-reference. + * + * @param reference + * @return + */ + public GroupOperationBuilder push(String reference) { + return push(reference, null); + } - public GroupOperation first(String reference) { - return current.and(new Operation(GroupOps.FIRST, name, reference, null)); - } + /** + * Generates an {@link GroupOperationBuilder} for an {@code $push}-expression for the given value. + * + * @param value + * @return + */ + public GroupOperationBuilder push(Object value) { + return push(null, value); + } - public GroupOperation avg() { - return avg(null); - } + private GroupOperationBuilder push(String reference, Object value) { + return newBuilder(GroupOps.PUSH, reference, value); + } - public GroupOperation avg(String reference) { - return current.and(new Operation(GroupOps.AVG, name, reference, null)); - } + /** + * Generates an {@link GroupOperationBuilder} for an {@code $min}-expression that for the given field-reference. + * + * @param reference + * @return + */ + public GroupOperationBuilder min(String reference) { + return newBuilder(GroupOps.MIN, reference, null); + } + + /** + * Generates an {@link GroupOperationBuilder} for an {@code $max}-expression that for the given field-reference. + * + * @param reference + * @return + */ + public GroupOperationBuilder max(String reference) { + return newBuilder(GroupOps.MAX, reference, null); + } + + private GroupOperationBuilder newBuilder(Keyword keyword, String reference, Object value) { + return new GroupOperationBuilder(this, new Operation(keyword, null, reference, value)); } /* @@ -249,6 +343,10 @@ public class GroupOperation extends ExposedFieldsAggregationOperationContext imp this.value = value; } + public Operation withAlias(String key) { + return new Operation(op, key, reference, value); + } + public ExposedField asField() { return new ExposedField(key, true); } @@ -260,5 +358,10 @@ public class GroupOperation extends ExposedFieldsAggregationOperationContext imp public Object getValue(AggregationOperationContext context) { return reference == null ? value : context.getReference(reference).toString(); } + + @Override + public String toString() { + return "Operation [op=" + op + ", key=" + key + ", reference=" + reference + ", value=" + value + "]"; + } } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index e4a841d7e..48d25d39c 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -155,7 +155,7 @@ public class AggregationTests { project("tags"), // unwind("tags"), // group("tags") // - .and("n").count(), // + .count().as("n"), // project("n") // .and("tag").previousOperation(), // sort(DESC, "n") // @@ -183,7 +183,7 @@ public class AggregationTests { project("tags"), // unwind("tags"), // group("tags") // - .and("n").count(), // + .count().as("n"), // project("n") // .and("tag").previousOperation(), // sort(DESC, "n") // @@ -209,7 +209,7 @@ public class AggregationTests { project("tags"), // unwind("tags"), // group("tags") // - .and("count").count(), // + .count().as("count"), // count field not present limit(2) // ); @@ -289,13 +289,13 @@ public class AggregationTests { */ TypedAggregation aggregation = newAggregation(ZipInfo.class, // - group("state", "city").and("pop").sum("population"), // + group("state", "city").sum("population").as("pop"), // sort(ASC, "pop", "state", "city"), // group("state") // - .and("biggestCity").last("city") // - .and("biggestPop").last("pop") // - .and("smallestCity").first("city") // - .and("smallestPop").first("pop"), // + .last("city").as("biggestCity") // + .last("pop").as("biggestPop") // + .first("city").as("smallestCity") // + .first("pop").as("smallestPop"), // project() // // .and(previousOperation()).exclude() // .and("state").previousOperation() // @@ -361,7 +361,7 @@ public class AggregationTests { TypedAggregation agg = newAggregation(ZipInfo.class, // group("state") // - .and("totalPop").sum("population"), // + .sum("population").as("totalPop"), // sort(ASC, previousOperation(), "totalPop"), // match(where("totalPop").gte(10 * 1000 * 1000)) // ); @@ -401,7 +401,7 @@ public class AggregationTests { TypedAggregation agg = newAggregation(UserWithLikes.class, // unwind("likes"), // - group("likes").and("number").count(), // + group("likes").count().as("number"), // sort(DESC, "number"), // limit(5), // sort(ASC, previousOperation()) // diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GroupOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GroupOperationUnitTests.java index 5bed23d1e..bbd31ea98 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GroupOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GroupOperationUnitTests.java @@ -34,7 +34,7 @@ public class GroupOperationUnitTests { @Test(expected = IllegalArgumentException.class) public void rejectsNullFields() { - new GroupOperation(null); + new GroupOperation((Fields) null); } @Test @@ -42,8 +42,7 @@ public class GroupOperationUnitTests { GroupOperation operation = new GroupOperation(fields("a")); - DBObject dbObject = operation.toDBObject(Aggregation.DEFAULT_CONTEXT); - DBObject groupClause = DBObjectUtils.getAsDBObject(dbObject, "$group"); + DBObject groupClause = extractDbObjectFromGroupOperation(operation); assertThat(groupClause.get(UNDERSCORE_ID), is((Object) "$a")); } @@ -53,8 +52,7 @@ public class GroupOperationUnitTests { GroupOperation operation = new GroupOperation(fields("a").and("b", "c")); - DBObject dbObject = operation.toDBObject(Aggregation.DEFAULT_CONTEXT); - DBObject groupClause = DBObjectUtils.getAsDBObject(dbObject, "$group"); + DBObject groupClause = extractDbObjectFromGroupOperation(operation); DBObject idClause = DBObjectUtils.getAsDBObject(groupClause, UNDERSCORE_ID); assertThat(idClause.get("a"), is((Object) "$a")); @@ -75,13 +73,98 @@ public class GroupOperationUnitTests { @Test public void groupFactoryMethodWithMultipleFieldsAndSumOperation() { - Fields fields = fields("a", "b").and("c"); // .and("d", 42); - GroupOperation groupOperation = new GroupOperation(fields).and("e").sum(); - - DBObject dbObject = groupOperation.toDBObject(Aggregation.DEFAULT_CONTEXT); + GroupOperation groupOperation = Aggregation.group(fields("a", "b").and("c")) // + .sum("e").as("e"); - DBObject groupClause = DBObjectUtils.getAsDBObject(dbObject, "$group"); + DBObject groupClause = extractDbObjectFromGroupOperation(groupOperation); DBObject eOp = DBObjectUtils.getAsDBObject(groupClause, "e"); assertThat(eOp, is((DBObject) new BasicDBObject("$sum", "$e"))); } + + @Test + public void groupFactoryMethodWithMultipleFieldsAndSumOperationWithAlias() { + + GroupOperation groupOperation = Aggregation.group(fields("a", "b").and("c")) // + .sum("e").as("ee"); + + DBObject groupClause = extractDbObjectFromGroupOperation(groupOperation); + DBObject eOp = DBObjectUtils.getAsDBObject(groupClause, "ee"); + assertThat(eOp, is((DBObject) new BasicDBObject("$sum", "$e"))); + } + + @Test + public void groupFactoryMethodWithMultipleFieldsAndCountOperationWithout() { + + GroupOperation groupOperation = Aggregation.group(fields("a", "b").and("c")) // + .count().as("count"); + + DBObject groupClause = extractDbObjectFromGroupOperation(groupOperation); + DBObject eOp = DBObjectUtils.getAsDBObject(groupClause, "count"); + assertThat(eOp, is((DBObject) new BasicDBObject("$sum", 1))); + } + + @Test + public void groupFactoryMethodWithMultipleFieldsAndMultipleAggregateOperationsWithAlias() { + + GroupOperation groupOperation = Aggregation.group(fields("a", "b").and("c")) // + .sum("e").as("sum") // + .min("e").as("min"); // + + DBObject groupClause = extractDbObjectFromGroupOperation(groupOperation); + DBObject sum = DBObjectUtils.getAsDBObject(groupClause, "sum"); + assertThat(sum, is((DBObject) new BasicDBObject("$sum", "$e"))); + + DBObject min = DBObjectUtils.getAsDBObject(groupClause, "min"); + assertThat(min, is((DBObject) new BasicDBObject("$min", "$e"))); + } + + @Test + public void groupOperationPushWithValue() { + + GroupOperation groupOperation = Aggregation.group("a", "b").push(1).as("x"); + + DBObject groupClause = extractDbObjectFromGroupOperation(groupOperation); + DBObject push = DBObjectUtils.getAsDBObject(groupClause, "x"); + + assertThat(push, is((DBObject) new BasicDBObject("$push", 1))); + } + + @Test + public void groupOperationPushWithReference() { + + GroupOperation groupOperation = Aggregation.group("a", "b").push("ref").as("x"); + + DBObject groupClause = extractDbObjectFromGroupOperation(groupOperation); + DBObject push = DBObjectUtils.getAsDBObject(groupClause, "x"); + + assertThat(push, is((DBObject) new BasicDBObject("$push", "$ref"))); + } + + @Test + public void groupOperationAddToSetWithReference() { + + GroupOperation groupOperation = Aggregation.group("a", "b").addToSet("ref").as("x"); + + DBObject groupClause = extractDbObjectFromGroupOperation(groupOperation); + DBObject push = DBObjectUtils.getAsDBObject(groupClause, "x"); + + assertThat(push, is((DBObject) new BasicDBObject("$addToSet", "$ref"))); + } + + @Test + public void groupOperationAddToSetWithValue() { + + GroupOperation groupOperation = Aggregation.group("a", "b").addToSet(42).as("x"); + + DBObject groupClause = extractDbObjectFromGroupOperation(groupOperation); + DBObject push = DBObjectUtils.getAsDBObject(groupClause, "x"); + + assertThat(push, is((DBObject) new BasicDBObject("$addToSet", 42))); + } + + private DBObject extractDbObjectFromGroupOperation(GroupOperation groupOperation) { + DBObject dbObject = groupOperation.toDBObject(Aggregation.DEFAULT_CONTEXT); + DBObject groupClause = DBObjectUtils.getAsDBObject(dbObject, "$group"); + return groupClause; + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java index 74d473894..b91c5f685 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java @@ -18,8 +18,11 @@ package org.springframework.data.mongodb.core.aggregation; import static org.hamcrest.Matchers.*; import static org.junit.Assert.*; +import java.util.Arrays; + import org.junit.Test; import org.springframework.data.mongodb.core.DBObjectUtils; +import org.springframework.data.mongodb.core.aggregation.ProjectionOperation.ProjectionOperationBuilder; import com.mongodb.BasicDBList; import com.mongodb.DBObject; @@ -31,6 +34,11 @@ import com.mongodb.DBObject; */ public class ProjectionOperationUnitTests { + static final String MOD = "$mod"; + static final String ADD = "$add"; + static final String SUBTRACT = "$subtract"; + static final String MULTIPLY = "$multiply"; + static final String DIVIDE = "$divide"; static final String PROJECT = "$project"; @Test(expected = IllegalArgumentException.class) @@ -86,4 +94,102 @@ public class ProjectionOperationUnitTests { assertThat(addClause.get(0), is((Object) "$foo")); assertThat(addClause.get(1), is((Object) 41)); } + + public void arithmenticProjectionOperationWithoutAlias() { + + String fieldName = "a"; + ProjectionOperationBuilder operation = new ProjectionOperation().and(fieldName).plus(1); + DBObject dbObject = operation.toDBObject(Aggregation.DEFAULT_CONTEXT); + DBObject projectClause = DBObjectUtils.getAsDBObject(dbObject, PROJECT); + DBObject oper = exctractOperation(fieldName, projectClause); + + assertThat(oper.containsField(ADD), is(true)); + assertThat(oper.get(ADD), is((Object) Arrays. asList("$a", 1))); + } + + @Test + public void arithmenticProjectionOperationPlus() { + + String fieldName = "a"; + String fieldAlias = "b"; + ProjectionOperation operation = new ProjectionOperation().and(fieldName).plus(1).as(fieldAlias); + DBObject dbObject = operation.toDBObject(Aggregation.DEFAULT_CONTEXT); + DBObject projectClause = DBObjectUtils.getAsDBObject(dbObject, PROJECT); + + DBObject oper = exctractOperation(fieldAlias, projectClause); + assertThat(oper.containsField(ADD), is(true)); + assertThat(oper.get(ADD), is((Object) Arrays. asList("$a", 1))); + } + + @Test + public void arithmenticProjectionOperationMinus() { + + String fieldName = "a"; + String fieldAlias = "b"; + ProjectionOperation operation = new ProjectionOperation().and(fieldName).minus(1).as(fieldAlias); + DBObject dbObject = operation.toDBObject(Aggregation.DEFAULT_CONTEXT); + DBObject projectClause = DBObjectUtils.getAsDBObject(dbObject, PROJECT); + DBObject oper = exctractOperation(fieldAlias, projectClause); + + assertThat(oper.containsField(SUBTRACT), is(true)); + assertThat(oper.get(SUBTRACT), is((Object) Arrays. asList("$a", 1))); + } + + @Test + public void arithmenticProjectionOperationMultiply() { + + String fieldName = "a"; + String fieldAlias = "b"; + ProjectionOperation operation = new ProjectionOperation().and(fieldName).multiply(1).as(fieldAlias); + DBObject dbObject = operation.toDBObject(Aggregation.DEFAULT_CONTEXT); + DBObject projectClause = DBObjectUtils.getAsDBObject(dbObject, PROJECT); + DBObject oper = exctractOperation(fieldAlias, projectClause); + + assertThat(oper.containsField(MULTIPLY), is(true)); + assertThat(oper.get(MULTIPLY), is((Object) Arrays. asList("$a", 1))); + } + + @Test + public void arithmenticProjectionOperationDivide() { + + String fieldName = "a"; + String fieldAlias = "b"; + ProjectionOperation operation = new ProjectionOperation().and(fieldName).divide(1).as(fieldAlias); + DBObject dbObject = operation.toDBObject(Aggregation.DEFAULT_CONTEXT); + DBObject projectClause = DBObjectUtils.getAsDBObject(dbObject, PROJECT); + DBObject oper = exctractOperation(fieldAlias, projectClause); + + assertThat(oper.containsField(DIVIDE), is(true)); + assertThat(oper.get(DIVIDE), is((Object) Arrays. asList("$a", 1))); + } + + @Test(expected = IllegalArgumentException.class) + public void arithmenticProjectionOperationDivideByZeroException() { + + new ProjectionOperation().and("a").divide(0); + } + + @Test + public void arithmenticProjectionOperationMod() { + + String fieldName = "a"; + String fieldAlias = "b"; + ProjectionOperation operation = new ProjectionOperation().and(fieldName).mod(3).as(fieldAlias); + DBObject dbObject = operation.toDBObject(Aggregation.DEFAULT_CONTEXT); + DBObject projectClause = DBObjectUtils.getAsDBObject(dbObject, PROJECT); + DBObject oper = exctractOperation(fieldAlias, projectClause); + + assertThat(oper.containsField(MOD), is(true)); + assertThat(oper.get(MOD), is((Object) Arrays. asList("$a", 3))); + } + + @Test(expected = IllegalArgumentException.class) + public void arithmenticProjectionOperationModByZeroException() { + + new ProjectionOperation().and("a").mod(0); + } + + private static DBObject exctractOperation(String field, DBObject fromProjectClause) { + return (DBObject) fromProjectClause.get(field); + } }