From feafd50b59ddb8a85996444bbccc2dca60211057 Mon Sep 17 00:00:00 2001 From: Thomas Darimont Date: Tue, 8 Oct 2013 15:24:28 +0200 Subject: [PATCH] DATAMONGO-769 - Improve support for arithmetic operators in AggregationFramework. We now support the usage of field references in arithmetic projection operations. Original pull request: #80. --- .../core/aggregation/ProjectionOperation.java | 68 ++++++++++++++++++- .../core/aggregation/AggregationTests.java | 10 +++ .../ProjectionOperationUnitTests.java | 38 +++++++++++ 3 files changed, 115 insertions(+), 1 deletion(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java index bd26e582f..16ed45afd 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java @@ -192,6 +192,7 @@ public class ProjectionOperation implements FieldsExposingAggregationOperation { * Builder for {@link ProjectionOperation}s on a field. * * @author Oliver Gierke + * @author Thomas Darimont */ public static class ProjectionOperationBuilder implements AggregationOperation { @@ -266,6 +267,18 @@ public class ProjectionOperation implements FieldsExposingAggregationOperation { return project("add", number); } + /** + * Generates an {@code $add} expression that adds the value of the given field to the previously mentioned field. + * + * @param fieldReference + * @return + */ + public ProjectionOperationBuilder plus(String fieldReference) { + + Assert.notNull(fieldReference, "Field reference must not be null!"); + return project("add", Fields.field(fieldReference)); + } + /** * Generates an {@code $subtract} expression that subtracts the given number to the previously mentioned field. * @@ -278,6 +291,19 @@ public class ProjectionOperation implements FieldsExposingAggregationOperation { return project("subtract", number); } + /** + * Generates an {@code $subtract} expression that subtracts the value of the given field to the previously mentioned + * field. + * + * @param fieldReference + * @return + */ + public ProjectionOperationBuilder minus(String fieldReference) { + + Assert.notNull(fieldReference, "Field reference must not be null!"); + return project("subtract", Fields.field(fieldReference)); + } + /** * Generates an {@code $multiply} expression that multiplies the given number with the previously mentioned field. * @@ -290,6 +316,19 @@ public class ProjectionOperation implements FieldsExposingAggregationOperation { return project("multiply", number); } + /** + * Generates an {@code $multiply} expression that multiplies the value of the given field with the previously + * mentioned field. + * + * @param fieldReference + * @return + */ + public ProjectionOperationBuilder multiply(String fieldReference) { + + Assert.notNull(fieldReference, "Field reference must not be null!"); + return project("multiply", Fields.field(fieldReference)); + } + /** * Generates an {@code $divide} expression that divides the previously mentioned field by the given number. * @@ -303,6 +342,19 @@ public class ProjectionOperation implements FieldsExposingAggregationOperation { return project("divide", number); } + /** + * Generates an {@code $divide} expression that divides the value of the given field by the previously mentioned + * field. + * + * @param fieldReference + * @return + */ + public ProjectionOperationBuilder divide(String fieldReference) { + + Assert.notNull(fieldReference, "Field reference must not be null!"); + return project("divide", Fields.field(fieldReference)); + } + /** * Generates an {@code $mod} expression that divides the previously mentioned field by the given number and returns * the remainder. @@ -317,7 +369,21 @@ public class ProjectionOperation implements FieldsExposingAggregationOperation { return project("mod", number); } - /* (non-Javadoc) + /** + * Generates an {@code $mod} expression that divides the value of the given field by the previously mentioned field + * and returns the remainder. + * + * @param fieldReference + * @return + */ + public ProjectionOperationBuilder mod(String fieldReference) { + + Assert.notNull(fieldReference, "Field reference must not be null!"); + return project("mod", Fields.field(fieldReference)); + } + + /* + * (non-Javadoc) * @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDBObject(org.springframework.data.mongodb.core.aggregation.AggregationOperationContext) */ @Override diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index bd632c3ae..b7534f5f6 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -440,6 +440,11 @@ public class AggregationTests { .and("netPrice").multiply(2).as("netPriceMul2") // .and("netPrice").divide(1.19).as("netPriceDiv119") // .and("spaceUnits").mod(2).as("spaceUnitsMod2") // + .and("spaceUnits").plus("spaceUnits").as("spaceUnitsPlusSpaceUnits") // + .and("spaceUnits").minus("spaceUnits").as("spaceUnitsMinusSpaceUnits") // + .and("spaceUnits").multiply("spaceUnits").as("spaceUnitsMultiplySpaceUnits") // + .and("spaceUnits").divide("spaceUnits").as("spaceUnitsDivideSpaceUnits") // + .and("spaceUnits").mod("spaceUnits").as("spaceUnitsModSpaceUnits") // ); AggregationResults result = mongoTemplate.aggregate(agg, DBObject.class); @@ -453,6 +458,11 @@ public class AggregationTests { assertThat((Double) resultList.get(0).get("netPriceMul2"), is(netPrice * 2)); assertThat((Double) resultList.get(0).get("netPriceDiv119"), is(netPrice / 1.19)); assertThat((Integer) resultList.get(0).get("spaceUnitsMod2"), is(spaceUnits % 2)); + assertThat((Integer) resultList.get(0).get("spaceUnitsPlusSpaceUnits"), is(spaceUnits + spaceUnits)); + assertThat((Integer) resultList.get(0).get("spaceUnitsMinusSpaceUnits"), is(spaceUnits - spaceUnits)); + assertThat((Integer) resultList.get(0).get("spaceUnitsMultiplySpaceUnits"), is(spaceUnits * spaceUnits)); + assertThat((Double) resultList.get(0).get("spaceUnitsDivideSpaceUnits"), is((double) (spaceUnits / spaceUnits))); + assertThat((Integer) resultList.get(0).get("spaceUnitsModSpaceUnits"), is(spaceUnits % spaceUnits)); } /** diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java index 1ed084fb0..44944c99e 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java @@ -25,6 +25,7 @@ import org.springframework.data.mongodb.core.DBObjectUtils; import org.springframework.data.mongodb.core.aggregation.ProjectionOperation.ProjectionOperationBuilder; import com.mongodb.BasicDBList; +import com.mongodb.BasicDBObject; import com.mongodb.DBObject; /** @@ -230,6 +231,43 @@ public class ProjectionOperationUnitTests { new ProjectionOperation().and("a").mod(0); } + /** + * @see DATAMONGO-769 + */ + @Test + public void allowArithmeticOperationsWithFieldReferences() { + + ProjectionOperation operation = Aggregation.project() // + .and("foo").plus("bar").as("fooPlusBar") // + .and("foo").minus("bar").as("fooMinusBar") // + .and("foo").multiply("bar").as("fooMultiplyBar") // + .and("foo").divide("bar").as("fooDivideBar") // + .and("foo").mod("bar").as("fooModBar"); + + DBObject dbObject = operation.toDBObject(Aggregation.DEFAULT_CONTEXT); + DBObject projectClause = DBObjectUtils.getAsDBObject(dbObject, PROJECT); + + assertThat((BasicDBObject) projectClause.get("fooPlusBar"), // + is(new BasicDBObject("$add", dbList("$foo", "$bar")))); + assertThat((BasicDBObject) projectClause.get("fooMinusBar"), // + is(new BasicDBObject("$subtract", dbList("$foo", "$bar")))); + assertThat((BasicDBObject) projectClause.get("fooMultiplyBar"), // + is(new BasicDBObject("$multiply", dbList("$foo", "$bar")))); + assertThat((BasicDBObject) projectClause.get("fooDivideBar"), // + is(new BasicDBObject("$divide", dbList("$foo", "$bar")))); + assertThat((BasicDBObject) projectClause.get("fooModBar"), // + is(new BasicDBObject("$mod", dbList("$foo", "$bar")))); + } + + public static BasicDBList dbList(Object... items) { + + BasicDBList list = new BasicDBList(); + for (Object item : items) { + list.add(item); + } + return list; + } + private static DBObject exctractOperation(String field, DBObject fromProjectClause) { return (DBObject) fromProjectClause.get(field); }