Browse Source

DATAMONGO-1824 - Polishing.

Move method from AggregationCommandPreparer and AggregationResultPostProcessor to BatchAggregationLoader. Extract field names to constants. Tiny renames to variables. Add unit test for aggregation response without cursor use.

Original pull request: #521.
pull/531/head
Mark Paluch 8 years ago
parent
commit
3f009053fe
  1. 155
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java
  2. 4
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java
  3. 30
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java

155
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java

@ -2550,8 +2550,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { @@ -2550,8 +2550,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
/**
* {@link BatchAggregationLoader} is a little helper that can process cursor results returned by an aggregation
* command execution. On presence of a {@literal nextBatch} indicated by presence of an {@code id} field in the
* {@code cursor} another {@code getMore} command gets executed reading the next batch of documents until everything
* has been loaded.
* {@code cursor} another {@code getMore} command gets executed reading the next batch of documents until all results
* are loaded.
*
* @author Christoph Strobl
* @since 1.10
@ -2561,6 +2561,10 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { @@ -2561,6 +2561,10 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
private static final String CURSOR_FIELD = "cursor";
private static final String RESULT_FIELD = "result";
private static final String BATCH_SIZE_FIELD = "batchSize";
private static final String FIRST_BATCH = "firstBatch";
private static final String NEXT_BATCH = "nextBatch";
private static final String SERVER_USED = "serverUsed";
private static final String OK = "ok";
private final MongoTemplate template;
private final ReadPreference readPreference;
@ -2573,135 +2577,118 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { @@ -2573,135 +2577,118 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
this.batchSize = batchSize;
}
/**
* Run aggregation command and fetch all results.
*/
DBObject aggregate(String collectionName, Aggregation aggregation, AggregationOperationContext context) {
DBObject command = AggregationCommandPreparer.INSTANCE.prepareAggregationCommand(collectionName, aggregation,
DBObject command = prepareAggregationCommand(collectionName, aggregation,
context, batchSize);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Executing aggregation: {}", serializeToJsonSafely(command));
}
List<DBObject> results = aggregateBatched(collectionName, batchSize, command);
return mergeArregationCommandResults(results);
return mergeAggregationResults(aggregateBatched(command, collectionName, batchSize));
}
private DBObject mergeArregationCommandResults(List<DBObject> results) {
DBObject commandResult = new BasicDBObject();
if (results.size() == 1) {
commandResult = results.iterator().next();
} else {
List<Object> allResults = new ArrayList();
for (DBObject result : results) {
Collection foo = (Collection<?>) result.get(RESULT_FIELD);
if (!CollectionUtils.isEmpty(foo)) {
allResults.addAll(foo);
}
}
/**
* Pre process the aggregation command sent to the server by adding {@code cursor} options to match execution on
* different server versions.
*/
private static DBObject prepareAggregationCommand(String collectionName, Aggregation aggregation,
AggregationOperationContext context, int batchSize) {
// take general info from first batch
commandResult.put("serverUsed", results.iterator().next().get("serverUsed"));
commandResult.put("ok", results.iterator().next().get("ok"));
AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context;
DBObject command = aggregation.toDbObject(collectionName, rootContext);
// and append the merged results
commandResult.put(RESULT_FIELD, allResults);
if (!aggregation.getOptions().isExplain()) {
command.put(CURSOR_FIELD, new BasicDBObject(BATCH_SIZE_FIELD, batchSize));
}
return commandResult;
return command;
}
private List<DBObject> aggregateBatched(String collectionName, int batchSize, DBObject command) {
private List<DBObject> aggregateBatched(DBObject command, String collectionName, int batchSize) {
List<DBObject> results = new ArrayList<DBObject>();
CommandResult tmp = template.executeCommand(command, readPreference);
results.add(AggregationResultPostProcessor.INSTANCE.process(command, tmp));
CommandResult commandResult = template.executeCommand(command, readPreference);
results.add(postProcessResult(command, commandResult));
while (hasNext(tmp)) {
while (hasNext(commandResult)) {
DBObject getMore = new BasicDBObject("getMore", getNextBatchId(tmp)) //
DBObject getMore = new BasicDBObject("getMore", getNextBatchId(commandResult)) //
.append("collection", collectionName) //
.append(BATCH_SIZE_FIELD, batchSize); //
.append(BATCH_SIZE_FIELD, batchSize);
tmp = template.executeCommand(getMore, this.readPreference);
results.add(AggregationResultPostProcessor.INSTANCE.process(command, tmp));
commandResult = template.executeCommand(getMore, this.readPreference);
results.add(postProcessResult(command, commandResult));
}
return results;
}
private boolean hasNext(DBObject commandResult) {
private static DBObject postProcessResult(DBObject command, CommandResult commandResult) {
handleCommandError(commandResult, command);
if (!commandResult.containsField(CURSOR_FIELD)) {
return false;
return commandResult;
}
Object next = getNextBatchId(commandResult);
return (next == null || ((Number) next).longValue() == 0L) ? false : true;
}
DBObject resultObject = new BasicDBObject(SERVER_USED, commandResult.get(SERVER_USED));
resultObject.put(OK, commandResult.get(OK));
private Object getNextBatchId(DBObject commandResult) {
return ((DBObject) commandResult.get(CURSOR_FIELD)).get("id");
DBObject cursor = (DBObject) commandResult.get(CURSOR_FIELD);
if (cursor.containsField(FIRST_BATCH)) {
resultObject.put(RESULT_FIELD, cursor.get(FIRST_BATCH));
} else {
resultObject.put(RESULT_FIELD, cursor.get(NEXT_BATCH));
}
return resultObject;
}
/**
* Helper to pre process the aggregation command sent to the server by adding {@code cursor} options to match
* execution on different server versions.
*
* @author Christoph Strobl
* @since 1.10
*/
private static enum AggregationCommandPreparer {
private static DBObject mergeAggregationResults(List<DBObject> batchResults) {
INSTANCE;
if (batchResults.size() == 1) {
return batchResults.iterator().next();
}
DBObject prepareAggregationCommand(String collectionName, Aggregation aggregation,
AggregationOperationContext context, int batchSize) {
DBObject commandResult = new BasicDBObject(3);
List<Object> allResults = new ArrayList<Object>();
AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context;
DBObject command = aggregation.toDbObject(collectionName, rootContext);
for (DBObject batchResult : batchResults) {
if (!aggregation.getOptions().isExplain()) {
command.put(CURSOR_FIELD, new BasicDBObject(BATCH_SIZE_FIELD, batchSize));
Collection documents = (Collection<?>) batchResult.get(RESULT_FIELD);
if (!CollectionUtils.isEmpty(documents)) {
allResults.addAll(documents);
}
return command;
}
}
/**
* Helper to post process aggregation command result by copying over required attributes.
*
* @author Christoph Strobl
* @since 1.10
*/
private static enum AggregationResultPostProcessor {
// take general info from first batch
commandResult.put(SERVER_USED, batchResults.iterator().next().get(SERVER_USED));
commandResult.put(OK, batchResults.iterator().next().get(OK));
INSTANCE;
// and append the merged batchResults
commandResult.put(RESULT_FIELD, allResults);
DBObject process(DBObject command, CommandResult commandResult) {
return commandResult;
}
handleCommandError(commandResult, command);
private static boolean hasNext(DBObject commandResult) {
if (!commandResult.containsField(CURSOR_FIELD)) {
return commandResult;
}
DBObject resultObject = new BasicDBObject("serverUsed", commandResult.get("serverUsed"));
resultObject.put("ok", commandResult.get("ok"));
if (!commandResult.containsField(CURSOR_FIELD)) {
return false;
}
DBObject cursor = (DBObject) commandResult.get(CURSOR_FIELD);
if (cursor.containsField("firstBatch")) {
resultObject.put(RESULT_FIELD, cursor.get("firstBatch"));
} else {
resultObject.put(RESULT_FIELD, cursor.get("nextBatch"));
}
Object next = getNextBatchId(commandResult);
return next != null && ((Number) next).longValue() != 0L;
}
return resultObject;
}
private static Object getNextBatchId(DBObject commandResult) {
return ((DBObject) commandResult.get(CURSOR_FIELD)).get("id");
}
}
}

4
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java

@ -599,8 +599,8 @@ public class Aggregation { @@ -599,8 +599,8 @@ public class Aggregation {
/**
* Get {@link AggregationOptions} to apply.
*
* @return never {@literal null}
* @since 1.10
* @return never {@literal null}.
* @since 1.10.10
*/
public AggregationOptions getOptions() {
return options;

30
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java

@ -38,7 +38,10 @@ import com.mongodb.DBObject; @@ -38,7 +38,10 @@ import com.mongodb.DBObject;
import com.mongodb.ReadPreference;
/**
* Unit tests for {@link BatchAggregationLoader}.
*
* @author Christoph Strobl
* @author Mark Paluch
*/
@RunWith(MockitoJUnitRunner.class)
public class BatchAggregationLoaderUnitTests {
@ -52,11 +55,13 @@ public class BatchAggregationLoaderUnitTests { @@ -52,11 +55,13 @@ public class BatchAggregationLoaderUnitTests {
BatchAggregationLoader loader;
DBObject cursorWithoutMore = new BasicDBObject("firstBatch", singletonList(new BasicDBObject("name", "luke")));
BasicDBObject luke = new BasicDBObject("name", "luke");
BasicDBObject han = new BasicDBObject("name", "han");
DBObject cursorWithoutMore = new BasicDBObject("firstBatch", singletonList(luke));
DBObject cursorWithMore = new BasicDBObject("id", 123).append("firstBatch",
singletonList(new BasicDBObject("name", "luke")));
singletonList(luke));
DBObject cursorWithNoMore = new BasicDBObject("id", 0).append("nextBatch",
singletonList(new BasicDBObject("name", "han")));
singletonList(han));
@Before
public void setUp() {
@ -64,7 +69,20 @@ public class BatchAggregationLoaderUnitTests { @@ -64,7 +69,20 @@ public class BatchAggregationLoaderUnitTests {
}
@Test // DATAMONGO-1824
public void shouldLoadJustOneBatchWhenAlreayDoneWithFirst() {
public void shouldLoadWithoutCursor() {
when(template.executeCommand(any(DBObject.class), any(ReadPreference.class))).thenReturn(aggregationResult);
when(aggregationResult.get("result")).thenReturn(singletonList(luke));
DBObject result = loader.aggregate("person", AGGREGATION, Aggregation.DEFAULT_CONTEXT);
assertThat((List<Object>) result.get("result"), IsCollectionContaining.<Object> hasItem(luke));
verify(template).executeCommand(any(DBObject.class), any(ReadPreference.class));
verifyNoMoreInteractions(template);
}
@Test // DATAMONGO-1824
public void shouldLoadJustOneBatchWhenAlreadyDoneWithFirst() {
when(template.executeCommand(any(DBObject.class), any(ReadPreference.class))).thenReturn(aggregationResult);
when(aggregationResult.containsField("cursor")).thenReturn(true);
@ -72,7 +90,7 @@ public class BatchAggregationLoaderUnitTests { @@ -72,7 +90,7 @@ public class BatchAggregationLoaderUnitTests {
DBObject result = loader.aggregate("person", AGGREGATION, Aggregation.DEFAULT_CONTEXT);
assertThat((List<Object>) result.get("result"),
IsCollectionContaining.<Object> hasItem(new BasicDBObject("name", "luke")));
IsCollectionContaining.<Object> hasItem(luke));
verify(template).executeCommand(any(DBObject.class), any(ReadPreference.class));
verifyNoMoreInteractions(template);
@ -90,7 +108,7 @@ public class BatchAggregationLoaderUnitTests { @@ -90,7 +108,7 @@ public class BatchAggregationLoaderUnitTests {
DBObject result = loader.aggregate("person", AGGREGATION, Aggregation.DEFAULT_CONTEXT);
assertThat((List<Object>) result.get("result"),
IsCollectionContaining.<Object> hasItems(new BasicDBObject("name", "luke"), new BasicDBObject("name", "han")));
IsCollectionContaining.<Object> hasItems(luke, han));
verify(template, times(2)).executeCommand(any(DBObject.class), any(ReadPreference.class));
verifyNoMoreInteractions(template);

Loading…
Cancel
Save