diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappedDocument.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappedDocument.java index dee860657..2efe43524 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappedDocument.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappedDocument.java @@ -137,5 +137,14 @@ public class MappedDocument { public Boolean isIsolated() { return delegate.isIsolated(); } + + /* + * (non-Javadoc) + * @see org.springframework.data.mongodb.core.query.UpdateDefinition#getArrayFilters() + */ + @Override + public List getArrayFilters() { + return delegate.getArrayFilters(); + } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index a8f0babcf..e7c16acfa 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -104,6 +104,7 @@ import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Update; import org.springframework.data.mongodb.core.query.UpdateDefinition; +import org.springframework.data.mongodb.core.query.UpdateDefinition.ArrayFilter; import org.springframework.data.mongodb.core.validation.Validator; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.util.CloseableIterator; @@ -1587,6 +1588,11 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, UpdateOptions opts = new UpdateOptions(); opts.upsert(upsert); + if (update.hasArrayFilters()) { + opts.arrayFilters( + update.getArrayFilters().stream().map(ArrayFilter::asDocument).collect(Collectors.toList())); + } + Document queryObj = new Document(); if (query != null) { @@ -2551,7 +2557,9 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, collectionName); } - return executeFindOneInternal(new FindAndModifyCallback(mappedQuery, fields, sort, mappedUpdate, options), + return executeFindOneInternal( + new FindAndModifyCallback(mappedQuery, fields, sort, mappedUpdate, + update.getArrayFilters().stream().map(ArrayFilter::asDocument).collect(Collectors.toList()), options), new ReadDocumentCallback<>(readerToUse, entityClass, collectionName), collectionName); } @@ -2908,14 +2916,16 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, private final Document fields; private final Document sort; private final Document update; + private final List arrayFilters; private final FindAndModifyOptions options; public FindAndModifyCallback(Document query, Document fields, Document sort, Document update, - FindAndModifyOptions options) { + List arrayFilters, FindAndModifyOptions options) { this.query = query; this.fields = fields; this.sort = sort; this.update = update; + this.arrayFilters = arrayFilters; this.options = options; } @@ -2933,6 +2943,10 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, options.getCollation().map(Collation::toMongoCollation).ifPresent(opts::collation); + if (!arrayFilters.isEmpty()) { + opts.arrayFilters(arrayFilters); + } + return collection.findOneAndUpdate(query, update, opts); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java index f92cd7632..f2c4837f1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java @@ -92,6 +92,7 @@ import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Update; import org.springframework.data.mongodb.core.query.UpdateDefinition; +import org.springframework.data.mongodb.core.query.UpdateDefinition.ArrayFilter; import org.springframework.data.mongodb.core.validation.Validator; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.util.Optionals; @@ -1640,6 +1641,11 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati UpdateOptions updateOptions = new UpdateOptions().upsert(upsert); query.getCollation().map(Collation::toMongoCollation).ifPresent(updateOptions::collation); + if (update.hasArrayFilters()) { + updateOptions.arrayFilters(update.getArrayFilters().stream().map(ArrayFilter::asDocument) + .map(it -> queryMapper.getMappedObject(it, entity)).collect(Collectors.toList())); + } + if (!UpdateMapper.isUpdateObject(updateObj)) { ReplaceOptions replaceOptions = new ReplaceOptions(); @@ -2367,7 +2373,7 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati collectionName)); } - return executeFindOneInternal(new FindAndModifyCallback(mappedQuery, fields, sort, mappedUpdate, options), + return executeFindOneInternal(new FindAndModifyCallback(mappedQuery, fields, sort, mappedUpdate, update.getArrayFilters().stream().map(ArrayFilter::asDocument).collect(Collectors.toList()), options), new ReadDocumentCallback<>(this.mongoConverter, entityClass, collectionName), collectionName); }); } @@ -2751,6 +2757,7 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati private final Document fields; private final Document sort; private final Document update; + private final List arrayFilters; private final FindAndModifyOptions options; @Override @@ -2766,12 +2773,12 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati return collection.findOneAndDelete(query, findOneAndDeleteOptions); } - FindOneAndUpdateOptions findOneAndUpdateOptions = convertToFindOneAndUpdateOptions(options, fields, sort); + FindOneAndUpdateOptions findOneAndUpdateOptions = convertToFindOneAndUpdateOptions(options, fields, sort, arrayFilters); return collection.findOneAndUpdate(query, update, findOneAndUpdateOptions); } - private FindOneAndUpdateOptions convertToFindOneAndUpdateOptions(FindAndModifyOptions options, Document fields, - Document sort) { + private static FindOneAndUpdateOptions convertToFindOneAndUpdateOptions(FindAndModifyOptions options, Document fields, + Document sort, List arrayFilters) { FindOneAndUpdateOptions result = new FindOneAndUpdateOptions(); @@ -2784,6 +2791,7 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati } result = options.getCollation().map(Collation::toMongoCollation).map(result::collation).orElse(result); + result.arrayFilters(arrayFilters); return result; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java index a9c52ea4c..5154d4ea0 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java @@ -15,6 +15,7 @@ */ package org.springframework.data.mongodb.core.query; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -58,6 +59,7 @@ public class Update implements UpdateDefinition { private Set keysToUpdate = new HashSet<>(); private Map modifierOps = new LinkedHashMap<>(); private Map pushCommandBuilders = new LinkedHashMap<>(1); + private List arrayFilters = new ArrayList<>(); /** * Static factory method to create an Update using the provided key @@ -399,6 +401,33 @@ public class Update implements UpdateDefinition { return this; } + /** + * Filter elements in an array that match the given criteria for update. + * + * @param criteria must not be {@literal null}. + * @return this. + * @since 2.2 + */ + public Update filterArray(CriteriaDefinition criteria) { + + this.arrayFilters.add(() -> criteria.getCriteriaObject()); + return this; + } + + /** + * Filter elements in an array that match the given criteria for update. + * + * @param identifier the positional operator identifier filter criteria name. + * @param expression the positional operator filter expression. + * @return this. + * @since 2.2 + */ + public Update filterArray(String identifier, Object expression) { + + this.arrayFilters.add(() -> new Document(identifier, expression)); + return this; + } + /* * (non-Javadoc) * @see org.springframework.data.mongodb.core.query.UpdateDefinition#isIsolated() @@ -415,6 +444,14 @@ public class Update implements UpdateDefinition { return new Document(modifierOps); } + /* + * (non-Javadoc) + * @see org.springframework.data.mongodb.core.query.UpdateDefinition#getArrayFilters() + */ + public List getArrayFilters() { + return Collections.unmodifiableList(this.arrayFilters); + } + /** * This method is not called anymore rather override {@link #addMultiFieldOperation(String, String, Object)}. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/UpdateDefinition.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/UpdateDefinition.java index 7918d0d74..c2b223676 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/UpdateDefinition.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/UpdateDefinition.java @@ -15,6 +15,8 @@ */ package org.springframework.data.mongodb.core.query; +import java.util.List; + import org.bson.Document; /** @@ -53,4 +55,36 @@ public interface UpdateDefinition { * @param key must not be {@literal null}. */ void inc(String key); + + /** + * Get the specification which elements to modify in an array field. + * + * @return never {@literal null}. + * @since 2.2 + */ + List getArrayFilters(); + + /** + * @return {@literal true} if {@link UpdateDefinition} contains {@link #getArrayFilters() array filters}. + * @since 2.2 + */ + default boolean hasArrayFilters() { + return !getArrayFilters().isEmpty(); + } + + /** + * A filter to specify which elements to modify in an array field. + * + * @since 2.2 + */ + interface ArrayFilter { + + /** + * Get the {@link Document} representation of the filter to apply. The returned Document is subject to mapping + * domain type filed names. + * + * @return never {@literal null}. + */ + Document asDocument(); + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java index 74b5bddec..cc89394ec 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java @@ -32,6 +32,7 @@ import java.util.Map; import java.util.Optional; import java.util.regex.Pattern; +import org.assertj.core.api.Assertions; import org.bson.Document; import org.bson.conversions.Bson; import org.bson.types.ObjectId; @@ -1034,6 +1035,34 @@ public class MongoTemplateUnitTests extends MongoOperationsUnitTests { is(equalTo(new Document("version", 11).append("_class", VersionedEntity.class.getName())))); } + @Test // DATAMONGO-2215 + public void updateShouldApplyArrayFilters() { + + template.updateFirst(new BasicQuery("{}"), + new Update().set("grades.$[element]", 100).filterArray(Criteria.where("element").gte(100)), + EntityWithListOfSimple.class); + + ArgumentCaptor options = ArgumentCaptor.forClass(UpdateOptions.class); + verify(collection).updateOne(any(), any(), options.capture()); + + Assertions.assertThat((List) options.getValue().getArrayFilters()) + .contains(new org.bson.Document("element", new Document("$gte", 100))); + } + + @Test // DATAMONGO-2215 + public void findAndModifyShouldApplyArrayFilters() { + + template.findAndModify(new BasicQuery("{}"), + new Update().set("grades.$[element]", 100).filterArray(Criteria.where("element").gte(100)), + EntityWithListOfSimple.class); + + ArgumentCaptor options = ArgumentCaptor.forClass(FindOneAndUpdateOptions.class); + verify(collection).findOneAndUpdate(any(), any(), options.capture()); + + Assertions.assertThat((List) options.getValue().getArrayFilters()) + .contains(new org.bson.Document("element", new Document("$gte", 100))); + } + class AutogenerateableId { @Id BigInteger id; @@ -1102,6 +1131,10 @@ public class MongoTemplateUnitTests extends MongoOperationsUnitTests { AutogenerateableId foo; } + static class EntityWithListOfSimple { + List grades; + } + /** * Mocks out the {@link MongoTemplate#getDb()} method to return the {@link DB} mock instead of executing the actual * behaviour. diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateUnitTests.java index 8ae62aca0..e2b2d8ccc 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateUnitTests.java @@ -25,8 +25,10 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import org.assertj.core.api.Assertions; import org.bson.Document; import org.bson.conversions.Bson; import org.bson.types.ObjectId; @@ -47,6 +49,7 @@ import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.Collation; +import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Update; @@ -81,6 +84,8 @@ public class ReactiveMongoTemplateUnitTests { @Mock FindPublisher findPublisher; @Mock AggregatePublisher aggregatePublisher; @Mock Publisher runCommandPublisher; + @Mock Publisher updatePublisher; + @Mock Publisher findAndUpdatePublisher; MongoExceptionTranslator exceptionTranslator = new MongoExceptionTranslator(); MappingMongoConverter converter; @@ -98,6 +103,8 @@ public class ReactiveMongoTemplateUnitTests { when(collection.find(any(Document.class), any(Class.class))).thenReturn(findPublisher); when(collection.aggregate(anyList())).thenReturn(aggregatePublisher); when(collection.aggregate(anyList(), any(Class.class))).thenReturn(aggregatePublisher); + when(collection.updateOne(any(), any(), any(UpdateOptions.class))).thenReturn(updatePublisher); + when(collection.findOneAndUpdate(any(), any(), any(FindOneAndUpdateOptions.class))).thenReturn(findAndUpdatePublisher); when(findPublisher.projection(any())).thenReturn(findPublisher); when(findPublisher.limit(anyInt())).thenReturn(findPublisher); when(findPublisher.collation(any())).thenReturn(findPublisher); @@ -343,6 +350,34 @@ public class ReactiveMongoTemplateUnitTests { verify(findPublisher, never()).projection(any()); } + @Test // DATAMONGO-2215 + public void updateShouldApplyArrayFilters() { + + template.updateFirst(new BasicQuery("{}"), + new Update().set("grades.$[element]", 100).filterArray(Criteria.where("element").gte(100)), + EntityWithListOfSimple.class).subscribe(); + + ArgumentCaptor options = ArgumentCaptor.forClass(UpdateOptions.class); + verify(collection).updateOne(any(), any(), options.capture()); + + Assertions.assertThat((List) options.getValue().getArrayFilters()) + .contains(new org.bson.Document("element", new Document("$gte", 100))); + } + + @Test // DATAMONGO-2215 + public void findAndModifyShouldApplyArrayFilters() { + + template.findAndModify(new BasicQuery("{}"), + new Update().set("grades.$[element]", 100).filterArray(Criteria.where("element").gte(100)), + EntityWithListOfSimple.class).subscribe(); + + ArgumentCaptor options = ArgumentCaptor.forClass(FindOneAndUpdateOptions.class); + verify(collection).findOneAndUpdate(any(), any(), options.capture()); + + Assertions.assertThat((List) options.getValue().getArrayFilters()) + .contains(new org.bson.Document("element", new Document("$gte", 100))); + } + @Data @org.springframework.data.mongodb.core.mapping.Document(collection = "star-wars") static class Person { @@ -371,4 +406,8 @@ public class ReactiveMongoTemplateUnitTests { @Field("firstname") String name; } + + static class EntityWithListOfSimple { + List grades; + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/UpdateMapperUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/UpdateMapperUnitTests.java index 6f6f81d52..ed522c442 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/UpdateMapperUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/UpdateMapperUnitTests.java @@ -1051,6 +1051,42 @@ public class UpdateMapperUnitTests { assertThat(mappedUpdate).isEqualTo(new Document("$inc", new Document("aliased.$[]", 10))); } + @Test // DATAMONGO-2215 + public void mappingShouldAllowPositionParameterWithIdentifier() { + + Update update = new Update().set("grades.$[element]", 10) // + .filterArray(Criteria.where("element").gte(100)); + + Document mappedUpdate = mapper.getMappedObject(update.getUpdateObject(), + context.getPersistentEntity(EntityWithListOfSimple.class)); + + assertThat(mappedUpdate).isEqualTo(new Document("$set", new Document("grades.$[element]", 10))); + } + + @Test // DATAMONGO-2215 + public void mappingShouldAllowPositionParameterWithIdentifierWhenFieldHasExplicitFieldName() { + + Update update = new Update().set("list.$[element]", 10) // + .filterArray(Criteria.where("element").gte(100)); + + Document mappedUpdate = mapper.getMappedObject(update.getUpdateObject(), + context.getPersistentEntity(ParentClass.class)); + + assertThat(mappedUpdate).isEqualTo(new Document("$set", new Document("aliased.$[element]", 10))); + } + + @Test // DATAMONGO-2215 + public void mappingShouldAllowNestedPositionParameterWithIdentifierWhenFieldHasExplicitFieldName() { + + Update update = new Update().set("list.$[element].value", 10) // + .filterArray(Criteria.where("element").gte(100)); + + Document mappedUpdate = mapper.getMappedObject(update.getUpdateObject(), + context.getPersistentEntity(ParentClass.class)); + + assertThat(mappedUpdate).isEqualTo(new Document("$set", new Document("aliased.$[element].value", 10))); + } + static class DomainTypeWrappingConcreteyTypeHavingListOfInterfaceTypeAttributes { ListModelWrapper concreteTypeWithListAttributeOfInterfaceType; }