Browse Source

DATAMONGO-1986 - Polishing.

Refactor duplicated code into AggregationUtil.

Original pull request: #564.
pull/585/head
Mark Paluch 8 years ago
parent
commit
41897c7d46
  1. 121
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java
  2. 21
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java
  3. 12
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java
  4. 2
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/TestEntities.java
  5. 5
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java
  6. 1
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/AbstractGeoSpatialTests.java

121
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java

@ -0,0 +1,121 @@ @@ -0,0 +1,121 @@
/*
* Copyright 2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core;
import lombok.AllArgsConstructor;
import java.util.List;
import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.util.ObjectUtils;
import com.mongodb.BasicDBList;
import com.mongodb.DBObject;
/**
* Utility methods to map {@link org.springframework.data.mongodb.core.aggregation.Aggregation} pipeline definitions and
* create type-bound {@link AggregationOperationContext}.
*
* @author Christoph Strobl
* @author Mark Paluch
* @since 1.10.13
*/
@AllArgsConstructor
class AggregationUtil {
QueryMapper queryMapper;
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext;
/**
* Prepare the {@link AggregationOperationContext} for a given aggregation by either returning the context itself it
* is not {@literal null}, create a {@link TypeBasedAggregationOperationContext} if the aggregation contains type
* information (is a {@link TypedAggregation}) or use the {@link Aggregation#DEFAULT_CONTEXT}.
*
* @param aggregation must not be {@literal null}.
* @param context can be {@literal null}.
* @return the root {@link AggregationOperationContext} to use.
*/
AggregationOperationContext prepareAggregationContext(Aggregation aggregation, AggregationOperationContext context) {
if (context != null) {
return context;
}
if (aggregation instanceof TypedAggregation) {
return new TypeBasedAggregationOperationContext(((TypedAggregation) aggregation).getInputType(), mappingContext,
queryMapper);
}
return Aggregation.DEFAULT_CONTEXT;
}
/**
* Extract and map the aggregation pipeline into a {@link List} of {@link Document}.
*
* @param aggregation
* @param context
* @return
*/
DBObject createPipeline(String collectionName, Aggregation aggregation, AggregationOperationContext context) {
if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
return aggregation.toDbObject(collectionName, context);
}
DBObject command = aggregation.toDbObject(collectionName, context);
command.put("pipeline", mapAggregationPipeline((List) command.get("pipeline")));
return command;
}
/**
* Extract the command and map the aggregation pipeline.
*
* @param aggregation
* @param context
* @return
*/
DBObject createCommand(String collection, Aggregation aggregation, AggregationOperationContext context) {
DBObject command = aggregation.toDbObject(collection, context);
if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
return command;
}
command.put("pipeline", mapAggregationPipeline((List) command.get("pipeline")));
return command;
}
private BasicDBList mapAggregationPipeline(List<DBObject> pipeline) {
BasicDBList mapped = new BasicDBList();
for (DBObject stage : pipeline) {
mapped.add(queryMapper.getMappedObject(stage, null));
}
return mapped;
}
}

21
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java

