Browse Source

Add support for `$covariancePop` and `$covarianceSamp` aggregation expressions.

Closes: #3712
Original pull request: #3740.
pull/3785/head
Christoph Strobl 5 years ago committed by Mark Paluch
parent
commit
c574e5cf8a
No known key found for this signature in database
GPG Key ID: 4406B84C1661DCD1
  1. 177
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java
  2. 59
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperators.java
  3. 2
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java
  4. 75
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/aggregation/TestAggregationContext.java
  5. 77
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java
  6. 10
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java

177
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java

@ -142,6 +142,63 @@ public class AccumulatorOperators { @@ -142,6 +142,63 @@ public class AccumulatorOperators {
return usesFieldRef() ? StdDevSamp.stdDevSampOf(fieldReference) : StdDevSamp.stdDevSampOf(expression);
}
/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given
* field to calculate the population covariance of the two.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovariancePop covariancePop(String fieldReference) {
return covariancePop().and(fieldReference);
}
/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given
* {@link AggregationExpression expression} to calculate the population covariance of the two.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovariancePop covariancePop(AggregationExpression expression) {
return covariancePop().and(expression);
}
private CovariancePop covariancePop() {
return usesFieldRef() ? CovariancePop.covariancePopOf(fieldReference) : CovariancePop.covariancePopOf(expression);
}
/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given
* field to calculate the sample covariance of the two.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovarianceSamp covarianceSamp(String fieldReference) {
return covarianceSamp().and(fieldReference);
}
/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given
* {@link AggregationExpression expression} to calculate the sample covariance of the two.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovarianceSamp covarianceSamp(AggregationExpression expression) {
return covarianceSamp().and(expression);
}
private CovarianceSamp covarianceSamp() {
return usesFieldRef() ? CovarianceSamp.covarianceSampOf(fieldReference)
: CovarianceSamp.covarianceSampOf(expression);
}
private boolean usesFieldRef() {
return fieldReference != null;
}
@ -658,4 +715,124 @@ public class AccumulatorOperators { @@ -658,4 +715,124 @@ public class AccumulatorOperators {
return super.toDocument(value, context);
}
}
/**
* {@link AggregationExpression} for {@code $covariancePop}.
*
* @author Christoph Strobl
* @since 3.3
*/
public static class CovariancePop extends AbstractAggregationExpression {
private CovariancePop(Object value) {
super(value);
}
/**
* Creates new {@link CovariancePop}.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
*/
public static CovariancePop covariancePopOf(String fieldReference) {
Assert.notNull(fieldReference, "FieldReference must not be null!");
return new CovariancePop(asFields(fieldReference));
}
/**
* Creates new {@link CovariancePop}.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
*/
public static CovariancePop covariancePopOf(AggregationExpression expression) {
return new CovariancePop(Collections.singletonList(expression));
}
/**
* Creates new {@link CovariancePop} with all previously added arguments appending the given one.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
*/
public CovariancePop and(String fieldReference) {
return new CovariancePop(append(asFields(fieldReference)));
}
/**
* Creates new {@link CovariancePop} with all previously added arguments appending the given one.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
*/
public CovariancePop and(AggregationExpression expression) {
return new CovariancePop(append(expression));
}
@Override
protected String getMongoMethod() {
return "$covariancePop";
}
}
/**
* {@link AggregationExpression} for {@code $covarianceSamp}.
*
* @author Christoph Strobl
* @since 3.3
*/
public static class CovarianceSamp extends AbstractAggregationExpression {
private CovarianceSamp(Object value) {
super(value);
}
/**
* Creates new {@link CovarianceSamp}.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovarianceSamp}.
*/
public static CovarianceSamp covarianceSampOf(String fieldReference) {
Assert.notNull(fieldReference, "FieldReference must not be null!");
return new CovarianceSamp(asFields(fieldReference));
}
/**
* Creates new {@link CovarianceSamp}.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovarianceSamp}.
*/
public static CovarianceSamp covarianceSampOf(AggregationExpression expression) {
return new CovarianceSamp(Collections.singletonList(expression));
}
/**
* Creates new {@link CovarianceSamp} with all previously added arguments appending the given one.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovarianceSamp}.
*/
public CovarianceSamp and(String fieldReference) {
return new CovarianceSamp(append(asFields(fieldReference)));
}
/**
* Creates new {@link CovarianceSamp} with all previously added arguments appending the given one.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovarianceSamp}.
*/
public CovarianceSamp and(AggregationExpression expression) {
return new CovarianceSamp(append(expression));
}
@Override
protected String getMongoMethod() {
return "$covarianceSamp";
}
}
}

