From 0ef852a8fcdb4654d8242b568f12e52f92575ada Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 25 Sep 2020 13:31:16 +0200 Subject: [PATCH] DATAMONGO-2623 - Add support for $function and $accumulator aggregation operators. Original pull request: #887. --- .../AbstractAggregationExpression.java | 42 ++ .../core/aggregation/ScriptOperators.java | 548 ++++++++++++++++++ .../aggregation/ScriptOperatorsUnitTests.java | 94 +++ src/main/asciidoc/new-features.adoc | 1 + src/main/asciidoc/reference/mongodb.adoc | 4 + 5 files changed, 689 insertions(+) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ScriptOperators.java create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ScriptOperatorsUnitTests.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AbstractAggregationExpression.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AbstractAggregationExpression.java index ad607cbca..82c03758b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AbstractAggregationExpression.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AbstractAggregationExpression.java @@ -140,6 +140,48 @@ abstract class AbstractAggregationExpression implements AggregationExpression { } + protected java.util.Map remove(String key) { + + Assert.isInstanceOf(Map.class, this.value, "Value must be a type of Map!"); + + java.util.Map clone = new LinkedHashMap<>((java.util.Map) this.value); + clone.remove(key); + return clone; + } + + /** + * Append the given key at the position in the underlying {@link LinkedHashMap}. + * + * @param index + * @param key + * @param value + * @return + * @since 3.1 + */ + protected java.util.Map appendAt(int index, String key, Object value) { + + Assert.isInstanceOf(Map.class, this.value, "Value must be a type of Map!"); + + java.util.LinkedHashMap clone = new java.util.LinkedHashMap<>(); + + int i = 0; + for (Map.Entry entry : ((java.util.Map) this.value).entrySet()) { + + if (i == index) { + clone.put(key, value); + } + if (!entry.getKey().equals(key)) { + clone.put(entry.getKey(), entry.getValue()); + } + i++; + } + if (i <= index) { + clone.put(key, value); + } + return clone; + + } + protected List values() { if (value instanceof List) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ScriptOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ScriptOperators.java new file mode 100644 index 000000000..6d451aca0 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ScriptOperators.java @@ -0,0 +1,548 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.aggregation; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.springframework.data.mongodb.core.aggregation.ScriptOperators.Accumulator.AccumulatorBuilder; +import org.springframework.data.mongodb.core.aggregation.ScriptOperators.Accumulator.AccumulatorInitBuilder; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +/** + * Gateway to {@literal $function} and {@literal $accumulator} aggregation operations. + *

