From df1775572add1a156b0791a62a46415bafe83bff Mon Sep 17 00:00:00 2001 From: Thomas Darimont Date: Thu, 17 Apr 2014 11:27:46 +0200 Subject: [PATCH] DATAMONGO-912 - Consider custom conversions in all stages of an aggregation pipeline. We now consider custom mongo conversions in all stages of an aggregation pipeline. Previously we did this only for the first stage and returned object basically unmapped in later stages. We now pass the root AggregationOperationContext on to nested ExposedFieldsAggregationOperationContexts so that those can delegate any mongo Mapping to the root context. Original pull request: #170. --- .../mongodb/core/aggregation/Aggregation.java | 4 +- ...osedFieldsAggregationOperationContext.java | 12 +- ...dAggregationOperationContextUnitTests.java | 115 +++++++++++++++++- 3 files changed, 125 insertions(+), 6 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index f3d1d03ac..dd81685fc 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -1,5 +1,5 @@ /* - * Copyright 2013 the original author or authors. + * Copyright 2013-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -248,7 +248,7 @@ public class Aggregation { if (operation instanceof FieldsExposingAggregationOperation) { FieldsExposingAggregationOperation exposedFieldsOperation = (FieldsExposingAggregationOperation) operation; - context = new ExposedFieldsAggregationOperationContext(exposedFieldsOperation.getFields()); + context = new ExposedFieldsAggregationOperationContext(exposedFieldsOperation.getFields(), rootContext); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java index ca7d8db5b..6ac32fa4a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java @@ -32,16 +32,22 @@ import com.mongodb.DBObject; class ExposedFieldsAggregationOperationContext implements AggregationOperationContext { private final ExposedFields exposedFields; + private final AggregationOperationContext rootContext; /** - * Creates a new {@link ExposedFieldsAggregationOperationContext} from the given {@link ExposedFields}. + * Creates a new {@link ExposedFieldsAggregationOperationContext} from the given {@link ExposedFields}. Uses the given + * {@link AggregationOperationContext} to perform a mapping to mongo types if necessary. * * @param exposedFields must not be {@literal null}. + * @param rootContext must not be {@literal null}. */ - public ExposedFieldsAggregationOperationContext(ExposedFields exposedFields) { + public ExposedFieldsAggregationOperationContext(ExposedFields exposedFields, AggregationOperationContext rootContext) { Assert.notNull(exposedFields, "ExposedFields must not be null!"); + Assert.notNull(rootContext, "RootContext must not be null!"); + this.exposedFields = exposedFields; + this.rootContext = rootContext; } /* @@ -50,7 +56,7 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo */ @Override public DBObject getMappedObject(DBObject dbObject) { - return dbObject; + return rootContext.getMappedObject(dbObject); } /* diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContextUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContextUnitTests.java index 6aae2af59..3ec62b847 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContextUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContextUnitTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2013 the original author or authors. + * Copyright 2013-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,25 +17,40 @@ package org.springframework.data.mongodb.core.aggregation; import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.*; +import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; +import java.util.Arrays; +import java.util.List; + +import org.bson.types.ObjectId; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.core.convert.converter.Converter; +import org.springframework.core.convert.support.GenericConversionService; import org.springframework.data.annotation.Id; +import org.springframework.data.annotation.PersistenceConstructor; import org.springframework.data.mapping.model.MappingException; import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; +import org.springframework.data.mongodb.core.convert.CustomConversions; import org.springframework.data.mongodb.core.convert.DbRefResolver; import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.mapping.Document; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; +import org.springframework.data.mongodb.core.query.Criteria; + +import com.mongodb.BasicDBObject; +import com.mongodb.DBObject; /** * Unit tests for {@link TypeBasedAggregationOperationContext}. * * @author Oliver Gierke + * @author Thomas Darimont */ @RunWith(MockitoJUnitRunner.class) public class TypeBasedAggregationOperationContextUnitTests { @@ -89,6 +104,104 @@ public class TypeBasedAggregationOperationContextUnitTests { assertThat(context.getReference("id"), is(new FieldReference(new ExposedField(Fields.field("id", "_id"), true)))); } + /** + * @see DATAMONGO-912 + */ + @Test + public void shouldUseCustomConversionIfPresentAndConversionIsRequiredInFirstStage() { + + CustomConversions customConversions = customAgeConversions(); + converter.setCustomConversions(customConversions); + customConversions.registerConvertersIn((GenericConversionService) converter.getConversionService()); + + AggregationOperationContext context = getContext(FooPerson.class); + + MatchOperation matchStage = match(Criteria.where("age").is(new Age(10))); + ProjectionOperation projectStage = project("age", "name"); + + DBObject agg = newAggregation(matchStage, projectStage).toDbObject("test", context); + + DBObject age = getValue((DBObject) getValue(getPipelineElementFromAggregationAt(agg, 0), "$match"), "age"); + assertThat(age, is((DBObject) new BasicDBObject("v", 10))); + } + + /** + * @see DATAMONGO-912 + */ + @Test + public void shouldUseCustomConversionIfPresentAndConversionIsRequiredInLaterStage() { + + CustomConversions customConversions = customAgeConversions(); + converter.setCustomConversions(customConversions); + customConversions.registerConvertersIn((GenericConversionService) converter.getConversionService()); + + AggregationOperationContext context = getContext(FooPerson.class); + + MatchOperation matchStage = match(Criteria.where("age").is(new Age(10))); + ProjectionOperation projectStage = project("age", "name"); + + DBObject agg = newAggregation(projectStage, matchStage).toDbObject("test", context); + + DBObject age = getValue((DBObject) getValue(getPipelineElementFromAggregationAt(agg, 1), "$match"), "age"); + assertThat(age, is((DBObject) new BasicDBObject("v", 10))); + } + + @Document(collection = "person") + public static class FooPerson { + + final ObjectId id; + final String name; + final Age age; + + @PersistenceConstructor + FooPerson(ObjectId id, String name, Age age) { + this.id = id; + this.name = name; + this.age = age; + } + } + + public static class Age { + + final int value; + + Age(int value) { + this.value = value; + } + } + + public CustomConversions customAgeConversions() { + return new CustomConversions(Arrays.> asList(ageWriteConverter(), ageReadConverter())); + } + + Converter ageWriteConverter() { + return new Converter() { + @Override + public DBObject convert(Age age) { + return new BasicDBObject("v", age.value); + } + }; + } + + Converter ageReadConverter() { + return new Converter() { + @Override + public Age convert(DBObject dbObject) { + return new Age(((Integer) dbObject.get("v"))); + } + }; + } + + @SuppressWarnings("unchecked") + static DBObject getPipelineElementFromAggregationAt(DBObject agg, int index) { + return ((List) agg.get("pipeline")).get(index); + } + + @SuppressWarnings("unchecked") + static T getValue(DBObject o, String key) { + return (T) o.get(key); + } + private TypeBasedAggregationOperationContext getContext(Class type) { return new TypeBasedAggregationOperationContext(type, context, mapper); }