59
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperators.java

@ -19,6 +19,8 @@ import java.util.Collections; @@ -19,6 +19,8 @@ import java.util.Collections;
import java.util.List;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Avg;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovariancePop;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovarianceSamp;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Max;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Min;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.StdDevPop;
@ -511,6 +513,63 @@ public class ArithmeticOperators { @@ -511,6 +513,63 @@ public class ArithmeticOperators {
: AccumulatorOperators.StdDevSamp.stdDevSampOf(expression);
}
/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given
* field to calculate the population covariance of the two.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovariancePop covariancePop(String fieldReference) {
return covariancePop().and(fieldReference);
}
/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given
* {@link AggregationExpression expression} to calculate the population covariance of the two.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovariancePop covariancePop(AggregationExpression expression) {
return covariancePop().and(expression);
}
private CovariancePop covariancePop() {
return usesFieldRef() ? CovariancePop.covariancePopOf(fieldReference) : CovariancePop.covariancePopOf(expression);
}
/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given
* field to calculate the sample covariance of the two.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovarianceSamp covarianceSamp(String fieldReference) {
return covarianceSamp().and(fieldReference);
}
/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given
* {@link AggregationExpression expression} to calculate the sample covariance of the two.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovarianceSamp covarianceSamp(AggregationExpression expression) {
return covarianceSamp().and(expression);
}
private CovarianceSamp covarianceSamp() {
return usesFieldRef() ? CovarianceSamp.covarianceSampOf(fieldReference)
: CovarianceSamp.covarianceSampOf(expression);
}
/**
* Creates new {@link AggregationExpression} that rounds a number to a whole integer or to a specified decimal
* place.

2
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java

@ -170,6 +170,8 @@ public class MethodReferenceNode extends ExpressionNode { @@ -170,6 +170,8 @@ public class MethodReferenceNode extends ExpressionNode {
map.put("addToSet", singleArgRef().forOperator("$addToSet"));
map.put("stdDevPop", arrayArgRef().forOperator("$stdDevPop"));
map.put("stdDevSamp", arrayArgRef().forOperator("$stdDevSamp"));
map.put("covariancePop", arrayArgRef().forOperator("$covariancePop"));
map.put("covarianceSamp", arrayArgRef().forOperator("$covarianceSamp"));
// TYPE OPERATORS
map.put("type", singleArgRef().forOperator("$type"));

75
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/aggregation/TestAggregationContext.java

@ -0,0 +1,75 @@ @@ -0,0 +1,75 @@
/*
* Copyright 2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.util.aggregation;
import org.bson.Document;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference;
import org.springframework.data.mongodb.core.aggregation.Field;
import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.lang.Nullable;
/**
* @author Christoph Strobl
*/
public class TestAggregationContext implements AggregationOperationContext {
private final AggregationOperationContext delegate;
private TestAggregationContext(AggregationOperationContext delegate) {
this.delegate = delegate;
}
public static AggregationOperationContext contextFor(@Nullable Class<?> type) {
MappingMongoConverter mongoConverter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE,
new MongoMappingContext());
mongoConverter.afterPropertiesSet();
return contextFor(type, mongoConverter);
}
public static AggregationOperationContext contextFor(@Nullable Class<?> type, MongoConverter mongoConverter) {
if (type == null) {
return Aggregation.DEFAULT_CONTEXT;
}
return new TestAggregationContext(new TypeBasedAggregationOperationContext(type, mongoConverter.getMappingContext(),
new QueryMapper(mongoConverter)).continueOnMissingFieldReference());
}
@Override
public Document getMappedObject(Document document, @Nullable Class<?> type) {
return delegate.getMappedObject(document, type);
}
@Override
public FieldReference getReference(Field field) {
return delegate.getReference(field);
}
@Override
public FieldReference getReference(String name) {
return delegate.getReference(name);
}
}

