diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFields.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFields.java index 764398d22..9021626ac 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFields.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFields.java @@ -268,14 +268,21 @@ public final class ExposedFields implements Iterable { return field.isAliased(); } + /** + * @return the synthetic + */ + public boolean isSynthetic() { + return synthetic; + } + /** * Returns whether the field can be referred to using the given name. * - * @param input + * @param name * @return */ - public boolean canBeReferredToBy(String input) { - return getTarget().equals(input); + public boolean canBeReferredToBy(String name) { + return getName().equals(name) || getTarget().equals(name); } /* @@ -340,6 +347,7 @@ public final class ExposedFields implements Iterable { public FieldReference(ExposedField field) { Assert.notNull(field, "ExposedField must not be null!"); + this.field = field; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java index 6ac32fa4a..0e3a05d3c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java @@ -65,7 +65,7 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo */ @Override public FieldReference getReference(Field field) { - return getReference(field.getTarget()); + return getReference(field, field.getTarget()); } /* @@ -74,13 +74,30 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo */ @Override public FieldReference getReference(String name) { + return getReference(null, name); + } + + /** + * Returns a {@link FieldReference} to the given {@link Field} with the given {@code name}. + * + * @param field may be {@literal null} + * @param name must not be {@literal null} + * @return + */ + private FieldReference getReference(Field field, String name) { Assert.notNull(name, "Name must not be null!"); - ExposedField field = exposedFields.getField(name); + ExposedField exposedField = exposedFields.getField(name); + + if (exposedField != null) { + + if (field != null) { + // we return a FieldReference to the given field directly to make sure that we reference the proper alias here. + return new FieldReference(new ExposedField(field, exposedField.isSynthetic())); + } - if (field != null) { - return new FieldReference(field); + return new FieldReference(exposedField); } if (name.contains(".")) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index 6361985b7..25cbd3b89 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -47,6 +47,7 @@ import org.springframework.data.annotation.Id; import org.springframework.data.mapping.model.MappingException; import org.springframework.data.mongodb.core.CollectionCallback; import org.springframework.data.mongodb.core.MongoTemplate; +import org.springframework.data.mongodb.core.aggregation.AggregationTests.CarDescriptor.Entry; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.util.Version; import org.springframework.test.context.ContextConfiguration; @@ -822,6 +823,40 @@ public class AggregationTests { assertThat(invoice.getTotalAmount(), is(closeTo(9.877, 000001))); } + /** + * @see DATAMONGO-924 + */ + @Test + public void shouldAllowGroupingByAliasedFieldDefinedInFormerAggregationStage() { + + mongoTemplate.dropCollection(CarPerson.class); + + CarPerson person1 = new CarPerson("first1", "last1", new CarDescriptor.Entry("MAKE1", "MODEL1", 2000), + new CarDescriptor.Entry("MAKE1", "MODEL2", 2001), new CarDescriptor.Entry("MAKE2", "MODEL3", 2010), + new CarDescriptor.Entry("MAKE3", "MODEL4", 2014)); + + CarPerson person2 = new CarPerson("first2", "last2", new CarDescriptor.Entry("MAKE3", "MODEL4", 2014)); + + CarPerson person3 = new CarPerson("first3", "last3", new CarDescriptor.Entry("MAKE2", "MODEL5", 2011)); + + mongoTemplate.save(person1); + mongoTemplate.save(person2); + mongoTemplate.save(person3); + + TypedAggregation agg = Aggregation.newAggregation(CarPerson.class, + unwind("descriptors.carDescriptor.entries"), // + project() // + .and("descriptors.carDescriptor.entries.make").as("make") // + .and("descriptors.carDescriptor.entries.model").as("model") // + .and("firstName").as("firstName") // + .and("lastName").as("lastName"), // + group("make")); + + AggregationResults result = mongoTemplate.aggregate(agg, DBObject.class); + + assertThat(result.getMappedResults(), hasSize(3)); + } + private void assertLikeStats(LikeStats like, String id, long count) { assertThat(like, is(notNullValue())); @@ -938,4 +973,52 @@ public class AggregationTests { this.createDate = createDate; } } + + @org.springframework.data.mongodb.core.mapping.Document + static class CarPerson { + + @Id private String id; + private String firstName; + private String lastName; + private Descriptors descriptors; + + public CarPerson(String firstname, String lastname, Entry... entries) { + this.firstName = firstname; + this.lastName = lastname; + + this.descriptors = new Descriptors(); + + this.descriptors.carDescriptor = new CarDescriptor(entries); + } + } + + static class Descriptors { + private CarDescriptor carDescriptor; + } + + static class CarDescriptor { + + private List entries = new ArrayList(); + + public CarDescriptor(Entry... entries) { + + for (Entry entry : entries) { + this.entries.add(entry); + } + } + + static class Entry { + private String make; + private String model; + private int year; + + public Entry() {} + + public Entry(String make, String model, int year) { + this.make = make; + this.model = model; + this.year = year; + } + } + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java index 2012200e6..95635bc66 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java @@ -200,4 +200,28 @@ public class AggregationUnitTests { assertThat(id.get("ruleType"), is((Object) "$rules.ruleType")); } + + /** + * @see DATAMONGO-924 + */ + @Test + public void referencingProjectionAliasesFromPreviousStepShouldReferToTheSameFieldTarget() { + + DBObject agg = newAggregation( // + project().and("foo.bar").as("ba") // + , project().and("ba").as("b") // + ).toDbObject("foo", Aggregation.DEFAULT_CONTEXT); + + DBObject projection0 = extractPipelineElement(agg, 0, "$project"); + assertThat(projection0, is((DBObject) new BasicDBObject("ba", "$foo.bar"))); + + DBObject projection1 = extractPipelineElement(agg, 1, "$project"); + assertThat(projection1, is((DBObject) new BasicDBObject("b", "$ba"))); + } + + private DBObject extractPipelineElement(DBObject agg, int index, String operation) { + + List pipeline = (List) agg.get("pipeline"); + return (DBObject) pipeline.get(index).get(operation); + } }