@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2013 the original author or authors .
* Copyright 2013 - 2014 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -17,25 +17,40 @@ package org.springframework.data.mongodb.core.aggregation;
@@ -17,25 +17,40 @@ package org.springframework.data.mongodb.core.aggregation;
import static org.hamcrest.CoreMatchers.* ;
import static org.junit.Assert.* ;
import static org.springframework.data.mongodb.core.aggregation.Aggregation.* ;
import java.util.Arrays ;
import java.util.List ;
import org.bson.types.ObjectId ;
import org.junit.Before ;
import org.junit.Test ;
import org.junit.runner.RunWith ;
import org.mockito.Mock ;
import org.mockito.runners.MockitoJUnitRunner ;
import org.springframework.core.convert.converter.Converter ;
import org.springframework.core.convert.support.GenericConversionService ;
import org.springframework.data.annotation.Id ;
import org.springframework.data.annotation.PersistenceConstructor ;
import org.springframework.data.mapping.model.MappingException ;
import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField ;
import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference ;
import org.springframework.data.mongodb.core.convert.CustomConversions ;
import org.springframework.data.mongodb.core.convert.DbRefResolver ;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter ;
import org.springframework.data.mongodb.core.convert.QueryMapper ;
import org.springframework.data.mongodb.core.mapping.Document ;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext ;
import org.springframework.data.mongodb.core.query.Criteria ;
import com.mongodb.BasicDBObject ;
import com.mongodb.DBObject ;
/ * *
* Unit tests for { @link TypeBasedAggregationOperationContext } .
*
* @author Oliver Gierke
* @author Thomas Darimont
* /
@RunWith ( MockitoJUnitRunner . class )
public class TypeBasedAggregationOperationContextUnitTests {
@ -89,6 +104,104 @@ public class TypeBasedAggregationOperationContextUnitTests {
@@ -89,6 +104,104 @@ public class TypeBasedAggregationOperationContextUnitTests {
assertThat ( context . getReference ( "id" ) , is ( new FieldReference ( new ExposedField ( Fields . field ( "id" , "_id" ) , true ) ) ) ) ;
}
/ * *
* @see DATAMONGO - 912
* /
@Test
public void shouldUseCustomConversionIfPresentAndConversionIsRequiredInFirstStage ( ) {
CustomConversions customConversions = customAgeConversions ( ) ;
converter . setCustomConversions ( customConversions ) ;
customConversions . registerConvertersIn ( ( GenericConversionService ) converter . getConversionService ( ) ) ;
AggregationOperationContext context = getContext ( FooPerson . class ) ;
MatchOperation matchStage = match ( Criteria . where ( "age" ) . is ( new Age ( 10 ) ) ) ;
ProjectionOperation projectStage = project ( "age" , "name" ) ;
DBObject agg = newAggregation ( matchStage , projectStage ) . toDbObject ( "test" , context ) ;
DBObject age = getValue ( ( DBObject ) getValue ( getPipelineElementFromAggregationAt ( agg , 0 ) , "$match" ) , "age" ) ;
assertThat ( age , is ( ( DBObject ) new BasicDBObject ( "v" , 10 ) ) ) ;
}
/ * *
* @see DATAMONGO - 912
* /
@Test
public void shouldUseCustomConversionIfPresentAndConversionIsRequiredInLaterStage ( ) {
CustomConversions customConversions = customAgeConversions ( ) ;
converter . setCustomConversions ( customConversions ) ;
customConversions . registerConvertersIn ( ( GenericConversionService ) converter . getConversionService ( ) ) ;
AggregationOperationContext context = getContext ( FooPerson . class ) ;
MatchOperation matchStage = match ( Criteria . where ( "age" ) . is ( new Age ( 10 ) ) ) ;
ProjectionOperation projectStage = project ( "age" , "name" ) ;
DBObject agg = newAggregation ( projectStage , matchStage ) . toDbObject ( "test" , context ) ;
DBObject age = getValue ( ( DBObject ) getValue ( getPipelineElementFromAggregationAt ( agg , 1 ) , "$match" ) , "age" ) ;
assertThat ( age , is ( ( DBObject ) new BasicDBObject ( "v" , 10 ) ) ) ;
}
@Document ( collection = "person" )
public static class FooPerson {
final ObjectId id ;
final String name ;
final Age age ;
@PersistenceConstructor
FooPerson ( ObjectId id , String name , Age age ) {
this . id = id ;
this . name = name ;
this . age = age ;
}
}
public static class Age {
final int value ;
Age ( int value ) {
this . value = value ;
}
}
public CustomConversions customAgeConversions ( ) {
return new CustomConversions ( Arrays . < Converter < ? , ? > > asList ( ageWriteConverter ( ) , ageReadConverter ( ) ) ) ;
}
Converter < Age , DBObject > ageWriteConverter ( ) {
return new Converter < Age , DBObject > ( ) {
@Override
public DBObject convert ( Age age ) {
return new BasicDBObject ( "v" , age . value ) ;
}
} ;
}
Converter < DBObject , Age > ageReadConverter ( ) {
return new Converter < DBObject , Age > ( ) {
@Override
public Age convert ( DBObject dbObject ) {
return new Age ( ( ( Integer ) dbObject . get ( "v" ) ) ) ;
}
} ;
}
@SuppressWarnings ( "unchecked" )
static DBObject getPipelineElementFromAggregationAt ( DBObject agg , int index ) {
return ( ( List < DBObject > ) agg . get ( "pipeline" ) ) . get ( index ) ;
}
@SuppressWarnings ( "unchecked" )
static < T > T getValue ( DBObject o , String key ) {
return ( T ) o . get ( key ) ;
}
private TypeBasedAggregationOperationContext getContext ( Class < ? > type ) {
return new TypeBasedAggregationOperationContext ( type , context , mapper ) ;
}