77
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java

@ -0,0 +1,77 @@ @@ -0,0 +1,77 @@
/*
* Copyright 2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core.aggregation;
import static org.assertj.core.api.Assertions.*;
import java.util.Arrays;
import java.util.Date;
import org.bson.Document;
import org.junit.jupiter.api.Test;
import org.springframework.data.mongodb.core.aggregation.DateOperators.Year;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.util.aggregation.TestAggregationContext;
/**
* @author Christoph Strobl
*/
class AccumulatorOperatorsUnitTests {
@Test // GH-3712
void rendersCovariancePopWithFieldReference() {
assertThat(AccumulatorOperators.valueOf("balance").covariancePop("midichlorianCount")
.toDocument(TestAggregationContext.contextFor(Jedi.class)))
.isEqualTo(new Document("$covariancePop", Arrays.asList("$balance", "$force")));
}
@Test // GH-3712
void rendersCovariancePopWithExpression() {
assertThat(AccumulatorOperators.valueOf(Year.yearOf("birthdate")).covariancePop("midichlorianCount")
.toDocument(TestAggregationContext.contextFor(Jedi.class)))
.isEqualTo(new Document("$covariancePop", Arrays.asList(new Document("$year", "$birthdate"), "$force")));
}
@Test // GH-3712
void rendersCovarianceSampWithFieldReference() {
assertThat(AccumulatorOperators.valueOf("balance").covarianceSamp("midichlorianCount")
.toDocument(TestAggregationContext.contextFor(Jedi.class)))
.isEqualTo(new Document("$covarianceSamp", Arrays.asList("$balance", "$force")));
}
@Test // GH-3712
void rendersCovarianceSampWithExpression() {
assertThat(AccumulatorOperators.valueOf(Year.yearOf("birthdate")).covarianceSamp("midichlorianCount")
.toDocument(TestAggregationContext.contextFor(Jedi.class)))
.isEqualTo(new Document("$covarianceSamp", Arrays.asList(new Document("$year", "$birthdate"), "$force")));
}
static class Jedi {
String name;
Date birthdate;
@Field("force")
Integer midichlorianCount;
Integer balance;
}
}

10
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java

@ -946,6 +946,16 @@ public class SpelExpressionTransformerUnitTests { @@ -946,6 +946,16 @@ public class SpelExpressionTransformerUnitTests {
assertThat(transform("round(field, 2)")).isEqualTo(Document.parse("{ \"$round\" : [\"$field\", 2]}"));
}
@Test // GH-3712
void shouldRenderCovariancePop() {
assertThat(transform("covariancePop(field1, field2)")).isEqualTo(Document.parse("{ \"$covariancePop\" : [\"$field1\", \"$field2\"]}"));
}
@Test // GH-3712
void shouldRenderCovarianceSamp() {
assertThat(transform("covarianceSamp(field1, field2)")).isEqualTo(Document.parse("{ \"$covarianceSamp\" : [\"$field1\", \"$field2\"]}"));
}
private Object transform(String expression, Object... params) {
Object result = transformer.transform(expression, Aggregation.DEFAULT_CONTEXT, params);
return result == null ? null : (!(result instanceof org.bson.Document) ? result.toString() : result);

Loading…
Cancel
Save