From 32bd540f9163f8ff520ef12f9133365223467877 Mon Sep 17 00:00:00 2001 From: Julia <5765049+sxhinzvc@users.noreply.github.com> Date: Fri, 1 Sep 2023 16:44:13 -0400 Subject: [PATCH] Add support for $percentile aggregation operator. Closes #4473 Original Pull Request: #4496 --- .../aggregation/AccumulatorOperators.java | 105 ++++++++++++++++++ .../core/aggregation/ArithmeticOperators.java | 16 +++ .../core/spel/MethodReferenceNode.java | 3 + .../AccumulatorOperatorsUnitTests.java | 24 ++++ .../core/aggregation/AggregationTests.java | 19 ++++ .../aggregation/GroupOperationUnitTests.java | 14 +++ .../ProjectionOperationUnitTests.java | 20 ++++ .../SpelExpressionTransformerUnitTests.java | 15 ++- .../pages/mongodb/aggregation-framework.adoc | 2 +- 9 files changed, 216 insertions(+), 2 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java index 9bb0d9e01..a69555c4d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java @@ -15,8 +15,11 @@ */ package org.springframework.data.mongodb.core.aggregation; +import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.bson.Document; import org.springframework.util.Assert; @@ -25,6 +28,7 @@ import org.springframework.util.Assert; * Gateway to {@literal accumulator} aggregation operations. * * @author Christoph Strobl + * @author Julia Lee * @since 1.10 * @soundtrack Rage Against The Machine - Killing In The Name */ @@ -52,6 +56,7 @@ public class AccumulatorOperators { /** * @author Christoph Strobl + * @author Julia Lee */ public static class AccumulatorOperatorFactory { @@ -246,6 +251,20 @@ public class AccumulatorOperators { }; } + /** + * Creates new {@link AggregationExpression} that calculates the requested percentile(s) of the + * associated numeric value expression. + * + * @return new instance of {@link Percentile}. + * @param percentages must not be {@literal null}. + * @since 4.2 + */ + public Percentile percentile(Double... percentages) { + Percentile percentile = usesFieldRef() ? Percentile.percentileOf(fieldReference) + : Percentile.percentileOf(expression); + return percentile.percentages(percentages); + } + private boolean usesFieldRef() { return fieldReference != null; } @@ -977,4 +996,90 @@ public class AccumulatorOperators { return "$expMovingAvg"; } } + + /** + * {@link AggregationExpression} for {@code $percentile}. + * + * @author Julia Lee + * @since 4.2 + */ + public static class Percentile extends AbstractAggregationExpression { + + private Percentile(Object value) { + super(value); + } + + /** + * Creates new {@link Percentile}. + * + * @param fieldReference must not be {@literal null}. + * @return new instance of {@link Percentile}. + */ + public static Percentile percentileOf(String fieldReference) { + + Assert.notNull(fieldReference, "FieldReference must not be null"); + Map fields = new HashMap<>(); + fields.put("input", Fields.field(fieldReference)); + fields.put("method", "approximate"); + return new Percentile(fields); + } + + /** + * Creates new {@link Percentile}. + * + * @param expression must not be {@literal null}. + * @return new instance of {@link Percentile}. + */ + public static Percentile percentileOf(AggregationExpression expression) { + + Assert.notNull(expression, "Expression must not be null"); + Map fields = new HashMap<>(); + fields.put("input", expression); + fields.put("method", "approximate"); + return new Percentile(fields); + } + + /** + * Define the percentile value(s) that must resolve to percentages in the range {@code 0.0 - 1.0} inclusive. + * + * @param percentages must not be {@literal null}. + * @return new instance of {@link Percentile}. + */ + public Percentile percentages(Double... percentages) { + + Assert.notEmpty(percentages, "Percentages must not be null or empty"); + return new Percentile(append("p", Arrays.asList(percentages))); + } + + /** + * Creates new {@link Percentile} with all previously added inputs appending the given one.
+ * NOTE: Only possible in {@code $project} stage. + * + * @param fieldReference must not be {@literal null}. + * @return new instance of {@link Percentile}. + */ + public Percentile and(String fieldReference) { + + Assert.notNull(fieldReference, "FieldReference must not be null"); + return new Percentile(appendTo("input", Fields.field(fieldReference))); + } + + /** + * Creates new {@link Percentile} with all previously added inputs appending the given one.
+ * NOTE: Only possible in {@code $project} stage. + * + * @param expression must not be {@literal null}. + * @return new instance of {@link Percentile}. + */ + public Percentile and(AggregationExpression expression) { + + Assert.notNull(expression, "Expression must not be null"); + return new Percentile(appendTo("input", expression)); + } + + @Override + protected String getMongoMethod() { + return "$percentile"; + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperators.java index 2bd4df637..d985e3b7b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperators.java @@ -25,6 +25,7 @@ import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Co import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovarianceSamp; import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Max; import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Min; +import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Percentile; import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.StdDevPop; import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.StdDevSamp; import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Sum; @@ -41,6 +42,7 @@ import org.springframework.util.StringUtils; * @author Christoph Strobl * @author Mark Paluch * @author Mushtaq Ahmed + * @author Julia Lee * @since 1.10 */ public class ArithmeticOperators { @@ -932,6 +934,20 @@ public class ArithmeticOperators { return usesFieldRef() ? Tanh.tanhOf(fieldReference, unit) : Tanh.tanhOf(expression, unit); } + /** + * Creates new {@link AggregationExpression} that calculates the requested percentile(s) of the + * numeric value. + * + * @return new instance of {@link Percentile}. + * @param percentages must not be {@literal null}. + * @since 4.2 + */ + public Percentile percentile(Double... percentages) { + Percentile percentile = usesFieldRef() ? AccumulatorOperators.Percentile.percentileOf(fieldReference) + : AccumulatorOperators.Percentile.percentileOf(expression); + return percentile.percentages(percentages); + } + private boolean usesFieldRef() { return fieldReference != null; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java index 611e3dcd3..c5f51800f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java @@ -35,6 +35,7 @@ import org.springframework.util.ObjectUtils; * @author Sebastien Gerard * @author Christoph Strobl * @author Mark Paluch + * @author Julia Lee */ public class MethodReferenceNode extends ExpressionNode { @@ -228,6 +229,8 @@ public class MethodReferenceNode extends ExpressionNode { .mappingParametersTo("n", "input")); map.put("minN", mapArgRef().forOperator("$minN") // .mappingParametersTo("n", "input")); + map.put("percentile", mapArgRef().forOperator("$percentile") // + .mappingParametersTo("input", "p", "method")); // TYPE OPERATORS map.put("type", singleArgRef().forOperator("$type")); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java index 889a99b81..a43b0f862 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java @@ -31,6 +31,7 @@ import org.springframework.data.mongodb.util.aggregation.TestAggregationContext; * Unit tests for {@link AccumulatorOperators}. * * @author Christoph Strobl + * @author Julia Lee */ class AccumulatorOperatorsUnitTests { @@ -108,6 +109,29 @@ class AccumulatorOperatorsUnitTests { .isEqualTo(Document.parse("{ $minN: { n: 3, input : \"$price\" } }")); } + @Test // GH-4473 + void rendersPercentileWithFieldReference() { + + assertThat(valueOf("score").percentile(0.2).toDocument(Aggregation.DEFAULT_CONTEXT)) + .isEqualTo(Document.parse("{ $percentile: { input: \"$score\", method: \"approximate\", p: [0.2] } }")); + + assertThat(valueOf("score").percentile(0.3, 0.9).toDocument(Aggregation.DEFAULT_CONTEXT)) + .isEqualTo(Document.parse("{ $percentile: { input: \"$score\", method: \"approximate\", p: [0.3, 0.9] } }")); + + assertThat(valueOf("score").percentile(0.3, 0.9).and("scoreTwo").toDocument(Aggregation.DEFAULT_CONTEXT)) + .isEqualTo(Document.parse("{ $percentile: { input: [\"$score\", \"$scoreTwo\"], method: \"approximate\", p: [0.3, 0.9] } }")); + } + + @Test // GH-4473 + void rendersPercentileWithExpression() { + + assertThat(valueOf(Sum.sumOf("score")).percentile(0.1).toDocument(Aggregation.DEFAULT_CONTEXT)) + .isEqualTo(Document.parse("{ $percentile: { input: {\"$sum\": \"$score\"}, method: \"approximate\", p: [0.1] } }")); + + assertThat(valueOf("scoreOne").percentile(0.1, 0.2).and(Sum.sumOf("scoreTwo")).toDocument(Aggregation.DEFAULT_CONTEXT)) + .isEqualTo(Document.parse("{ $percentile: { input: [\"$scoreOne\", {\"$sum\": \"$scoreTwo\"}], method: \"approximate\", p: [0.1, 0.2] } }")); + } + static class Jedi { String name; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index 17b6301c9..5025d7fdc 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -1893,6 +1893,25 @@ public class AggregationTests { assertThat(categorizeByYear).hasSize(3); } + @Test // GH-4473 + @EnableIfMongoServerVersion(isGreaterThanEqual = "7.0") + void percentileShouldBeAppliedCorrectly() { + + mongoTemplate.insert(new DATAMONGO788(15, 16)); + mongoTemplate.insert(new DATAMONGO788(17, 18)); + + Aggregation agg = Aggregation.newAggregation( + project().and(ArithmeticOperators.valueOf("x").percentile(0.9).and("y")) + .as("ninetiethPercentile")); + + AggregationResults result = mongoTemplate.aggregate(agg, DATAMONGO788.class, Document.class); + + // MongoDB server returns $percentile as an array of doubles + List rawResults = (List) result.getRawResults().get("results"); + assertThat((List) rawResults.get(0).get("ninetiethPercentile")).containsExactly(16.0); + assertThat((List) rawResults.get(1).get("ninetiethPercentile")).containsExactly(18.0); + } + @Test // DATAMONGO-1986 void runMatchOperationCriteriaThroughQueryMapperForTypedAggregation() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GroupOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GroupOperationUnitTests.java index 687fdc5d5..d5c3e547e 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GroupOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GroupOperationUnitTests.java @@ -25,6 +25,7 @@ import org.junit.jupiter.api.Test; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort.Direction; import org.springframework.data.mongodb.core.DocumentTestUtils; +import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Percentile; import org.springframework.data.mongodb.core.aggregation.SelectionOperators.Bottom; import org.springframework.data.mongodb.core.query.Criteria; @@ -34,6 +35,7 @@ import org.springframework.data.mongodb.core.query.Criteria; * @author Oliver Gierke * @author Thomas Darimont * @author Gustavo de Geus + * @author Julia Lee */ class GroupOperationUnitTests { @@ -266,6 +268,18 @@ class GroupOperationUnitTests { Document.parse("{ $bottom : { output: [ \"$playerId\", \"$score\" ], sortBy: { \"score\": -1 }}}")); } + @Test // GH-4473 + void groupOperationAllowsAddingFieldWithPercentileAggregationExpression() { + + GroupOperation groupOperation = Aggregation.group("id").and("scorePercentile", + Percentile.percentileOf("score").percentages(0.2)); + + Document groupClause = extractDocumentFromGroupOperation(groupOperation); + + assertThat(groupClause).containsEntry("scorePercentile", + Document.parse("{ $percentile : { input: \"$score\", method: \"approximate\", p: [0.2]}}")); + } + private Document extractDocumentFromGroupOperation(GroupOperation groupOperation) { Document document = groupOperation.toDocument(Aggregation.DEFAULT_CONTEXT); Document groupClause = DocumentTestUtils.getAsDocument(document, "$group"); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java index 227aa34f7..87934ade1 100755 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperationUnitTests.java @@ -2241,6 +2241,26 @@ public class ProjectionOperationUnitTests { "{ $project: { \"author\" : 1, \"myArray\" : [ \"$ti_t_le\", \"plain - string\", { \"$sum\" : [\"$ti_t_le\", 10] } ] } } ] }")); } + @Test // GH-4473 + void shouldRenderPercentileAggregationExpression() { + + Document agg = project() + .and(ArithmeticOperators.valueOf("score").percentile(0.3, 0.9)).as("scorePercentiles") + .toDocument(Aggregation.DEFAULT_CONTEXT); + + assertThat(agg).isEqualTo(Document.parse("{ $project: { scorePercentiles: { $percentile: { input: \"$score\", method: \"approximate\", p: [0.3, 0.9] } }} } }")); + } + + @Test // GH-4473 + void shouldRenderPercentileWithMultipleArgsAggregationExpression() { + + Document agg = project() + .and(ArithmeticOperators.valueOf("scoreOne").percentile(0.4).and("scoreTwo")).as("scorePercentiles") + .toDocument(Aggregation.DEFAULT_CONTEXT); + + assertThat(agg).isEqualTo(Document.parse("{ $project: { scorePercentiles: { $percentile: { input: [\"$scoreOne\", \"$scoreTwo\"], method: \"approximate\", p: [0.4] } }} } }")); + } + private static Document extractOperation(String field, Document fromProjectClause) { return (Document) fromProjectClause.get(field); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java index 3ea6d4a11..d7aa5cfc7 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java @@ -33,6 +33,7 @@ import org.springframework.data.mongodb.core.Person; * @author Oliver Gierke * @author Christoph Strobl * @author Divya Srivastava + * @author Julia Lee */ public class SpelExpressionTransformerUnitTests { @@ -1255,7 +1256,19 @@ public class SpelExpressionTransformerUnitTests { void shouldRenderLocf() { assertThat(transform("locf(price)")).isEqualTo("{ $locf: \"$price\" }"); } - + + @Test // GH-4473 + void shouldRenderPercentile() { + assertThat(transform("percentile(new String[]{\"$scoreOne\", \"$scoreTwo\" }, new double[]{0.4}, \"approximate\")")) + .isEqualTo("{ $percentile : { input : [\"$scoreOne\", \"$scoreTwo\"], p : [0.4], method : \"approximate\" }}"); + + assertThat(transform("percentile(score, new double[]{0.4, 0.85}, \"approximate\")")) + .isEqualTo("{ $percentile : { input : \"$score\", p : [0.4, 0.85], method : \"approximate\" }}"); + + assertThat(transform("percentile(\"$score\", new double[]{0.4, 0.85}, \"approximate\")")) + .isEqualTo("{ $percentile : { input : \"$score\", p : [0.4, 0.85], method : \"approximate\" }}"); + } + private Document transform(String expression, Object... params) { return (Document) transformer.transform(expression, Aggregation.DEFAULT_CONTEXT, params); } diff --git a/src/main/antora/modules/ROOT/pages/mongodb/aggregation-framework.adoc b/src/main/antora/modules/ROOT/pages/mongodb/aggregation-framework.adoc index 3f34e55ae..18cb70d4a 100644 --- a/src/main/antora/modules/ROOT/pages/mongodb/aggregation-framework.adoc +++ b/src/main/antora/modules/ROOT/pages/mongodb/aggregation-framework.adoc @@ -112,7 +112,7 @@ At the time of this writing, we provide support for the following Aggregation Op | `setEquals`, `setIntersection`, `setUnion`, `setDifference`, `setIsSubset`, `anyElementTrue`, `allElementsTrue` | Group/Accumulator Aggregation Operators -| `addToSet`, `bottom`, `bottomN`, `covariancePop`, `covarianceSamp`, `expMovingAvg`, `first`, `firstN`, `last`, `lastN` `max`, `maxN`, `min`, `minN`, `avg`, `push`, `sum`, `top`, `topN`, `count` (+++*+++), `stdDevPop`, `stdDevSamp` +| `addToSet`, `bottom`, `bottomN`, `covariancePop`, `covarianceSamp`, `expMovingAvg`, `first`, `firstN`, `last`, `lastN` `max`, `maxN`, `min`, `minN`, `avg`, `push`, `sum`, `top`, `topN`, `count` (+++*+++), `percentile`, `stdDevPop`, `stdDevSamp` | Arithmetic Aggregation Operators | `abs`, `acos`, `acosh`, `add` (+++*+++ via `plus`), `asin`, `asin`, `atan`, `atan2`, `atanh`, `ceil`, `cos`, `cosh`, `derivative`, `divide`, `exp`, `floor`, `integral`, `ln`, `log`, `log10`, `mod`, `multiply`, `pow`, `round`, `sqrt`, `subtract` (+++*+++ via `minus`), `sin`, `sinh`, `tan`, `tanh`, `trunc`