@ -1554,7 +1554,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { @@ -1554,7 +1554,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
Assert.notNull(aggregation, "Aggregation pipeline must not be null!");
Assert.notNull(outputType, "Output type must not be null!");
DBObject commandResult = new BatchAggregationLoader(this, readPreference, Integer.MAX_VALUE)
DBObject commandResult = new BatchAggregationLoader(this, queryMapper, mappingContext, readPreference,
Integer.MAX_VALUE)
.aggregate(collectionName, aggregation, context);
return new AggregationResults<O>(returnPotentiallyMappedResults(outputType, commandResult, collectionName),
@ -2555,12 +2556,18 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { @@ -2555,12 +2556,18 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
private static final String OK = "ok";
private final MongoTemplate template;
private final QueryMapper queryMapper;
private final MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext;
private final ReadPreference readPreference;
private final int batchSize;
BatchAggregationLoader(MongoTemplate template, ReadPreference readPreference, int batchSize) {
BatchAggregationLoader(MongoTemplate template, QueryMapper queryMapper,
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext,
ReadPreference readPreference, int batchSize) {
this.template = template;
this.queryMapper = queryMapper;
this.mappingContext = mappingContext;
this.readPreference = readPreference;
this.batchSize = batchSize;
}
@ -2583,11 +2590,13 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { @@ -2583,11 +2590,13 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
* Pre process the aggregation command sent to the server by adding {@code cursor} options to match execution on
* different server versions.
*/
private static DBObject prepareAggregationCommand(String collectionName, Aggregation aggregation,
AggregationOperationContext context, int batchSize) {
private DBObject prepareAggregationCommand(String collectionName, Aggregation aggregation,
AggregationOperationContext context, int batchSize) {
AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context;
DBObject command = aggregation.toDbObject(collectionName, rootContext);
AggregationUtil aggregationUtil = new AggregationUtil(queryMapper, mappingContext);
AggregationOperationContext rootContext = aggregationUtil.prepareAggregationContext(aggregation, context);
DBObject command = aggregationUtil.createCommand(collectionName, aggregation, rootContext);
if (!aggregation.getOptions().isExplain()) {
command.put(CURSOR_FIELD, new BasicDBObject(BATCH_SIZE_FIELD, batchSize));

12
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java

@ -31,6 +31,10 @@ import org.mockito.runners.MockitoJUnitRunner; @@ -31,6 +31,10 @@ import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.data.mongodb.core.MongoTemplate.BatchAggregationLoader;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
import org.springframework.data.mongodb.core.convert.DbRefResolver;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import com.mongodb.BasicDBObject;
import com.mongodb.CommandResult;
@ -39,7 +43,7 @@ import com.mongodb.ReadPreference; @@ -39,7 +43,7 @@ import com.mongodb.ReadPreference;
/**
* Unit tests for {@link BatchAggregationLoader}.
*
*
* @author Christoph Strobl
* @author Mark Paluch
*/
@ -50,6 +54,7 @@ public class BatchAggregationLoaderUnitTests { @@ -50,6 +54,7 @@ public class BatchAggregationLoaderUnitTests {
project().and("firstName").as("name"));
@Mock MongoTemplate template;
@Mock DbRefResolver dbRefResolver;
@Mock CommandResult aggregationResult;
@Mock CommandResult getMoreResult;
@ -65,7 +70,9 @@ public class BatchAggregationLoaderUnitTests { @@ -65,7 +70,9 @@ public class BatchAggregationLoaderUnitTests {
@Before
public void setUp() {
loader = new BatchAggregationLoader(template, ReadPreference.primary(), 10);
MongoMappingContext context = new MongoMappingContext();
loader = new BatchAggregationLoader(template, new QueryMapper(new MappingMongoConverter(dbRefResolver, context)),
context, ReadPreference.primary(), 10);
}
@Test // DATAMONGO-1824
@ -89,6 +96,7 @@ public class BatchAggregationLoaderUnitTests { @@ -89,6 +96,7 @@ public class BatchAggregationLoaderUnitTests {
when(aggregationResult.get("cursor")).thenReturn(cursorWithoutMore);
DBObject result = loader.aggregate("person", AGGREGATION, Aggregation.DEFAULT_CONTEXT);
assertThat((List<Object>) result.get("result"),
IsCollectionContaining.<Object> hasItem(luke));

2
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/TestEntities.java

@ -84,7 +84,7 @@ public class TestEntities { @@ -84,7 +84,7 @@ public class TestEntities {
public List<Venue> newYork() {
List<Venue> venues = new ArrayList<>();
List<Venue> venues = new ArrayList<Venue>();
venues.add(pennStation());
venues.add(tenGenOffice());

5
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java

@ -1741,7 +1741,7 @@ public class AggregationTests { @@ -1741,7 +1741,7 @@ public class AggregationTests {
.within(new Box(new Point(-73.99756, 40.73083), new Point(-73.988135, 40.741404)))),
project("id", "location", "name"));
AggregationResults<Document> groupResults = mongoTemplate.aggregate(aggregation, "newyork", Document.class);
AggregationResults<DBObject> groupResults = mongoTemplate.aggregate(aggregation, "newyork", DBObject.class);
assertThat(groupResults.getMappedResults().size(), is(4));
}
@ -1756,7 +1756,8 @@ public class AggregationTests { @@ -1756,7 +1756,8 @@ public class AggregationTests {
.within(new Box(new Point(-73.99756, 40.73083), new Point(-73.988135, 40.741404)))),
project("id", "location", "name"));
AggregationResults<Document> groupResults = mongoTemplate.aggregate(aggregation, "newyork", Document.class);
AggregationResults<DBObject
> groupResults = mongoTemplate.aggregate(aggregation, "newyork", DBObject.class);
assertThat(groupResults.getMappedResults().size(), is(4));
}

1
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/AbstractGeoSpatialTests.java

@ -45,6 +45,7 @@ import org.springframework.data.mongodb.core.query.Query; @@ -45,6 +45,7 @@ import org.springframework.data.mongodb.core.query.Query;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import com.mongodb.Mongo;
import com.mongodb.MongoClient;
import com.mongodb.WriteConcern;

Loading…
Cancel
Save