diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 59963bc80..9a7db57a3 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -2550,8 +2550,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { /** * {@link BatchAggregationLoader} is a little helper that can process cursor results returned by an aggregation * command execution. On presence of a {@literal nextBatch} indicated by presence of an {@code id} field in the - * {@code cursor} another {@code getMore} command gets executed reading the next batch of documents until everything - * has been loaded. + * {@code cursor} another {@code getMore} command gets executed reading the next batch of documents until all results + * are loaded. * * @author Christoph Strobl * @since 1.10 @@ -2561,6 +2561,10 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { private static final String CURSOR_FIELD = "cursor"; private static final String RESULT_FIELD = "result"; private static final String BATCH_SIZE_FIELD = "batchSize"; + private static final String FIRST_BATCH = "firstBatch"; + private static final String NEXT_BATCH = "nextBatch"; + private static final String SERVER_USED = "serverUsed"; + private static final String OK = "ok"; private final MongoTemplate template; private final ReadPreference readPreference; @@ -2573,135 +2577,118 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { this.batchSize = batchSize; } + /** + * Run aggregation command and fetch all results. + */ DBObject aggregate(String collectionName, Aggregation aggregation, AggregationOperationContext context) { - DBObject command = AggregationCommandPreparer.INSTANCE.prepareAggregationCommand(collectionName, aggregation, + DBObject command = prepareAggregationCommand(collectionName, aggregation, context, batchSize); if (LOGGER.isDebugEnabled()) { LOGGER.debug("Executing aggregation: {}", serializeToJsonSafely(command)); } - List results = aggregateBatched(collectionName, batchSize, command); - return mergeArregationCommandResults(results); + return mergeAggregationResults(aggregateBatched(command, collectionName, batchSize)); } - private DBObject mergeArregationCommandResults(List results) { - - DBObject commandResult = new BasicDBObject(); - if (results.size() == 1) { - commandResult = results.iterator().next(); - } else { - - List allResults = new ArrayList(); - - for (DBObject result : results) { - Collection foo = (Collection) result.get(RESULT_FIELD); - if (!CollectionUtils.isEmpty(foo)) { - allResults.addAll(foo); - } - } + /** + * Pre process the aggregation command sent to the server by adding {@code cursor} options to match execution on + * different server versions. + */ + private static DBObject prepareAggregationCommand(String collectionName, Aggregation aggregation, + AggregationOperationContext context, int batchSize) { - // take general info from first batch - commandResult.put("serverUsed", results.iterator().next().get("serverUsed")); - commandResult.put("ok", results.iterator().next().get("ok")); + AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context; + DBObject command = aggregation.toDbObject(collectionName, rootContext); - // and append the merged results - commandResult.put(RESULT_FIELD, allResults); + if (!aggregation.getOptions().isExplain()) { + command.put(CURSOR_FIELD, new BasicDBObject(BATCH_SIZE_FIELD, batchSize)); } - return commandResult; + + return command; } - private List aggregateBatched(String collectionName, int batchSize, DBObject command) { + private List aggregateBatched(DBObject command, String collectionName, int batchSize) { List results = new ArrayList(); - CommandResult tmp = template.executeCommand(command, readPreference); - results.add(AggregationResultPostProcessor.INSTANCE.process(command, tmp)); + CommandResult commandResult = template.executeCommand(command, readPreference); + results.add(postProcessResult(command, commandResult)); - while (hasNext(tmp)) { + while (hasNext(commandResult)) { - DBObject getMore = new BasicDBObject("getMore", getNextBatchId(tmp)) // + DBObject getMore = new BasicDBObject("getMore", getNextBatchId(commandResult)) // .append("collection", collectionName) // - .append(BATCH_SIZE_FIELD, batchSize); // + .append(BATCH_SIZE_FIELD, batchSize); - tmp = template.executeCommand(getMore, this.readPreference); - results.add(AggregationResultPostProcessor.INSTANCE.process(command, tmp)); + commandResult = template.executeCommand(getMore, this.readPreference); + results.add(postProcessResult(command, commandResult)); } return results; } - private boolean hasNext(DBObject commandResult) { + private static DBObject postProcessResult(DBObject command, CommandResult commandResult) { + + handleCommandError(commandResult, command); if (!commandResult.containsField(CURSOR_FIELD)) { - return false; + return commandResult; } - Object next = getNextBatchId(commandResult); - return (next == null || ((Number) next).longValue() == 0L) ? false : true; - } + DBObject resultObject = new BasicDBObject(SERVER_USED, commandResult.get(SERVER_USED)); + resultObject.put(OK, commandResult.get(OK)); - private Object getNextBatchId(DBObject commandResult) { - return ((DBObject) commandResult.get(CURSOR_FIELD)).get("id"); + DBObject cursor = (DBObject) commandResult.get(CURSOR_FIELD); + if (cursor.containsField(FIRST_BATCH)) { + resultObject.put(RESULT_FIELD, cursor.get(FIRST_BATCH)); + } else { + resultObject.put(RESULT_FIELD, cursor.get(NEXT_BATCH)); + } + + return resultObject; } - /** - * Helper to pre process the aggregation command sent to the server by adding {@code cursor} options to match - * execution on different server versions. - * - * @author Christoph Strobl - * @since 1.10 - */ - private static enum AggregationCommandPreparer { + private static DBObject mergeAggregationResults(List batchResults) { - INSTANCE; + if (batchResults.size() == 1) { + return batchResults.iterator().next(); + } - DBObject prepareAggregationCommand(String collectionName, Aggregation aggregation, - AggregationOperationContext context, int batchSize) { + DBObject commandResult = new BasicDBObject(3); + List allResults = new ArrayList(); - AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context; - DBObject command = aggregation.toDbObject(collectionName, rootContext); + for (DBObject batchResult : batchResults) { - if (!aggregation.getOptions().isExplain()) { - command.put(CURSOR_FIELD, new BasicDBObject(BATCH_SIZE_FIELD, batchSize)); + Collection documents = (Collection) batchResult.get(RESULT_FIELD); + if (!CollectionUtils.isEmpty(documents)) { + allResults.addAll(documents); } - - return command; } - } - /** - * Helper to post process aggregation command result by copying over required attributes. - * - * @author Christoph Strobl - * @since 1.10 - */ - private static enum AggregationResultPostProcessor { + // take general info from first batch + commandResult.put(SERVER_USED, batchResults.iterator().next().get(SERVER_USED)); + commandResult.put(OK, batchResults.iterator().next().get(OK)); - INSTANCE; + // and append the merged batchResults + commandResult.put(RESULT_FIELD, allResults); - DBObject process(DBObject command, CommandResult commandResult) { + return commandResult; + } - handleCommandError(commandResult, command); + private static boolean hasNext(DBObject commandResult) { - if (!commandResult.containsField(CURSOR_FIELD)) { - return commandResult; - } - - DBObject resultObject = new BasicDBObject("serverUsed", commandResult.get("serverUsed")); - resultObject.put("ok", commandResult.get("ok")); + if (!commandResult.containsField(CURSOR_FIELD)) { + return false; + } - DBObject cursor = (DBObject) commandResult.get(CURSOR_FIELD); - if (cursor.containsField("firstBatch")) { - resultObject.put(RESULT_FIELD, cursor.get("firstBatch")); - } else { - resultObject.put(RESULT_FIELD, cursor.get("nextBatch")); - } + Object next = getNextBatchId(commandResult); + return next != null && ((Number) next).longValue() != 0L; + } - return resultObject; - } + private static Object getNextBatchId(DBObject commandResult) { + return ((DBObject) commandResult.get(CURSOR_FIELD)).get("id"); } } - } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index 922f192a2..03600b9bd 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -599,8 +599,8 @@ public class Aggregation { /** * Get {@link AggregationOptions} to apply. * - * @return never {@literal null} - * @since 1.10 + * @return never {@literal null}. + * @since 1.10.10 */ public AggregationOptions getOptions() { return options; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java index 5941b465c..34838b45a 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java @@ -38,7 +38,10 @@ import com.mongodb.DBObject; import com.mongodb.ReadPreference; /** + * Unit tests for {@link BatchAggregationLoader}. + * * @author Christoph Strobl + * @author Mark Paluch */ @RunWith(MockitoJUnitRunner.class) public class BatchAggregationLoaderUnitTests { @@ -52,11 +55,13 @@ public class BatchAggregationLoaderUnitTests { BatchAggregationLoader loader; - DBObject cursorWithoutMore = new BasicDBObject("firstBatch", singletonList(new BasicDBObject("name", "luke"))); + BasicDBObject luke = new BasicDBObject("name", "luke"); + BasicDBObject han = new BasicDBObject("name", "han"); + DBObject cursorWithoutMore = new BasicDBObject("firstBatch", singletonList(luke)); DBObject cursorWithMore = new BasicDBObject("id", 123).append("firstBatch", - singletonList(new BasicDBObject("name", "luke"))); + singletonList(luke)); DBObject cursorWithNoMore = new BasicDBObject("id", 0).append("nextBatch", - singletonList(new BasicDBObject("name", "han"))); + singletonList(han)); @Before public void setUp() { @@ -64,7 +69,20 @@ public class BatchAggregationLoaderUnitTests { } @Test // DATAMONGO-1824 - public void shouldLoadJustOneBatchWhenAlreayDoneWithFirst() { + public void shouldLoadWithoutCursor() { + + when(template.executeCommand(any(DBObject.class), any(ReadPreference.class))).thenReturn(aggregationResult); + when(aggregationResult.get("result")).thenReturn(singletonList(luke)); + + DBObject result = loader.aggregate("person", AGGREGATION, Aggregation.DEFAULT_CONTEXT); + assertThat((List) result.get("result"), IsCollectionContaining. hasItem(luke)); + + verify(template).executeCommand(any(DBObject.class), any(ReadPreference.class)); + verifyNoMoreInteractions(template); + } + + @Test // DATAMONGO-1824 + public void shouldLoadJustOneBatchWhenAlreadyDoneWithFirst() { when(template.executeCommand(any(DBObject.class), any(ReadPreference.class))).thenReturn(aggregationResult); when(aggregationResult.containsField("cursor")).thenReturn(true); @@ -72,7 +90,7 @@ public class BatchAggregationLoaderUnitTests { DBObject result = loader.aggregate("person", AGGREGATION, Aggregation.DEFAULT_CONTEXT); assertThat((List) result.get("result"), - IsCollectionContaining. hasItem(new BasicDBObject("name", "luke"))); + IsCollectionContaining. hasItem(luke)); verify(template).executeCommand(any(DBObject.class), any(ReadPreference.class)); verifyNoMoreInteractions(template); @@ -90,7 +108,7 @@ public class BatchAggregationLoaderUnitTests { DBObject result = loader.aggregate("person", AGGREGATION, Aggregation.DEFAULT_CONTEXT); assertThat((List) result.get("result"), - IsCollectionContaining. hasItems(new BasicDBObject("name", "luke"), new BasicDBObject("name", "han"))); + IsCollectionContaining. hasItems(luke, han)); verify(template, times(2)).executeCommand(any(DBObject.class), any(ReadPreference.class)); verifyNoMoreInteractions(template);