@ -2550,8 +2550,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
@@ -2550,8 +2550,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
/ * *
* { @link BatchAggregationLoader } is a little helper that can process cursor results returned by an aggregation
* command execution . On presence of a { @literal nextBatch } indicated by presence of an { @code id } field in the
* { @code cursor } another { @code getMore } command gets executed reading the next batch of documents until everything
* has been loaded .
* { @code cursor } another { @code getMore } command gets executed reading the next batch of documents until all results
* are loaded .
*
* @author Christoph Strobl
* @since 1 . 10
@ -2561,6 +2561,10 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
@@ -2561,6 +2561,10 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
private static final String CURSOR_FIELD = "cursor" ;
private static final String RESULT_FIELD = "result" ;
private static final String BATCH_SIZE_FIELD = "batchSize" ;
private static final String FIRST_BATCH = "firstBatch" ;
private static final String NEXT_BATCH = "nextBatch" ;
private static final String SERVER_USED = "serverUsed" ;
private static final String OK = "ok" ;
private final MongoTemplate template ;
private final ReadPreference readPreference ;
@ -2573,135 +2577,118 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
@@ -2573,135 +2577,118 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
this . batchSize = batchSize ;
}
/ * *
* Run aggregation command and fetch all results .
* /
DBObject aggregate ( String collectionName , Aggregation aggregation , AggregationOperationContext context ) {
DBObject command = AggregationCommandPreparer . INSTANCE . prepareAggregationCommand ( collectionName , aggregation ,
DBObject command = prepareAggregationCommand ( collectionName , aggregation ,
context , batchSize ) ;
if ( LOGGER . isDebugEnabled ( ) ) {
LOGGER . debug ( "Executing aggregation: {}" , serializeToJsonSafely ( command ) ) ;
}
List < DBObject > results = aggregateBatched ( collectionName , batchSize , command ) ;
return mergeArregationCommandResults ( results ) ;
return mergeAggregationResults ( aggregateBatched ( command , collectionName , batchSize ) ) ;
}
private DBObject mergeArregationCommandResults ( List < DBObject > results ) {
DBObject commandResult = new BasicDBObject ( ) ;
if ( results . size ( ) = = 1 ) {
commandResult = results . iterator ( ) . next ( ) ;
} else {
List < Object > allResults = new ArrayList ( ) ;
for ( DBObject result : results ) {
Collection foo = ( Collection < ? > ) result . get ( RESULT_FIELD ) ;
if ( ! CollectionUtils . isEmpty ( foo ) ) {
allResults . addAll ( foo ) ;
}
}
/ * *
* Pre process the aggregation command sent to the server by adding { @code cursor } options to match execution on
* different server versions .
* /
private static DBObject prepareAggregationCommand ( String collectionName , Aggregation aggregation ,
AggregationOperationContext context , int batchSize ) {
// take general info from first batch
commandResult . put ( "serverUsed" , results . iterator ( ) . next ( ) . get ( "serverUsed" ) ) ;
commandResult . put ( "ok" , results . iterator ( ) . next ( ) . get ( "ok" ) ) ;
AggregationOperationContext rootContext = context = = null ? Aggregation . DEFAULT_CONTEXT : context ;
DBObject command = aggregation . toDbObject ( collectionName , rootContext ) ;
// and append the merged results
commandResult . put ( RESULT_FIELD , allResults ) ;
if ( ! aggregation . getOptions ( ) . isExplain ( ) ) {
command . put ( CURSOR_FIELD , new BasicDBObject ( BATCH_SIZE_FIELD , batchSize ) ) ;
}
return commandResult ;
return command ;
}
private List < DBObject > aggregateBatched ( String collectionName , int batchSize , DBObject command ) {
private List < DBObject > aggregateBatched ( DBObject command , String collectionName , int batchSize ) {
List < DBObject > results = new ArrayList < DBObject > ( ) ;
CommandResult tmp = template . executeCommand ( command , readPreference ) ;
results . add ( AggregationResultPostProcessor . INSTANCE . process ( command , tmp ) ) ;
CommandResult commandResult = template . executeCommand ( command , readPreference ) ;
results . add ( postProcessResult ( command , commandResult ) ) ;
while ( hasNext ( tmp ) ) {
while ( hasNext ( commandResult ) ) {
DBObject getMore = new BasicDBObject ( "getMore" , getNextBatchId ( tmp ) ) //
DBObject getMore = new BasicDBObject ( "getMore" , getNextBatchId ( commandResult ) ) //
. append ( "collection" , collectionName ) //
. append ( BATCH_SIZE_FIELD , batchSize ) ; //
. append ( BATCH_SIZE_FIELD , batchSize ) ;
tmp = template . executeCommand ( getMore , this . readPreference ) ;
results . add ( AggregationResultPostProcessor . INSTANCE . process ( command , tmp ) ) ;
commandResult = template . executeCommand ( getMore , this . readPreference ) ;
results . add ( postProcessResult ( command , commandResult ) ) ;
}
return results ;
}
private boolean hasNext ( DBObject commandResult ) {
private static DBObject postProcessResult ( DBObject command , CommandResult commandResult ) {
handleCommandError ( commandResult , command ) ;
if ( ! commandResult . containsField ( CURSOR_FIELD ) ) {
return false ;
return commandResult ;
}
Object next = getNextBatchId ( commandResult ) ;
return ( next = = null | | ( ( Number ) next ) . longValue ( ) = = 0L ) ? false : true ;
}
DBObject resultObject = new BasicDBObject ( SERVER_USED , commandResult . get ( SERVER_USED ) ) ;
resultObject . put ( OK , commandResult . get ( OK ) ) ;
private Object getNextBatchId ( DBObject commandResult ) {
return ( ( DBObject ) commandResult . get ( CURSOR_FIELD ) ) . get ( "id" ) ;
DBObject cursor = ( DBObject ) commandResult . get ( CURSOR_FIELD ) ;
if ( cursor . containsField ( FIRST_BATCH ) ) {
resultObject . put ( RESULT_FIELD , cursor . get ( FIRST_BATCH ) ) ;
} else {
resultObject . put ( RESULT_FIELD , cursor . get ( NEXT_BATCH ) ) ;
}
return resultObject ;
}
/ * *
* Helper to pre process the aggregation command sent to the server by adding { @code cursor } options to match
* execution on different server versions .
*
* @author Christoph Strobl
* @since 1 . 10
* /
private static enum AggregationCommandPreparer {
private static DBObject mergeAggregationResults ( List < DBObject > batchResults ) {
INSTANCE ;
if ( batchResults . size ( ) = = 1 ) {
return batchResults . iterator ( ) . next ( ) ;
}
DBObject prepareAggregationCommand ( String collectionName , Aggregation aggregation ,
AggregationOperationContext context , int batchSize ) {
DBObject commandResult = new BasicDBObject ( 3 ) ;
List < Object > allResults = new ArrayList < Object > ( ) ;
AggregationOperationContext rootContext = context = = null ? Aggregation . DEFAULT_CONTEXT : context ;
DBObject command = aggregation . toDbObject ( collectionName , rootContext ) ;
for ( DBObject batchResult : batchResults ) {
if ( ! aggregation . getOptions ( ) . isExplain ( ) ) {
command . put ( CURSOR_FIELD , new BasicDBObject ( BATCH_SIZE_FIELD , batchSize ) ) ;
Collection documents = ( Collection < ? > ) batchResult . get ( RESULT_FIELD ) ;
if ( ! CollectionUtils . isEmpty ( documents ) ) {
allResults . addAll ( documents ) ;
}
return command ;
}
}
/ * *
* Helper to post process aggregation command result by copying over required attributes .
*
* @author Christoph Strobl
* @since 1 . 10
* /
private static enum AggregationResultPostProcessor {
// take general info from first batch
commandResult . put ( SERVER_USED , batchResults . iterator ( ) . next ( ) . get ( SERVER_USED ) ) ;
commandResult . put ( OK , batchResults . iterator ( ) . next ( ) . get ( OK ) ) ;
INSTANCE ;
// and append the merged batchResults
commandResult . put ( RESULT_FIELD , allResults ) ;
DBObject process ( DBObject command , CommandResult commandResult ) {
return commandResult ;
}
handleCommandError ( commandResult , command ) ;
private static boolean hasNext ( DBObject commandResult ) {
if ( ! commandResult . containsField ( CURSOR_FIELD ) ) {
return commandResult ;
}
DBObject resultObject = new BasicDBObject ( "serverUsed" , commandResult . get ( "serverUsed" ) ) ;
resultObject . put ( "ok" , commandResult . get ( "ok" ) ) ;
if ( ! commandResult . containsField ( CURSOR_FIELD ) ) {
return false ;
}
DBObject cursor = ( DBObject ) commandResult . get ( CURSOR_FIELD ) ;
if ( cursor . containsField ( "firstBatch" ) ) {
resultObject . put ( RESULT_FIELD , cursor . get ( "firstBatch" ) ) ;
} else {
resultObject . put ( RESULT_FIELD , cursor . get ( "nextBatch" ) ) ;
}
Object next = getNextBatchId ( commandResult ) ;
return next ! = null & & ( ( Number ) next ) . longValue ( ) ! = 0L ;
}
return resultObject ;
}
private static Object getNextBatchId ( DBObject commandResult ) {
return ( ( DBObject ) commandResult . get ( CURSOR_FIELD ) ) . get ( "id" ) ;
}
}
}