diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml
index 83cbb58fc..7a74583a5 100644
--- a/spring-data-mongodb/pom.xml
+++ b/spring-data-mongodb/pom.xml
@@ -305,7 +305,7 @@
io.mockk
- mockk
+ mockk-jvm
${mockk}
test
diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java
index bfbd5d40a..968f3bda1 100644
--- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java
+++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java
@@ -15,8 +15,6 @@
*/
package org.springframework.data.mongodb.core;
-import java.util.Arrays;
-import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
@@ -24,23 +22,16 @@ import java.util.stream.Collectors;
import org.bson.Document;
import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
-import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
-import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
import org.springframework.data.mongodb.core.aggregation.AggregationOptions.DomainTypeMapping;
-import org.springframework.data.mongodb.core.aggregation.CountOperation;
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
-import org.springframework.data.mongodb.core.query.CriteriaDefinition;
-import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.util.Lazy;
import org.springframework.lang.Nullable;
-import org.springframework.util.Assert;
-import org.springframework.util.ObjectUtils;
/**
* Utility methods to map {@link org.springframework.data.mongodb.core.aggregation.Aggregation} pipeline definitions and
@@ -75,12 +66,11 @@ class AggregationUtil {
if (!(aggregation instanceof TypedAggregation)) {
- if(inputType == null) {
+ if (inputType == null) {
return untypedMappingContext.get();
}
- if (domainTypeMapping == DomainTypeMapping.STRICT
- && !aggregation.getPipeline().containsUnionWith()) {
+ if (domainTypeMapping == DomainTypeMapping.STRICT && !aggregation.getPipeline().containsUnionWith()) {
return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
}
@@ -88,8 +78,7 @@ class AggregationUtil {
}
inputType = ((TypedAggregation>) aggregation).getInputType();
- if (domainTypeMapping == DomainTypeMapping.STRICT
- && !aggregation.getPipeline().containsUnionWith()) {
+ if (domainTypeMapping == DomainTypeMapping.STRICT && !aggregation.getPipeline().containsUnionWith()) {
return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
}
@@ -104,12 +93,7 @@ class AggregationUtil {
* @return
*/
List createPipeline(Aggregation aggregation, AggregationOperationContext context) {
-
- if (ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
- return aggregation.toPipeline(context);
- }
-
- return mapAggregationPipeline(aggregation.toPipeline(context));
+ return aggregation.toPipeline(context);
}
/**
@@ -120,16 +104,7 @@ class AggregationUtil {
* @return
*/
Document createCommand(String collection, Aggregation aggregation, AggregationOperationContext context) {
-
- Document command = aggregation.toDocument(collection, context);
-
- if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
- return command;
- }
-
- command.put("pipeline", mapAggregationPipeline(command.get("pipeline", List.class)));
-
- return command;
+ return aggregation.toDocument(collection, context);
}
private List mapAggregationPipeline(List pipeline) {
diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java
index 2f58bec24..078eb65ba 100644
--- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java
+++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java
@@ -44,7 +44,6 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
-
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.annotation.Id;
import org.springframework.data.domain.Sort;
@@ -62,6 +61,7 @@ import org.springframework.data.mongodb.core.aggregation.VariableOperators.Let.E
import org.springframework.data.mongodb.core.geo.GeoJsonPoint;
import org.springframework.data.mongodb.core.index.GeoSpatialIndexType;
import org.springframework.data.mongodb.core.index.GeospatialIndex;
+import org.springframework.data.mongodb.core.mapping.MongoId;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.NearQuery;
import org.springframework.data.mongodb.core.query.Query;
@@ -1943,6 +1943,25 @@ public class AggregationTests {
assertThat(results.getMappedResults()).hasSize(1);
}
+ @Test // GH-4043
+ void considersMongoIdWithinTypedCollections() {
+
+ UserRef userRef = new UserRef();
+ userRef.id = "4ee921aca44fd11b3254e001";
+ userRef.name = "u-1";
+
+ Widget widget = new Widget();
+ widget.id = "w-1";
+ widget.users = Collections.singletonList(userRef);
+
+ mongoTemplate.save(widget);
+
+ Criteria criteria = Criteria.where("users").elemMatch(Criteria.where("id").is("4ee921aca44fd11b3254e001"));
+ AggregationResults aggregate = mongoTemplate.aggregate(newAggregation(match(criteria)), Widget.class,
+ Widget.class);
+ assertThat(aggregate.getMappedResults()).contains(widget);
+ }
+
private void createUsersWithReferencedPersons() {
mongoTemplate.dropCollection(User.class);
@@ -2266,4 +2285,16 @@ public class AggregationTests {
@Id String id;
MyEnum enumValue;
}
+
+ @lombok.Data
+ static class Widget {
+ @Id String id;
+ List users;
+ }
+
+ @lombok.Data
+ static class UserRef {
+ @MongoId String id;
+ String name;
+ }
}