diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 759b3ebf5..59963bc80 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -1566,15 +1566,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { Assert.notNull(aggregation, "Aggregation pipeline must not be null!"); Assert.notNull(outputType, "Output type must not be null!"); - AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context; - DBObject command = aggregation.toDbObject(collectionName, rootContext); - - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Executing aggregation: {}", serializeToJsonSafely(command)); - } - - CommandResult commandResult = executeCommand(command, this.readPreference); - handleCommandError(commandResult, command); + DBObject commandResult = new BatchAggregationLoader(this, readPreference, Integer.MAX_VALUE) + .aggregate(collectionName, aggregation, context); return new AggregationResults(returnPotentiallyMappedResults(outputType, commandResult, collectionName), commandResult); @@ -1587,7 +1580,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { * @param commandResult * @return */ - private List returnPotentiallyMappedResults(Class outputType, CommandResult commandResult, + private List returnPotentiallyMappedResults(Class outputType, DBObject commandResult, String collectionName) { @SuppressWarnings("unchecked") @@ -2094,7 +2087,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { * @param result must not be {@literal null}. * @param source must not be {@literal null}. */ - private void handleCommandError(CommandResult result, DBObject source) { + private static void handleCommandError(CommandResult result, DBObject source) { try { result.throwOnError(); @@ -2553,4 +2546,162 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { } } } + + /** + * {@link BatchAggregationLoader} is a little helper that can process cursor results returned by an aggregation + * command execution. On presence of a {@literal nextBatch} indicated by presence of an {@code id} field in the + * {@code cursor} another {@code getMore} command gets executed reading the next batch of documents until everything + * has been loaded. + * + * @author Christoph Strobl + * @since 1.10 + */ + static class BatchAggregationLoader { + + private static final String CURSOR_FIELD = "cursor"; + private static final String RESULT_FIELD = "result"; + private static final String BATCH_SIZE_FIELD = "batchSize"; + + private final MongoTemplate template; + private final ReadPreference readPreference; + private final int batchSize; + + BatchAggregationLoader(MongoTemplate template, ReadPreference readPreference, int batchSize) { + + this.template = template; + this.readPreference = readPreference; + this.batchSize = batchSize; + } + + DBObject aggregate(String collectionName, Aggregation aggregation, AggregationOperationContext context) { + + DBObject command = AggregationCommandPreparer.INSTANCE.prepareAggregationCommand(collectionName, aggregation, + context, batchSize); + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Executing aggregation: {}", serializeToJsonSafely(command)); + } + + List results = aggregateBatched(collectionName, batchSize, command); + return mergeArregationCommandResults(results); + } + + private DBObject mergeArregationCommandResults(List results) { + + DBObject commandResult = new BasicDBObject(); + if (results.size() == 1) { + commandResult = results.iterator().next(); + } else { + + List allResults = new ArrayList(); + + for (DBObject result : results) { + Collection foo = (Collection) result.get(RESULT_FIELD); + if (!CollectionUtils.isEmpty(foo)) { + allResults.addAll(foo); + } + } + + // take general info from first batch + commandResult.put("serverUsed", results.iterator().next().get("serverUsed")); + commandResult.put("ok", results.iterator().next().get("ok")); + + // and append the merged results + commandResult.put(RESULT_FIELD, allResults); + } + return commandResult; + } + + private List aggregateBatched(String collectionName, int batchSize, DBObject command) { + + List results = new ArrayList(); + + CommandResult tmp = template.executeCommand(command, readPreference); + results.add(AggregationResultPostProcessor.INSTANCE.process(command, tmp)); + + while (hasNext(tmp)) { + + DBObject getMore = new BasicDBObject("getMore", getNextBatchId(tmp)) // + .append("collection", collectionName) // + .append(BATCH_SIZE_FIELD, batchSize); // + + tmp = template.executeCommand(getMore, this.readPreference); + results.add(AggregationResultPostProcessor.INSTANCE.process(command, tmp)); + } + + return results; + } + + private boolean hasNext(DBObject commandResult) { + + if (!commandResult.containsField(CURSOR_FIELD)) { + return false; + } + + Object next = getNextBatchId(commandResult); + return (next == null || ((Number) next).longValue() == 0L) ? false : true; + } + + private Object getNextBatchId(DBObject commandResult) { + return ((DBObject) commandResult.get(CURSOR_FIELD)).get("id"); + } + + /** + * Helper to pre process the aggregation command sent to the server by adding {@code cursor} options to match + * execution on different server versions. + * + * @author Christoph Strobl + * @since 1.10 + */ + private static enum AggregationCommandPreparer { + + INSTANCE; + + DBObject prepareAggregationCommand(String collectionName, Aggregation aggregation, + AggregationOperationContext context, int batchSize) { + + AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context; + DBObject command = aggregation.toDbObject(collectionName, rootContext); + + if (!aggregation.getOptions().isExplain()) { + command.put(CURSOR_FIELD, new BasicDBObject(BATCH_SIZE_FIELD, batchSize)); + } + + return command; + } + } + + /** + * Helper to post process aggregation command result by copying over required attributes. + * + * @author Christoph Strobl + * @since 1.10 + */ + private static enum AggregationResultPostProcessor { + + INSTANCE; + + DBObject process(DBObject command, CommandResult commandResult) { + + handleCommandError(commandResult, command); + + if (!commandResult.containsField(CURSOR_FIELD)) { + return commandResult; + } + + DBObject resultObject = new BasicDBObject("serverUsed", commandResult.get("serverUsed")); + resultObject.put("ok", commandResult.get("ok")); + + DBObject cursor = (DBObject) commandResult.get(CURSOR_FIELD); + if (cursor.containsField("firstBatch")) { + resultObject.put(RESULT_FIELD, cursor.get("firstBatch")); + } else { + resultObject.put(RESULT_FIELD, cursor.get("nextBatch")); + } + + return resultObject; + } + } + } + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index 4798d1e1f..922f192a2 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -596,6 +596,16 @@ public class Aggregation { return SerializationUtils.serializeToJsonSafely(toDbObject("__collection__", DEFAULT_CONTEXT)); } + /** + * Get {@link AggregationOptions} to apply. + * + * @return never {@literal null} + * @since 1.10 + */ + public AggregationOptions getOptions() { + return options; + } + /** * Describes the system variables available in MongoDB aggregation framework pipeline expressions. * diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java new file mode 100644 index 000000000..5941b465c --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java @@ -0,0 +1,98 @@ +/* + * Copyright 2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core; + +import static java.util.Collections.*; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; +import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; + +import java.util.List; + +import org.hamcrest.core.IsCollectionContaining; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.data.mongodb.core.MongoTemplate.BatchAggregationLoader; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.TypedAggregation; + +import com.mongodb.BasicDBObject; +import com.mongodb.CommandResult; +import com.mongodb.DBObject; +import com.mongodb.ReadPreference; + +/** + * @author Christoph Strobl + */ +@RunWith(MockitoJUnitRunner.class) +public class BatchAggregationLoaderUnitTests { + + static final TypedAggregation AGGREGATION = newAggregation(Person.class, + project().and("firstName").as("name")); + + @Mock MongoTemplate template; + @Mock CommandResult aggregationResult; + @Mock CommandResult getMoreResult; + + BatchAggregationLoader loader; + + DBObject cursorWithoutMore = new BasicDBObject("firstBatch", singletonList(new BasicDBObject("name", "luke"))); + DBObject cursorWithMore = new BasicDBObject("id", 123).append("firstBatch", + singletonList(new BasicDBObject("name", "luke"))); + DBObject cursorWithNoMore = new BasicDBObject("id", 0).append("nextBatch", + singletonList(new BasicDBObject("name", "han"))); + + @Before + public void setUp() { + loader = new BatchAggregationLoader(template, ReadPreference.primary(), 10); + } + + @Test // DATAMONGO-1824 + public void shouldLoadJustOneBatchWhenAlreayDoneWithFirst() { + + when(template.executeCommand(any(DBObject.class), any(ReadPreference.class))).thenReturn(aggregationResult); + when(aggregationResult.containsField("cursor")).thenReturn(true); + when(aggregationResult.get("cursor")).thenReturn(cursorWithoutMore); + + DBObject result = loader.aggregate("person", AGGREGATION, Aggregation.DEFAULT_CONTEXT); + assertThat((List) result.get("result"), + IsCollectionContaining. hasItem(new BasicDBObject("name", "luke"))); + + verify(template).executeCommand(any(DBObject.class), any(ReadPreference.class)); + verifyNoMoreInteractions(template); + } + + @Test // DATAMONGO-1824 + public void shouldBatchLoadWhenRequired() { + + when(template.executeCommand(any(DBObject.class), any(ReadPreference.class))).thenReturn(aggregationResult) + .thenReturn(getMoreResult); + when(aggregationResult.containsField("cursor")).thenReturn(true); + when(aggregationResult.get("cursor")).thenReturn(cursorWithMore); + when(getMoreResult.containsField("cursor")).thenReturn(true); + when(getMoreResult.get("cursor")).thenReturn(cursorWithNoMore); + + DBObject result = loader.aggregate("person", AGGREGATION, Aggregation.DEFAULT_CONTEXT); + assertThat((List) result.get("result"), + IsCollectionContaining. hasItems(new BasicDBObject("name", "luke"), new BasicDBObject("name", "han"))); + + verify(template, times(2)).executeCommand(any(DBObject.class), any(ReadPreference.class)); + verifyNoMoreInteractions(template); + } +}