From c3f9af01e662d14fdba97c3628b22cde30e4b5e3 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 2 Dec 2016 15:54:30 +0100 Subject: [PATCH] DATAMONGO-1540 - Add support for $map (aggregation). We now support $map operator in aggregation. Original pull request: #420. --- .../aggregation/AggregationExpressions.java | 188 +++++++++++++++++- .../AggregationFunctionExpressions.java | 2 +- .../ProjectionOperationUnitTests.java | 33 +++ 3 files changed, 214 insertions(+), 9 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java index e49008403..f62bee9a9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationExpressions.java @@ -20,10 +20,12 @@ import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; +import com.mongodb.BasicDBObject; +import com.mongodb.DBObject; import org.bson.Document; import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Filter.AsBuilder; +import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.Map.ArrayOfBuilder; import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; import org.springframework.util.Assert; @@ -1779,6 +1781,24 @@ public interface AggregationExpressions { } } + /** + * Gateway to {@literal Date} aggregation operations. + * + * @author Christoph Strobl + */ + class VariableOperators { + + /** + * Starts building new {@link Map} that applies an {@link AggregationExpression} to each item of a referenced array + * and returns an array with the applied results. + * + * @return + */ + public static ArrayOfBuilder map() { + return Map.map(); + } + } + /** * @author Christoph Strobl */ @@ -1807,10 +1827,10 @@ public interface AggregationExpressions { args.add(unpack(val, context)); } valueToUse = args; - } else if (value instanceof Map) { + } else if (value instanceof java.util.Map) { Document dbo = new Document(); - for (Map.Entry entry : ((Map) value).entrySet()) { + for (java.util.Map.Entry entry : ((java.util.Map) value).entrySet()) { dbo.put(entry.getKey(), unpack(entry.getValue(), context)); } valueToUse = dbo; @@ -1864,10 +1884,10 @@ public interface AggregationExpressions { protected Object append(String key, Object value) { - if (!(value instanceof Map)) { + if (!(value instanceof java.util.Map)) { throw new IllegalArgumentException("o_O"); } - Map clone = new LinkedHashMap((Map) value); + java.util.Map clone = new LinkedHashMap((java.util.Map) value); clone.put(key, value); return clone; @@ -2342,6 +2362,7 @@ public interface AggregationExpressions { Assert.notNull(expression, "Expression must not be null!"); return new Abs(expression); } + /** * Creates new {@link Abs}. * @@ -2493,7 +2514,6 @@ public interface AggregationExpressions { return "$divide"; } - /** * Creates new {@link Divide}. * @@ -4507,9 +4527,9 @@ public interface AggregationExpressions { }; } - private static Map argumentMap(Object date, String format) { + private static java.util.Map argumentMap(Object date, String format) { - Map args = new LinkedHashMap(2); + java.util.Map args = new LinkedHashMap(2); args.put("format", format); args.put("date", date); return args; @@ -5713,4 +5733,156 @@ public interface AggregationExpressions { } } + /** + * {@link AggregationExpression} for {@code $map}. + */ + class Map implements AggregationExpression { + + private Object sourceArray; + private String itemVariableName; + private AggregationExpression functionToApply; + + private Map(Object sourceArray, String itemVariableName, AggregationExpression functionToApply) { + + Assert.notNull(sourceArray, "SourceArray must not be null!"); + Assert.notNull(itemVariableName, "ItemVariableName must not be null!"); + Assert.notNull(functionToApply, "FunctionToApply must not be null!"); + + this.sourceArray = sourceArray; + this.itemVariableName = itemVariableName; + this.functionToApply = functionToApply; + } + + /** + * Starts building new {@link Map} that applies an {@link AggregationExpression} to each item of a referenced array + * and returns an array with the applied results. + * + * @return + */ + static ArrayOfBuilder map() { + + return new ArrayOfBuilder() { + + @Override + public AsBuilder itemsOf(final String fieldReference) { + + return new AsBuilder() { + + @Override + public FunctionBuilder as(final String variableName) { + + return new FunctionBuilder() { + + @Override + public Map andApply(final AggregationExpression expression) { + return new Map(Fields.field(fieldReference), variableName, expression); + } + }; + } + }; + } + + @Override + public AsBuilder itemsOf(final AggregationExpression source) { + + return new AsBuilder() { + + @Override + public FunctionBuilder as(final String variableName) { + + return new FunctionBuilder() { + + @Override + public Map andApply(final AggregationExpression expression) { + return new Map(source, variableName, expression); + } + }; + } + }; + } + }; + }; + + @Override + public Document toDocument(final AggregationOperationContext context) { + + return toMap(new ExposedFieldsAggregationOperationContext( + ExposedFields.synthetic(Fields.fields(itemVariableName)), context) { + + @Override + public FieldReference getReference(Field field) { + + FieldReference ref = null; + try { + ref = context.getReference(field); + } catch (Exception e) { + // just ignore that one. + } + return ref != null ? ref : super.getReference(field); + } + }); + } + + private Document toMap(AggregationOperationContext context) { + + Document map = new Document(); + + Document input; + if (sourceArray instanceof Field) { + input = new Document("input", context.getReference((Field) sourceArray).toString()); + } else { + input = new Document("input", ((AggregationExpression) sourceArray).toDocument(context)); + } + + map.putAll(context.getMappedObject(input)); + map.put("as", itemVariableName); + map.put("in", functionToApply.toDocument(new NestedDelegatingExpressionAggregationOperationContext(context))); + + return new Document("$map", map); + } + + interface ArrayOfBuilder { + + /** + * Set the field that resolves to an array on which to apply the {@link AggregationExpression}. + * + * @param fieldReference must not be {@literal null}. + * @return + */ + AsBuilder itemsOf(String fieldReference); + + /** + * Set the {@link AggregationExpression} that results in an array on which to apply the + * {@link AggregationExpression}. + * + * @param expression must not be {@literal null}. + * @return + */ + AsBuilder itemsOf(AggregationExpression expression); + } + + interface AsBuilder { + + /** + * Define the {@literal variableName} for addressing items within the array. + * + * @param variableName must not be {@literal null}. + * @return + */ + FunctionBuilder as(String variableName); + } + + interface FunctionBuilder { + + /** + * Creates new {@link Map} that applies the given {@link AggregationExpression} to each item of the referenced + * array and returns an array with the applied results. + * + * @param expression must not be {@literal null}. + * @return + */ + Map andApply(AggregationExpression expression); + } + } + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java index 8cceaaf0d..2804ca0bf 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationFunctionExpressions.java @@ -35,7 +35,7 @@ import org.springframework.util.Assert; @Deprecated public enum AggregationFunctionExpressions { - SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT; + SIZE, CMP, EQ, GT, GTE, LT, LTE, NE, SUBTRACT, ADD; /** * Returns an {@link AggregationExpression} build from the current {@link Enum} name and the given parameters. diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java index 0604188a4..5af53f3ce 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java @@ -25,6 +25,8 @@ import static org.springframework.data.mongodb.test.util.IsBsonObject.*; import java.util.Arrays; import java.util.List; +import com.mongodb.DBObject; +import com.mongodb.util.JSON; import org.bson.Document; import org.junit.Test; import org.springframework.data.mongodb.core.DocumentTestUtils; @@ -36,6 +38,7 @@ import org.springframework.data.mongodb.core.aggregation.AggregationExpressions. import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.LiteralOperators; import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.SetOperators; import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.StringOperators; +import org.springframework.data.mongodb.core.aggregation.AggregationExpressions.VariableOperators; import org.springframework.data.mongodb.core.aggregation.ProjectionOperation.ProjectionOperationBuilder; /** @@ -1670,6 +1673,36 @@ public class ProjectionOperationUnitTests { assertThat(agg, is(Document.parse("{ $project: { result: { $not: [ { $gt: [ \"$qty\", 250 ] } ] } } }"))); } + /** + * @see DATAMONGO-784 + */ + @Test + public void shouldRenderMapAggregationExpression() { + + Document agg = Aggregation.project() + .and(VariableOperators.map().itemsOf("quizzes").as("grade") + .andApply(AggregationFunctionExpressions.ADD.of(field("grade"), 2))) + .as("adjustedGrades").toDocument(Aggregation.DEFAULT_CONTEXT); + + assertThat(agg, is(Document.parse( + "{ $project:{ adjustedGrades:{ $map: { input: \"$quizzes\", as: \"grade\",in: { $add: [ \"$$grade\", 2 ] }}}}}"))); + } + + /** + * @see DATAMONGO-784 + */ + @Test + public void shouldRenderMapAggregationExpressionOnExpression() { + + Document agg = Aggregation.project() + .and(VariableOperators.map().itemsOf(AggregationFunctionExpressions.SIZE.of("foo")).as("grade") + .andApply(AggregationFunctionExpressions.ADD.of(field("grade"), 2))) + .as("adjustedGrades").toDocument(Aggregation.DEFAULT_CONTEXT); + + assertThat(agg, is(Document.parse( + "{ $project:{ adjustedGrades:{ $map: { input: { $size : [\"foo\"]}, as: \"grade\",in: { $add: [ \"$$grade\", 2 ] }}}}}"))); + } + private static Document exctractOperation(String field, Document fromProjectClause) { return (Document) fromProjectClause.get(field); }