Browse Source

DATAMONGO-924 - Improve aggregation field reference resolving.

Previously we didn't support referring to aliased fields defined in former stages of an aggregation pipeline. We now also consider field aliases during field reference lookup.

Original pull request: #176.
1.4.x
Thomas Darimont 12 years ago committed by Oliver Gierke
parent
commit
f0fc3961d2
  1. 14
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFields.java
  2. 25
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java
  3. 83
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java
  4. 24
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java

14
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFields.java

@ -268,14 +268,21 @@ public final class ExposedFields implements Iterable<ExposedField> { @@ -268,14 +268,21 @@ public final class ExposedFields implements Iterable<ExposedField> {
return field.isAliased();
}
/**
* @return the synthetic
*/
public boolean isSynthetic() {
return synthetic;
}
/**
* Returns whether the field can be referred to using the given name.
*
* @param input
* @param name
* @return
*/
public boolean canBeReferredToBy(String input) {
return getTarget().equals(input);
public boolean canBeReferredToBy(String name) {
return getName().equals(name) || getTarget().equals(name);
}
/*
@ -340,6 +347,7 @@ public final class ExposedFields implements Iterable<ExposedField> { @@ -340,6 +347,7 @@ public final class ExposedFields implements Iterable<ExposedField> {
public FieldReference(ExposedField field) {
Assert.notNull(field, "ExposedField must not be null!");
this.field = field;
}

25
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java

@ -65,7 +65,7 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo @@ -65,7 +65,7 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo
*/
@Override
public FieldReference getReference(Field field) {
return getReference(field.getTarget());
return getReference(field, field.getTarget());
}
/*
@ -74,13 +74,30 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo @@ -74,13 +74,30 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo
*/
@Override
public FieldReference getReference(String name) {
return getReference(null, name);
}
/**
* Returns a {@link FieldReference} to the given {@link Field} with the given {@code name}.
*
* @param field may be {@literal null}
* @param name must not be {@literal null}
* @return
*/
private FieldReference getReference(Field field, String name) {
Assert.notNull(name, "Name must not be null!");
ExposedField field = exposedFields.getField(name);
ExposedField exposedField = exposedFields.getField(name);
if (exposedField != null) {
if (field != null) {
// we return a FieldReference to the given field directly to make sure that we reference the proper alias here.
return new FieldReference(new ExposedField(field, exposedField.isSynthetic()));
}
if (field != null) {
return new FieldReference(field);
return new FieldReference(exposedField);
}
if (name.contains(".")) {

83
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java

@ -47,6 +47,7 @@ import org.springframework.data.annotation.Id; @@ -47,6 +47,7 @@ import org.springframework.data.annotation.Id;
import org.springframework.data.mapping.model.MappingException;
import org.springframework.data.mongodb.core.CollectionCallback;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.AggregationTests.CarDescriptor.Entry;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.util.Version;
import org.springframework.test.context.ContextConfiguration;
@ -822,6 +823,40 @@ public class AggregationTests { @@ -822,6 +823,40 @@ public class AggregationTests {
assertThat(invoice.getTotalAmount(), is(closeTo(9.877, 000001)));
}
/**
* @see DATAMONGO-924
*/
@Test
public void shouldAllowGroupingByAliasedFieldDefinedInFormerAggregationStage() {
mongoTemplate.dropCollection(CarPerson.class);
CarPerson person1 = new CarPerson("first1", "last1", new CarDescriptor.Entry("MAKE1", "MODEL1", 2000),
new CarDescriptor.Entry("MAKE1", "MODEL2", 2001), new CarDescriptor.Entry("MAKE2", "MODEL3", 2010),
new CarDescriptor.Entry("MAKE3", "MODEL4", 2014));
CarPerson person2 = new CarPerson("first2", "last2", new CarDescriptor.Entry("MAKE3", "MODEL4", 2014));
CarPerson person3 = new CarPerson("first3", "last3", new CarDescriptor.Entry("MAKE2", "MODEL5", 2011));
mongoTemplate.save(person1);
mongoTemplate.save(person2);
mongoTemplate.save(person3);
TypedAggregation<CarPerson> agg = Aggregation.newAggregation(CarPerson.class,
unwind("descriptors.carDescriptor.entries"), //
project() //
.and("descriptors.carDescriptor.entries.make").as("make") //
.and("descriptors.carDescriptor.entries.model").as("model") //
.and("firstName").as("firstName") //
.and("lastName").as("lastName"), //
group("make"));
AggregationResults<DBObject> result = mongoTemplate.aggregate(agg, DBObject.class);
assertThat(result.getMappedResults(), hasSize(3));
}
private void assertLikeStats(LikeStats like, String id, long count) {
assertThat(like, is(notNullValue()));
@ -938,4 +973,52 @@ public class AggregationTests { @@ -938,4 +973,52 @@ public class AggregationTests {
this.createDate = createDate;
}
}
@org.springframework.data.mongodb.core.mapping.Document
static class CarPerson {
@Id private String id;
private String firstName;
private String lastName;
private Descriptors descriptors;
public CarPerson(String firstname, String lastname, Entry... entries) {
this.firstName = firstname;
this.lastName = lastname;
this.descriptors = new Descriptors();
this.descriptors.carDescriptor = new CarDescriptor(entries);
}
}
static class Descriptors {
private CarDescriptor carDescriptor;
}
static class CarDescriptor {
private List<Entry> entries = new ArrayList<AggregationTests.CarDescriptor.Entry>();
public CarDescriptor(Entry... entries) {
for (Entry entry : entries) {
this.entries.add(entry);
}
}
static class Entry {
private String make;
private String model;
private int year;
public Entry() {}
public Entry(String make, String model, int year) {
this.make = make;
this.model = model;
this.year = year;
}
}
}
}

24
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java

@ -200,4 +200,28 @@ public class AggregationUnitTests { @@ -200,4 +200,28 @@ public class AggregationUnitTests {
assertThat(id.get("ruleType"), is((Object) "$rules.ruleType"));
}
/**
* @see DATAMONGO-924
*/
@Test
public void referencingProjectionAliasesFromPreviousStepShouldReferToTheSameFieldTarget() {
DBObject agg = newAggregation( //
project().and("foo.bar").as("ba") //
, project().and("ba").as("b") //
).toDbObject("foo", Aggregation.DEFAULT_CONTEXT);
DBObject projection0 = extractPipelineElement(agg, 0, "$project");
assertThat(projection0, is((DBObject) new BasicDBObject("ba", "$foo.bar")));
DBObject projection1 = extractPipelineElement(agg, 1, "$project");
assertThat(projection1, is((DBObject) new BasicDBObject("b", "$ba")));
}
private DBObject extractPipelineElement(DBObject agg, int index, String operation) {
List<DBObject> pipeline = (List<DBObject>) agg.get("pipeline");
return (DBObject) pipeline.get(index).get(operation);
}
}

Loading…
Cancel
Save