+ * Using {@link ScriptOperators} as part of the {@link Aggregation} requires MongoDB server to have + * server-side JavaScript execution + * enabled. + * + * @author Christoph Strobl + * @since 3.1 + */ +public class ScriptOperators { + + /** + * Create a custom aggregation + * $function in JavaScript. + * + * @param body The function definition. Must not be {@literal null}. + * @return new instance of {@link Function}. + */ + public static Function function(String body) { + return Function.function(body); + } + + /** + * Create a custom $accumulator + * operator in Javascript. + * + * @return new instance of {@link AccumulatorInitBuilder}. + */ + public static AccumulatorInitBuilder accumulatorBuilder() { + return new AccumulatorBuilder(); + } + + /** + * {@link Function} defines a custom aggregation + * $function in JavaScript. + *

+ * + * { + * $function: { + * body: ..., + * args: ..., + * lang: "js" + * } + * } + * + *

+ * {@link Function} cannot be used as part of {@link org.springframework.data.mongodb.core.schema.MongoJsonSchema + * schema} validation query expression.
+ * NOTE: Server-Side JavaScript + * execution must be + * enabled + * + * @see MongoDB Documentation: + * $function + * @since 3.1 + */ + public static class Function extends AbstractAggregationExpression { + + private Function(Map values) { + super(values); + } + + /** + * Create a new {@link Function} with the given function definition. + * + * @param body must not be {@literal null}. + * @return new instance of {@link Function}. + */ + public static Function function(String body) { + + Map function = new LinkedHashMap<>(2); + function.put(Fields.BODY.toString(), body); + function.put(Fields.ARGS.toString(), Collections.emptyList()); + function.put(Fields.LANG.toString(), "js"); + + return new Function(function); + } + + /** + * Set the arguments passed to the function body. + * + * @param args the arguments passed to the function body. Leave empty if the function does not take any arguments. + * @return new instance of {@link Function}. + */ + public Function args(Object... args) { + return args(Arrays.asList(args)); + } + + /** + * Set the arguments passed to the function body. + * + * @param args the arguments passed to the function body. Leave empty if the function does not take any arguments. + * @return new instance of {@link Function}. + */ + public Function args(List args) { + + Assert.notNull(args, "Args must not be null! Use an empty list instead."); + return new Function(appendAt(1, Fields.ARGS.toString(), args)); + } + + /** + * The language used in the body. + * + * @param lang must not be {@literal null} nor empty. + * @return new instance of {@link Function}. + */ + public Function lang(String lang) { + + Assert.hasText(lang, "Lang must not be null nor emtpy! The default would be 'js'."); + return new Function(appendAt(2, Fields.LANG.toString(), lang)); + } + + @Nullable + List getArgs() { + return get(Fields.ARGS.toString()); + } + + String getBody() { + return get(Fields.BODY.toString()); + } + + String getLang() { + return get(Fields.LANG.toString()); + } + + @Override + protected String getMongoMethod() { + return "$function"; + } + + enum Fields { + + BODY, ARGS, LANG; + + @Override + public String toString() { + return name().toLowerCase(); + } + } + } + + /** + * {@link Accumulator} defines a custom aggregation + * $accumulator operator, + * one that maintains its state (e.g. totals, maximums, minimums, and related data) as documents progress through the + * pipeline, in JavaScript. + *

+ * + * { + * $accumulator: { + * init: ..., + * intArgs: ..., + * accumulate: ..., + * accumulateArgs: ..., + * merge: ..., + * finalize: ..., + * lang: "js" + * } + * } + * + *

+ * {@link Accumulator} can be used as part of {@link GroupOperation $group}, {@link BucketOperation $bucket} and + * {@link BucketAutoOperation $bucketAuto} pipeline stages.
+ * NOTE: Server-Side JavaScript + * execution must be + * enabled + * + * @see MongoDB Documentation: + * $accumulator + * @since 3.1 + */ + public static class Accumulator extends AbstractAggregationExpression { + + private Accumulator(Map value) { + super(value); + } + + @Override + protected String getMongoMethod() { + return "$accumulator"; + } + + enum Fields { + + ACCUMULATE("accumulate"), // + ACCUMULATE_ARGS("accumulateArgs"), // + FINALIZE("finalize"), // + INIT("init"), // + INIT_ARGS("initArgs"), // + LANG("lang"), // + MERGE("merge"); // + + private String field; + + Fields(String field) { + this.field = field; + } + + @Override + public String toString() { + return field; + } + } + + public interface AccumulatorInitBuilder { + + /** + * Define the {@code init} {@link Function} for the {@link Accumulator accumulators} initial state. The function + * receives its arguments from the {@link Function#args(Object...) initArgs} array expression. + *

+ * + * function(initArg1, initArg2, ...) { + * ... + * return initialState + * } + * + * + * @param function must not be {@literal null}. + * @return this. + */ + default AccumulatorAccumulateBuilder init(Function function) { + return init(function.getBody()).initArgs(function.getArgs()); + } + + /** + * Define the {@code init} function for the {@link Accumulator accumulators} initial state. The function receives + * its arguments from the {@link AccumulatorInitArgsBuilder#initArgs(Object...)} array expression. + *

+ * + * function(initArg1, initArg2, ...) { + * ... + * return initialState + * } + * + * + * @param function must not be {@literal null}. + * @return this. + */ + AccumulatorInitArgsBuilder init(String function); + + /** + * The language used in the {@code $accumulator} code. + * + * @param lang must not be {@literal null}. Default is {@literal js}. + * @return this. + */ + AccumulatorInitBuilder lang(String lang); + } + + public interface AccumulatorInitArgsBuilder extends AccumulatorAccumulateBuilder { + + /** + * Define the optional {@code initArgs} for the {@link AccumulatorInitBuilder#init(String)} function. + * + * @param args must not be {@literal null}. + * @return this. + */ + default AccumulatorAccumulateBuilder initArgs(Object... args) { + return initArgs(Arrays.asList(args)); + } + + /** + * Define the optional {@code initArgs} for the {@link AccumulatorInitBuilder#init(String)} function. + * + * @param args can be {@literal null}. + * @return this. + */ + AccumulatorAccumulateBuilder initArgs(@Nullable List args); + } + + public interface AccumulatorAccumulateBuilder { + + /** + * Set the {@code accumulate} {@link Function} that updates the state for each document. The functions first + * argument is the current {@code state}, additional arguments can be defined via {@link Function#args(Object...) + * accumulateArgs}. + *

+ * + * function(state, accumArg1, accumArg2, ...) { + * ... + * return newState + * } + * + * + * @param function must not be {@literal null}. + * @return this. + */ + default AccumulatorMergeBuilder accumulate(Function function) { + return accumulate(function.getBody()).accumulateArgs(function.getArgs()); + } + + /** + * Set the {@code accumulate} function that updates the state for each document. The functions first argument is + * the current {@code state}, additional arguments can be defined via + * {@link AccumulatorAccumulateArgsBuilder#accumulateArgs(Object...)}. + *

+ * + * function(state, accumArg1, accumArg2, ...) { + * ... + * return newState + * } + * + * + * @param function must not be {@literal null}. + * @return this. + */ + AccumulatorAccumulateArgsBuilder accumulate(String function); + } + + public interface AccumulatorAccumulateArgsBuilder extends AccumulatorMergeBuilder { + + /** + * Define additional {@code accumulateArgs} for the {@link AccumulatorAccumulateBuilder#accumulate(String)} + * function. + * + * @param args must not be {@literal null}. + * @return this. + */ + default AccumulatorMergeBuilder accumulateArgs(Object... args) { + return accumulateArgs(Arrays.asList(args)); + } + + /** + * Define additional {@code accumulateArgs} for the {@link AccumulatorAccumulateBuilder#accumulate(String)} + * function. + * + * @param args can be {@literal null}. + * @return this. + */ + AccumulatorMergeBuilder accumulateArgs(@Nullable List args); + } + + public interface AccumulatorMergeBuilder { + + /** + * Set the {@code merge} function used to merge two internal states.
+ * This might be required because the operation is run on a sharded cluster or when the operator exceeds its + * memory limit. + *

+ * + * function(state1, state2) { + * ... + * return newState + * } + * + * + * @param function must not be {@literal null}. + * @return this. + */ + AccumulatorFinalizeBuilder merge(String function); + } + + public interface AccumulatorFinalizeBuilder { + + /** + * Set the {@code finalize} function used to update the result of the accumulation when all documents have been + * processed. + *

+ * + * function(state) { + * ... + * return finalState + * } + * + * + * @param function must not be {@literal null}. + * @return new instance of {@link Accumulator}. + */ + Accumulator finalize(String function); + } + + public static class AccumulatorBuilder + implements AccumulatorInitBuilder, AccumulatorInitArgsBuilder, AccumulatorAccumulateBuilder, + AccumulatorAccumulateArgsBuilder, AccumulatorMergeBuilder, AccumulatorFinalizeBuilder { + + private List initArgs; + private String initFunction; + private List accumulateArgs; + private String accumulateFunction; + private String mergeFunction; + private String finalizeFunction; + private String lang = "js"; + + /** + * Define the {@code init} function for the {@link Accumulator accumulators} initial state. The function receives + * its arguments from the {@link #initArgs(Object...)} array expression. + *

+ * + * function(initArg1, initArg2, ...) { + * ... + * return initialState + * } + * + * + * @param function must not be {@literal null}. + * @return this. + */ + public AccumulatorBuilder init(String function) { + + this.initFunction = function; + return this; + } + + /** + * Define the optional {@code initArgs} for the {@link #init(String)} function. + * + * @param args can be {@literal null}. + * @return this. + */ + public AccumulatorBuilder initArgs(@Nullable List args) { + + this.initArgs = args != null ? new ArrayList<>(args) : Collections.emptyList(); + return this; + } + + /** + * Set the {@code accumulate} function that updates the state for each document. The functions first argument is + * the current {@code state}, additional arguments can be defined via {@link #accumulateArgs(Object...)}. + *

+ * + * function(state, accumArg1, accumArg2, ...) { + * ... + * return newState + * } + * + * + * @param function must not be {@literal null}. + * @return this. + */ + public AccumulatorBuilder accumulate(String function) { + + this.accumulateFunction = function; + return this; + } + + /** + * Define additional {@code accumulateArgs} for the {@link #accumulate(String)} function. + * + * @param args can be {@literal null}. + * @return this. + */ + public AccumulatorBuilder accumulateArgs(@Nullable List args) { + + this.accumulateArgs = args != null ? new ArrayList<>(args) : Collections.emptyList(); + return this; + } + + /** + * Set the {@code merge} function used to merge two internal states.
+ * This might be required because the operation is run on a sharded cluster or when the operator exceeds its + * memory limit. + *

+ * + * function(state1, state2) { + * ... + * return newState + * } + * + * + * @param function must not be {@literal null}. + * @return this. + */ + public AccumulatorBuilder merge(String function) { + + this.mergeFunction = function; + return this; + } + + /** + * The language used in the {@code $accumulator} code. + * + * @param lang must not be {@literal null}. Default is {@literal js}. + * @return this. + */ + public AccumulatorBuilder lang(String lang) { + + this.lang = lang; + return this; + } + + /** + * Set the {@code finalize} function used to update the result of the accumulation when all documents have been + * processed. + *

+ * + * function(state) { + * ... + * return finalState + * } + * + * + * @param function must not be {@literal null}. + * @return new instance of {@link Accumulator}. + */ + public Accumulator finalize(String function) { + + this.finalizeFunction = function; + + Map args = new LinkedHashMap<>(); + args.put(Fields.INIT.toString(), initFunction); + if (!CollectionUtils.isEmpty(initArgs)) { + args.put(Fields.INIT_ARGS.toString(), initArgs); + } + args.put(Fields.ACCUMULATE.toString(), accumulateFunction); + if (!CollectionUtils.isEmpty(accumulateArgs)) { + args.put(Fields.ACCUMULATE_ARGS.toString(), accumulateArgs); + } + args.put(Fields.MERGE.toString(), mergeFunction); + args.put(Fields.FINALIZE.toString(), finalizeFunction); + args.put(Fields.LANG.toString(), lang); + + return new Accumulator(args); + } + + } + } +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ScriptOperatorsUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ScriptOperatorsUnitTests.java new file mode 100644 index 000000000..fb237b631 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ScriptOperatorsUnitTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.aggregation; + +import static org.assertj.core.api.Assertions.*; +import static org.springframework.data.mongodb.core.aggregation.ScriptOperators.*; + +import java.util.Collections; + +import org.bson.Document; +import org.junit.jupiter.api.Test; + +/** + * @author Christoph Strobl + */ +class ScriptOperatorsUnitTests { + + private static final String FUNCTION_BODY = "function(name) { return hex_md5(name) == \"15b0a220baa16331e8d80e15367677ad\" }"; + private static final Document EMPTY_ARGS_FUNCTION_DOCUMENT = new Document("body", FUNCTION_BODY) + .append("args", Collections.emptyList()).append("lang", "js"); + + @Test // DATAMONGO-2623 + void functionWithoutArgsShouldBeRenderedCorrectly() { + + assertThat(function(FUNCTION_BODY).toDocument(Aggregation.DEFAULT_CONTEXT)) + .isEqualTo($function(EMPTY_ARGS_FUNCTION_DOCUMENT)); + } + + @Test // DATAMONGO-2623 + void functionWithArgsShouldBeRenderedCorrectly() { + + assertThat(function(FUNCTION_BODY).args("$name").toDocument(Aggregation.DEFAULT_CONTEXT)).isEqualTo( + $function(new Document(EMPTY_ARGS_FUNCTION_DOCUMENT).append("args", Collections.singletonList("$name")))); + } + + private static final String INIT_FUNCTION = "function() { return { count: 0, sum: 0 } }"; + private static final String ACC_FUNCTION = "function(state, numCopies) { return { count: state.count + 1, sum: state.sum + numCopies } }"; + private static final String MERGE_FUNCTION = "function(state1, state2) { return { count: state1.count + state2.count, sum: state1.sum + state2.sum } }"; + private static final String FINALIZE_FUNCTION = "function(state) { return (state.sum / state.count) }"; + + private static final Document $ACCUMULATOR = Document.parse("{" + // + " $accumulator:" + // + " {" + // + " init: '" + INIT_FUNCTION + "'," + // + " accumulate: '" + ACC_FUNCTION + "'," + // + " accumulateArgs: [\"$copies\"]," + // + " merge: '" + MERGE_FUNCTION + "'," + // + " finalize: '" + FINALIZE_FUNCTION + "'," + // + " lang: \"js\"" + // + " }" + // + " }" + // + " }"); + + @Test // DATAMONGO-2623 + void accumulatorWithStringInput() { + + Accumulator accumulator = accumulatorBuilder() // + .init(INIT_FUNCTION) // + .accumulate(ACC_FUNCTION).accumulateArgs("$copies") // + .merge(MERGE_FUNCTION) // + .finalize(FINALIZE_FUNCTION); + + assertThat(accumulator.toDocument(Aggregation.DEFAULT_CONTEXT)).isEqualTo($ACCUMULATOR); + } + + @Test // DATAMONGO-2623 + void accumulatorWithFunctionInput() { + + Accumulator accumulator = accumulatorBuilder() // + .init(function(INIT_FUNCTION)) // + .accumulate(function(ACC_FUNCTION).args("$copies")) // + .merge(MERGE_FUNCTION) // + .finalize(FINALIZE_FUNCTION); + + assertThat(accumulator.toDocument(Aggregation.DEFAULT_CONTEXT)).isEqualTo($ACCUMULATOR); + } + + static Document $function(Document source) { + return new Document("$function", source); + } +} diff --git a/src/main/asciidoc/new-features.adoc b/src/main/asciidoc/new-features.adoc index 071d56f32..305e349cd 100644 --- a/src/main/asciidoc/new-features.adoc +++ b/src/main/asciidoc/new-features.adoc @@ -8,6 +8,7 @@ * Reactive SpEL support in `@Query` and `@Aggregation` query methods. * Aggregation hints via `AggregationOptions.builder().hint(bson).build()`. * Extension Function `KProperty.asPath()` to render property references into a property path representation. +* Server-side JavaScript aggregation expressions `$function` and `$accumulator` via `ScriptOperators`. [[new-features.3.0]] == What's New in Spring Data MongoDB 3.0 diff --git a/src/main/asciidoc/reference/mongodb.adoc b/src/main/asciidoc/reference/mongodb.adoc index 75ad097ce..7c6395fc5 100644 --- a/src/main/asciidoc/reference/mongodb.adoc +++ b/src/main/asciidoc/reference/mongodb.adoc @@ -2559,6 +2559,7 @@ The MongoDB Aggregation Framework provides the following types of aggregation op * Lookup Aggregation Operators * Convert Aggregation Operators * Object Aggregation Operators +* Script Aggregation Operators At the time of this writing, we provide support for the following Aggregation Operations in Spring Data MongoDB: @@ -2606,6 +2607,9 @@ At the time of this writing, we provide support for the following Aggregation Op | Object Aggregation Operators | `objectToArray`, `mergeObjects` + +| Script Aggregation Operators +| `function`, `accumulator` |=== * The operation is mapped or added by Spring Data MongoDB.