Browse Source

Avoid multiple mapping iterations.

A 2nd pass is no longer needed as the context already does all the work.

Closes: #4043
Original pull request: #4240
3.4.x
Christoph Strobl 3 years ago committed by Mark Paluch
parent
commit
f7b913535e
No known key found for this signature in database
GPG Key ID: 4406B84C1661DCD1
  1. 2
      spring-data-mongodb/pom.xml
  2. 35
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java
  3. 33
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java

2
spring-data-mongodb/pom.xml

@ -305,7 +305,7 @@
<dependency> <dependency>
<groupId>io.mockk</groupId> <groupId>io.mockk</groupId>
<artifactId>mockk</artifactId> <artifactId>mockk-jvm</artifactId>
<version>${mockk}</version> <version>${mockk}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

35
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java

@ -15,8 +15,6 @@
*/ */
package org.springframework.data.mongodb.core; package org.springframework.data.mongodb.core;
import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -24,23 +22,16 @@ import java.util.stream.Collectors;
import org.bson.Document; import org.bson.Document;
import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
import org.springframework.data.mongodb.core.aggregation.AggregationOptions.DomainTypeMapping; import org.springframework.data.mongodb.core.aggregation.AggregationOptions.DomainTypeMapping;
import org.springframework.data.mongodb.core.aggregation.CountOperation;
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
import org.springframework.data.mongodb.core.convert.QueryMapper; import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.data.mongodb.core.query.CriteriaDefinition;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.util.Lazy; import org.springframework.data.util.Lazy;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
/** /**
* Utility methods to map {@link org.springframework.data.mongodb.core.aggregation.Aggregation} pipeline definitions and * Utility methods to map {@link org.springframework.data.mongodb.core.aggregation.Aggregation} pipeline definitions and
@ -75,12 +66,11 @@ class AggregationUtil {
if (!(aggregation instanceof TypedAggregation)) { if (!(aggregation instanceof TypedAggregation)) {
if(inputType == null) { if (inputType == null) {
return untypedMappingContext.get(); return untypedMappingContext.get();
} }
if (domainTypeMapping == DomainTypeMapping.STRICT if (domainTypeMapping == DomainTypeMapping.STRICT && !aggregation.getPipeline().containsUnionWith()) {
&& !aggregation.getPipeline().containsUnionWith()) {
return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper); return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
} }
@ -88,8 +78,7 @@ class AggregationUtil {
} }
inputType = ((TypedAggregation<?>) aggregation).getInputType(); inputType = ((TypedAggregation<?>) aggregation).getInputType();
if (domainTypeMapping == DomainTypeMapping.STRICT if (domainTypeMapping == DomainTypeMapping.STRICT && !aggregation.getPipeline().containsUnionWith()) {
&& !aggregation.getPipeline().containsUnionWith()) {
return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper); return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
} }
@ -104,12 +93,7 @@ class AggregationUtil {
* @return * @return
*/ */
List<Document> createPipeline(Aggregation aggregation, AggregationOperationContext context) { List<Document> createPipeline(Aggregation aggregation, AggregationOperationContext context) {
return aggregation.toPipeline(context);
if (ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
return aggregation.toPipeline(context);
}
return mapAggregationPipeline(aggregation.toPipeline(context));
} }
/** /**
@ -120,16 +104,7 @@ class AggregationUtil {
* @return * @return
*/ */
Document createCommand(String collection, Aggregation aggregation, AggregationOperationContext context) { Document createCommand(String collection, Aggregation aggregation, AggregationOperationContext context) {
return aggregation.toDocument(collection, context);
Document command = aggregation.toDocument(collection, context);
if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
return command;
}
command.put("pipeline", mapAggregationPipeline(command.get("pipeline", List.class)));
return command;
} }
private List<Document> mapAggregationPipeline(List<Document> pipeline) { private List<Document> mapAggregationPipeline(List<Document> pipeline) {

33
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java

@ -44,7 +44,6 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.ClassPathResource;
import org.springframework.data.annotation.Id; import org.springframework.data.annotation.Id;
import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort;
@ -62,6 +61,7 @@ import org.springframework.data.mongodb.core.aggregation.VariableOperators.Let.E
import org.springframework.data.mongodb.core.geo.GeoJsonPoint; import org.springframework.data.mongodb.core.geo.GeoJsonPoint;
import org.springframework.data.mongodb.core.index.GeoSpatialIndexType; import org.springframework.data.mongodb.core.index.GeoSpatialIndexType;
import org.springframework.data.mongodb.core.index.GeospatialIndex; import org.springframework.data.mongodb.core.index.GeospatialIndex;
import org.springframework.data.mongodb.core.mapping.MongoId;
import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.NearQuery;
import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Query;
@ -1943,6 +1943,25 @@ public class AggregationTests {
assertThat(results.getMappedResults()).hasSize(1); assertThat(results.getMappedResults()).hasSize(1);
} }
@Test // GH-4043
void considersMongoIdWithinTypedCollections() {
UserRef userRef = new UserRef();
userRef.id = "4ee921aca44fd11b3254e001";
userRef.name = "u-1";
Widget widget = new Widget();
widget.id = "w-1";
widget.users = Collections.singletonList(userRef);
mongoTemplate.save(widget);
Criteria criteria = Criteria.where("users").elemMatch(Criteria.where("id").is("4ee921aca44fd11b3254e001"));
AggregationResults<Widget> aggregate = mongoTemplate.aggregate(newAggregation(match(criteria)), Widget.class,
Widget.class);
assertThat(aggregate.getMappedResults()).contains(widget);
}
private void createUsersWithReferencedPersons() { private void createUsersWithReferencedPersons() {
mongoTemplate.dropCollection(User.class); mongoTemplate.dropCollection(User.class);
@ -2266,4 +2285,16 @@ public class AggregationTests {
@Id String id; @Id String id;
MyEnum enumValue; MyEnum enumValue;
} }
@lombok.Data
static class Widget {
@Id String id;
List<UserRef> users;
}
@lombok.Data
static class UserRef {
@MongoId String id;
String name;
}
} }

Loading…
Cancel
Save