From 41897c7d461a1ac61cf1118b358322796198342a Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 6 Jun 2018 09:53:54 +0200 Subject: [PATCH] DATAMONGO-1986 - Polishing. Refactor duplicated code into AggregationUtil. Original pull request: #564. --- .../data/mongodb/core/AggregationUtil.java | 121 ++++++++++++++++++ .../data/mongodb/core/MongoTemplate.java | 21 ++- .../core/BatchAggregationLoaderUnitTests.java | 12 +- .../data/mongodb/core/TestEntities.java | 2 +- .../core/aggregation/AggregationTests.java | 5 +- .../core/geo/AbstractGeoSpatialTests.java | 1 + 6 files changed, 151 insertions(+), 11 deletions(-) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java new file mode 100644 index 000000000..e0070cc5f --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java @@ -0,0 +1,121 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core; + +import lombok.AllArgsConstructor; + +import java.util.List; + +import org.springframework.data.mapping.context.MappingContext; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; +import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext; +import org.springframework.data.mongodb.core.aggregation.TypedAggregation; +import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; +import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; +import org.springframework.util.ObjectUtils; + +import com.mongodb.BasicDBList; +import com.mongodb.DBObject; + +/** + * Utility methods to map {@link org.springframework.data.mongodb.core.aggregation.Aggregation} pipeline definitions and + * create type-bound {@link AggregationOperationContext}. + * + * @author Christoph Strobl + * @author Mark Paluch + * @since 1.10.13 + */ +@AllArgsConstructor +class AggregationUtil { + + QueryMapper queryMapper; + MappingContext, MongoPersistentProperty> mappingContext; + + /** + * Prepare the {@link AggregationOperationContext} for a given aggregation by either returning the context itself it + * is not {@literal null}, create a {@link TypeBasedAggregationOperationContext} if the aggregation contains type + * information (is a {@link TypedAggregation}) or use the {@link Aggregation#DEFAULT_CONTEXT}. + * + * @param aggregation must not be {@literal null}. + * @param context can be {@literal null}. + * @return the root {@link AggregationOperationContext} to use. + */ + AggregationOperationContext prepareAggregationContext(Aggregation aggregation, AggregationOperationContext context) { + + if (context != null) { + return context; + } + + if (aggregation instanceof TypedAggregation) { + return new TypeBasedAggregationOperationContext(((TypedAggregation) aggregation).getInputType(), mappingContext, + queryMapper); + } + + return Aggregation.DEFAULT_CONTEXT; + } + + /** + * Extract and map the aggregation pipeline into a {@link List} of {@link Document}. + * + * @param aggregation + * @param context + * @return + */ + DBObject createPipeline(String collectionName, Aggregation aggregation, AggregationOperationContext context) { + + if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) { + return aggregation.toDbObject(collectionName, context); + } + + DBObject command = aggregation.toDbObject(collectionName, context); + command.put("pipeline", mapAggregationPipeline((List) command.get("pipeline"))); + + return command; + } + + /** + * Extract the command and map the aggregation pipeline. + * + * @param aggregation + * @param context + * @return + */ + DBObject createCommand(String collection, Aggregation aggregation, AggregationOperationContext context) { + + DBObject command = aggregation.toDbObject(collection, context); + + if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) { + return command; + } + + command.put("pipeline", mapAggregationPipeline((List) command.get("pipeline"))); + + return command; + } + + private BasicDBList mapAggregationPipeline(List pipeline) { + + BasicDBList mapped = new BasicDBList(); + + for (DBObject stage : pipeline) { + mapped.add(queryMapper.getMappedObject(stage, null)); + } + + return mapped; + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index a052a7782..7b6932804 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -1554,7 +1554,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { Assert.notNull(aggregation, "Aggregation pipeline must not be null!"); Assert.notNull(outputType, "Output type must not be null!"); - DBObject commandResult = new BatchAggregationLoader(this, readPreference, Integer.MAX_VALUE) + DBObject commandResult = new BatchAggregationLoader(this, queryMapper, mappingContext, readPreference, + Integer.MAX_VALUE) .aggregate(collectionName, aggregation, context); return new AggregationResults(returnPotentiallyMappedResults(outputType, commandResult, collectionName), @@ -2555,12 +2556,18 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { private static final String OK = "ok"; private final MongoTemplate template; + private final QueryMapper queryMapper; + private final MappingContext, MongoPersistentProperty> mappingContext; private final ReadPreference readPreference; private final int batchSize; - BatchAggregationLoader(MongoTemplate template, ReadPreference readPreference, int batchSize) { + BatchAggregationLoader(MongoTemplate template, QueryMapper queryMapper, + MappingContext, MongoPersistentProperty> mappingContext, + ReadPreference readPreference, int batchSize) { this.template = template; + this.queryMapper = queryMapper; + this.mappingContext = mappingContext; this.readPreference = readPreference; this.batchSize = batchSize; } @@ -2583,11 +2590,13 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { * Pre process the aggregation command sent to the server by adding {@code cursor} options to match execution on * different server versions. */ - private static DBObject prepareAggregationCommand(String collectionName, Aggregation aggregation, - AggregationOperationContext context, int batchSize) { + private DBObject prepareAggregationCommand(String collectionName, Aggregation aggregation, + AggregationOperationContext context, int batchSize) { - AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context; - DBObject command = aggregation.toDbObject(collectionName, rootContext); + AggregationUtil aggregationUtil = new AggregationUtil(queryMapper, mappingContext); + + AggregationOperationContext rootContext = aggregationUtil.prepareAggregationContext(aggregation, context); + DBObject command = aggregationUtil.createCommand(collectionName, aggregation, rootContext); if (!aggregation.getOptions().isExplain()) { command.put(CURSOR_FIELD, new BasicDBObject(BATCH_SIZE_FIELD, batchSize)); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java index 34838b45a..53e59f1f5 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java @@ -31,6 +31,10 @@ import org.mockito.runners.MockitoJUnitRunner; import org.springframework.data.mongodb.core.MongoTemplate.BatchAggregationLoader; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; +import org.springframework.data.mongodb.core.convert.DbRefResolver; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import com.mongodb.BasicDBObject; import com.mongodb.CommandResult; @@ -39,7 +43,7 @@ import com.mongodb.ReadPreference; /** * Unit tests for {@link BatchAggregationLoader}. - * + * * @author Christoph Strobl * @author Mark Paluch */ @@ -50,6 +54,7 @@ public class BatchAggregationLoaderUnitTests { project().and("firstName").as("name")); @Mock MongoTemplate template; + @Mock DbRefResolver dbRefResolver; @Mock CommandResult aggregationResult; @Mock CommandResult getMoreResult; @@ -65,7 +70,9 @@ public class BatchAggregationLoaderUnitTests { @Before public void setUp() { - loader = new BatchAggregationLoader(template, ReadPreference.primary(), 10); + MongoMappingContext context = new MongoMappingContext(); + loader = new BatchAggregationLoader(template, new QueryMapper(new MappingMongoConverter(dbRefResolver, context)), + context, ReadPreference.primary(), 10); } @Test // DATAMONGO-1824 @@ -89,6 +96,7 @@ public class BatchAggregationLoaderUnitTests { when(aggregationResult.get("cursor")).thenReturn(cursorWithoutMore); DBObject result = loader.aggregate("person", AGGREGATION, Aggregation.DEFAULT_CONTEXT); + assertThat((List) result.get("result"), IsCollectionContaining. hasItem(luke)); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/TestEntities.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/TestEntities.java index 77933cae3..77fe70b79 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/TestEntities.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/TestEntities.java @@ -84,7 +84,7 @@ public class TestEntities { public List newYork() { - List venues = new ArrayList<>(); + List venues = new ArrayList(); venues.add(pennStation()); venues.add(tenGenOffice()); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index e2c937be7..af428ac9f 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -1741,7 +1741,7 @@ public class AggregationTests { .within(new Box(new Point(-73.99756, 40.73083), new Point(-73.988135, 40.741404)))), project("id", "location", "name")); - AggregationResults groupResults = mongoTemplate.aggregate(aggregation, "newyork", Document.class); + AggregationResults groupResults = mongoTemplate.aggregate(aggregation, "newyork", DBObject.class); assertThat(groupResults.getMappedResults().size(), is(4)); } @@ -1756,7 +1756,8 @@ public class AggregationTests { .within(new Box(new Point(-73.99756, 40.73083), new Point(-73.988135, 40.741404)))), project("id", "location", "name")); - AggregationResults groupResults = mongoTemplate.aggregate(aggregation, "newyork", Document.class); + AggregationResults groupResults = mongoTemplate.aggregate(aggregation, "newyork", DBObject.class); assertThat(groupResults.getMappedResults().size(), is(4)); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/AbstractGeoSpatialTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/AbstractGeoSpatialTests.java index 20f44d2dd..0d771860e 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/AbstractGeoSpatialTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/AbstractGeoSpatialTests.java @@ -45,6 +45,7 @@ import org.springframework.data.mongodb.core.query.Query; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import com.mongodb.Mongo; import com.mongodb.MongoClient; import com.mongodb.WriteConcern;