@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2011 - 2012 the original author or authors .
* Copyright 2013 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -15,9 +15,8 @@
@@ -15,9 +15,8 @@
* /
package org.springframework.data.mongodb.core.aggregation ;
import static org.hamcrest.CoreMatchers.is ;
import static org.hamcrest.CoreMatchers.notNullValue ;
import static org.junit.Assert.assertThat ;
import static org.hamcrest.CoreMatchers.* ;
import static org.junit.Assert.* ;
import java.util.ArrayList ;
import java.util.List ;
@ -41,6 +40,7 @@ import com.mongodb.DBObject;
@@ -41,6 +40,7 @@ import com.mongodb.DBObject;
/ * *
* Tests for { @link MongoTemplate # aggregate ( String , AggregationPipeline , Class ) } .
*
* @see DATAMONGO - 586
* @author Tobias Trelle
* /
@RunWith ( SpringJUnit4ClassRunner . class )
@ -79,35 +79,26 @@ public class AggregationTests {
@@ -79,35 +79,26 @@ public class AggregationTests {
@Test ( expected = IllegalArgumentException . class )
public void shouldDetectIllegalJsonInOperation ( ) {
// given
AggregationPipeline pipeline = new AggregationPipeline ( ) . project ( "{ foo bar" ) ;
// when
AggregationPipeline pipeline = new AggregationPipeline ( ) . project ( "{ foo bar" ) ;
mongoTemplate . aggregate ( INPUT_COLLECTION , pipeline , TagCount . class ) ;
// then: throw expected exception
}
@Test
public void shouldAggregate ( ) {
// given
createDocuments ( ) ;
AggregationPipeline pipeline = new AggregationPipeline ( )
. project ( "{_id:0,tags:1}}" )
. unwind ( "tags" )
. group ( "{_id:\"$tags\", n:{$sum:1}}" )
. project ( "{tag: \"$_id\", n:1, _id:0}" )
. sort ( new Sort ( new Sort . Order ( Direction . DESC , "n" ) ) ) ;
// when
AggregationResults < TagCount > results = mongoTemplate . aggregate ( INPUT_COLLECTION , pipeline , TagCount . class ) ;
// then
assertThat ( results , notNullValue ( ) ) ;
AggregationPipeline pipeline = new AggregationPipeline ( ) . project ( "{_id:0,tags:1}}" ) . unwind ( "tags" )
. group ( "{_id:\"$tags\", n:{$sum:1}}" ) . project ( "{tag: \"$_id\", n:1, _id:0}" )
. sort ( new Sort ( new Sort . Order ( Direction . DESC , "n" ) ) ) ;
AggregationResults < TagCount > results = mongoTemplate . aggregate ( INPUT_COLLECTION , pipeline , TagCount . class ) ;
assertThat ( results , is ( notNullValue ( ) ) ) ;
assertThat ( results . getServerUsed ( ) , is ( "/127.0.0.1:27017" ) ) ;
List < TagCount > tagCount = results . getAggregationResult ( ) ;
assertThat ( tagCount , notNullValue ( ) ) ;
assertThat ( tagCount , is ( notNullValue ( ) ) ) ;
assertThat ( tagCount . size ( ) , is ( 3 ) ) ;
assertTagCount ( "spring" , 3 , tagCount . get ( 0 ) ) ;
assertTagCount ( "mongodb" , 2 , tagCount . get ( 1 ) ) ;
@ -116,87 +107,73 @@ public class AggregationTests {
@@ -116,87 +107,73 @@ public class AggregationTests {
@Test ( expected = InvalidDataAccessApiUsageException . class )
public void shouldDetectIllegalAggregationOperation ( ) {
// given
createDocuments ( ) ;
AggregationPipeline pipeline = new AggregationPipeline ( ) . project ( "{$foobar:{_id:0,tags:1}}" ) ;
// when
mongoTemplate . aggregate ( INPUT_COLLECTION , pipeline , TagCount . class ) ;
// then: throw expected exception
}
@Test
public void shouldAggregateEmptyCollection ( ) {
// given
AggregationPipeline pipeline = new AggregationPipeline ( )
. project ( "{_id:0,tags:1}}" )
. unwind ( "$tags" )
. group ( "{_id:\"$tags\", n:{$sum:1}}" )
. project ( "{tag: \"$_id\", n:1, _id:0}" )
. sort ( "{n:-1}" ) ;
// when
AggregationResults < TagCount > results = mongoTemplate . aggregate ( INPUT_COLLECTION , pipeline , TagCount . class ) ;
// then
assertThat ( results , notNullValue ( ) ) ;
AggregationPipeline pipeline = new AggregationPipeline ( ) . project ( "{_id:0,tags:1}}" ) . unwind ( "$tags" )
. group ( "{_id:\"$tags\", n:{$sum:1}}" ) . project ( "{tag: \"$_id\", n:1, _id:0}" ) . sort ( "{n:-1}" ) ;
AggregationResults < TagCount > results = mongoTemplate . aggregate ( INPUT_COLLECTION , pipeline , TagCount . class ) ;
assertThat ( results , is ( notNullValue ( ) ) ) ;
assertThat ( results . getServerUsed ( ) , is ( "/127.0.0.1:27017" ) ) ;
List < TagCount > tagCount = results . getAggregationResult ( ) ;
assertThat ( tagCount , notNullValue ( ) ) ;
assertThat ( tagCount , is ( notNullValue ( ) ) ) ;
assertThat ( tagCount . size ( ) , is ( 0 ) ) ;
}
@Test
public void shouldDetectResultMismatch ( ) {
// given
createDocuments ( ) ;
AggregationPipeline pipeline = new AggregationPipeline ( )
. project ( "{_id:0,tags:1}}" )
. unwind ( "$tags" )
. group ( "{_id:\"$tags\", count:{$sum:1}}" )
. limit ( 2 ) ;
// when
AggregationResults < TagCount > results = mongoTemplate . aggregate ( INPUT_COLLECTION , pipeline , TagCount . class ) ;
// then
assertThat ( results , notNullValue ( ) ) ;
createDocuments ( ) ;
AggregationPipeline pipeline = new AggregationPipeline ( ) . project ( "{_id:0,tags:1}}" ) . unwind ( "$tags" )
. group ( "{_id:\"$tags\", count:{$sum:1}}" ) . limit ( 2 ) ;
AggregationResults < TagCount > results = mongoTemplate . aggregate ( INPUT_COLLECTION , pipeline , TagCount . class ) ;
assertThat ( results , is ( notNullValue ( ) ) ) ;
assertThat ( results . getServerUsed ( ) , is ( "/127.0.0.1:27017" ) ) ;
List < TagCount > tagCount = results . getAggregationResult ( ) ;
assertThat ( tagCount , notNullValue ( ) ) ;
assertThat ( tagCount , is ( notNullValue ( ) ) ) ;
assertThat ( tagCount . size ( ) , is ( 2 ) ) ;
assertTagCount ( null , 0 , tagCount . get ( 0 ) ) ;
assertTagCount ( null , 0 , tagCount . get ( 1 ) ) ;
}
protected void cleanDb ( ) {
mongoTemplate . dropCollection ( INPUT_COLLECTION ) ;
}
private void createDocuments ( ) {
DBCollection coll = mongoTemplate . getCollection ( INPUT_COLLECTION ) ;
coll . insert ( createDocument ( "Doc1" , "spring" , "mongodb" , "nosql" ) ) ;
coll . insert ( createDocument ( "Doc2" , "spring" , "mongodb" ) ) ;
coll . insert ( createDocument ( "Doc3" , "spring" ) ) ;
}
private DBObject createDocument ( String title , String . . . tags ) {
private static DBObject createDocument ( String title , String . . . tags ) {
DBObject doc = new BasicDBObject ( "title" , title ) ;
List < String > tagList = new ArrayList < String > ( ) ;
for ( String tag : tags ) {
tagList . add ( tag ) ;
}
doc . put ( "tags" , tagList ) ;
doc . put ( "tags" , tagList ) ;
return doc ;
}
private void assertTagCount ( String tag , int n , TagCount tagCount ) {
private static void assertTagCount ( String tag , int n , TagCount tagCount ) {
assertThat ( tagCount . getTag ( ) , is ( tag ) ) ;
assertThat ( tagCount . getN ( ) , is ( n ) ) ;
}
}