diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index f3d1d03ac..dd81685fc 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -1,5 +1,5 @@ /* - * Copyright 2013 the original author or authors. + * Copyright 2013-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -248,7 +248,7 @@ public class Aggregation { if (operation instanceof FieldsExposingAggregationOperation) { FieldsExposingAggregationOperation exposedFieldsOperation = (FieldsExposingAggregationOperation) operation; - context = new ExposedFieldsAggregationOperationContext(exposedFieldsOperation.getFields()); + context = new ExposedFieldsAggregationOperationContext(exposedFieldsOperation.getFields(), rootContext); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java index ca7d8db5b..6ac32fa4a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java @@ -32,16 +32,22 @@ import com.mongodb.DBObject; class ExposedFieldsAggregationOperationContext implements AggregationOperationContext { private final ExposedFields exposedFields; + private final AggregationOperationContext rootContext; /** - * Creates a new {@link ExposedFieldsAggregationOperationContext} from the given {@link ExposedFields}. + * Creates a new {@link ExposedFieldsAggregationOperationContext} from the given {@link ExposedFields}. Uses the given + * {@link AggregationOperationContext} to perform a mapping to mongo types if necessary. * * @param exposedFields must not be {@literal null}. + * @param rootContext must not be {@literal null}. */ - public ExposedFieldsAggregationOperationContext(ExposedFields exposedFields) { + public ExposedFieldsAggregationOperationContext(ExposedFields exposedFields, AggregationOperationContext rootContext) { Assert.notNull(exposedFields, "ExposedFields must not be null!"); + Assert.notNull(rootContext, "RootContext must not be null!"); + this.exposedFields = exposedFields; + this.rootContext = rootContext; } /* @@ -50,7 +56,7 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo */ @Override public DBObject getMappedObject(DBObject dbObject) { - return dbObject; + return rootContext.getMappedObject(dbObject); } /* diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContextUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContextUnitTests.java index 6aae2af59..3ec62b847 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContextUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContextUnitTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2013 the original author or authors. + * Copyright 2013-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,25 +17,40 @@ package org.springframework.data.mongodb.core.aggregation; import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.*; +import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; +import java.util.Arrays; +import java.util.List; + +import org.bson.types.ObjectId; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.core.convert.converter.Converter; +import org.springframework.core.convert.support.GenericConversionService; import org.springframework.data.annotation.Id; +import org.springframework.data.annotation.PersistenceConstructor; import org.springframework.data.mapping.model.MappingException; import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; +import org.springframework.data.mongodb.core.convert.CustomConversions; import org.springframework.data.mongodb.core.convert.DbRefResolver; import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.mapping.Document; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; +import org.springframework.data.mongodb.core.query.Criteria; + +import com.mongodb.BasicDBObject; +import com.mongodb.DBObject; /** * Unit tests for {@link TypeBasedAggregationOperationContext}. * * @author Oliver Gierke + * @author Thomas Darimont */ @RunWith(MockitoJUnitRunner.class) public class TypeBasedAggregationOperationContextUnitTests { @@ -89,6 +104,104 @@ public class TypeBasedAggregationOperationContextUnitTests { assertThat(context.getReference("id"), is(new FieldReference(new ExposedField(Fields.field("id", "_id"), true)))); } + /** + * @see DATAMONGO-912 + */ + @Test + public void shouldUseCustomConversionIfPresentAndConversionIsRequiredInFirstStage() { + + CustomConversions customConversions = customAgeConversions(); + converter.setCustomConversions(customConversions); + customConversions.registerConvertersIn((GenericConversionService) converter.getConversionService()); + + AggregationOperationContext context = getContext(FooPerson.class); + + MatchOperation matchStage = match(Criteria.where("age").is(new Age(10))); + ProjectionOperation projectStage = project("age", "name"); + + DBObject agg = newAggregation(matchStage, projectStage).toDbObject("test", context); + + DBObject age = getValue((DBObject) getValue(getPipelineElementFromAggregationAt(agg, 0), "$match"), "age"); + assertThat(age, is((DBObject) new BasicDBObject("v", 10))); + } + + /** + * @see DATAMONGO-912 + */ + @Test + public void shouldUseCustomConversionIfPresentAndConversionIsRequiredInLaterStage() { + + CustomConversions customConversions = customAgeConversions(); + converter.setCustomConversions(customConversions); + customConversions.registerConvertersIn((GenericConversionService) converter.getConversionService()); + + AggregationOperationContext context = getContext(FooPerson.class); + + MatchOperation matchStage = match(Criteria.where("age").is(new Age(10))); + ProjectionOperation projectStage = project("age", "name"); + + DBObject agg = newAggregation(projectStage, matchStage).toDbObject("test", context); + + DBObject age = getValue((DBObject) getValue(getPipelineElementFromAggregationAt(agg, 1), "$match"), "age"); + assertThat(age, is((DBObject) new BasicDBObject("v", 10))); + } + + @Document(collection = "person") + public static class FooPerson { + + final ObjectId id; + final String name; + final Age age; + + @PersistenceConstructor + FooPerson(ObjectId id, String name, Age age) { + this.id = id; + this.name = name; + this.age = age; + } + } + + public static class Age { + + final int value; + + Age(int value) { + this.value = value; + } + } + + public CustomConversions customAgeConversions() { + return new CustomConversions(Arrays.> asList(ageWriteConverter(), ageReadConverter())); + } + + Converter ageWriteConverter() { + return new Converter() { + @Override + public DBObject convert(Age age) { + return new BasicDBObject("v", age.value); + } + }; + } + + Converter ageReadConverter() { + return new Converter() { + @Override + public Age convert(DBObject dbObject) { + return new Age(((Integer) dbObject.get("v"))); + } + }; + } + + @SuppressWarnings("unchecked") + static DBObject getPipelineElementFromAggregationAt(DBObject agg, int index) { + return ((List) agg.get("pipeline")).get(index); + } + + @SuppressWarnings("unchecked") + static T getValue(DBObject o, String key) { + return (T) o.get(key); + } + private TypeBasedAggregationOperationContext getContext(Class type) { return new TypeBasedAggregationOperationContext(type, context, mapper); }