@ -15,6 +15,11 @@
@@ -15,6 +15,11 @@
* /
package org.springframework.data.mongodb.core.aggregation ;
import java.util.Collection ;
import java.util.HashSet ;
import java.util.Set ;
import java.util.stream.Collectors ;
import org.bson.Document ;
import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference ;
import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExpressionFieldReference ;
@ -31,16 +36,18 @@ import org.springframework.util.Assert;
@@ -31,16 +36,18 @@ import org.springframework.util.Assert;
class NestedDelegatingExpressionAggregationOperationContext implements AggregationOperationContext {
private final AggregationOperationContext delegate ;
private final Set < String > inners ;
/ * *
* Creates new { @link NestedDelegatingExpressionAggregationOperationContext } .
*
* @param referenceContext must not be { @literal null } .
* /
public NestedDelegatingExpressionAggregationOperationContext ( AggregationOperationContext referenceContext ) {
NestedDelegatingExpressionAggregationOperationContext ( AggregationOperationContext referenceContext , Collection < Field > inners ) {
Assert . notNull ( referenceContext , "Reference context must not be null!" ) ;
this . delegate = referenceContext ;
this . inners = inners . stream ( ) . map ( Field : : getName ) . collect ( Collectors . toSet ( ) ) ;
}
/ *
@ -67,7 +74,22 @@ class NestedDelegatingExpressionAggregationOperationContext implements Aggregati
@@ -67,7 +74,22 @@ class NestedDelegatingExpressionAggregationOperationContext implements Aggregati
* /
@Override
public FieldReference getReference ( Field field ) {
return new ExpressionFieldReference ( delegate . getReference ( field ) ) ;
FieldReference reference = delegate . getReference ( field ) ;
return ! isInnerVariableReference ( field ) ? reference : new ExpressionFieldReference ( delegate . getReference ( field ) ) ;
}
private boolean isInnerVariableReference ( Field field ) {
if ( inners . isEmpty ( ) ) {
return false ;
}
if ( inners . contains ( field . getName ( ) ) ) {
return true ;
}
return inners . stream ( ) . anyMatch ( it - > field . getTarget ( ) . contains ( "." ) & & field . getTarget ( ) . startsWith ( it ) ) ;
}
/ *