From c3383432f7c5e828654fdf7a7dd42f255922c8e0 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 23 Jun 2017 07:57:43 +0200 Subject: [PATCH] DATAMONGO-1678 - Run bulk update / remove documents through type mappers. We now make sure to run any query / update object through the Query- / UpdateMapper. This ensures @Field annotations and potential custom conversions get processed correctly for update / remove operations. Original pull request: #472. --- .../mongodb/core/DefaultBulkOperations.java | 107 ++++++++++++++---- .../data/mongodb/core/MongoTemplate.java | 4 +- ...DefaultBulkOperationsIntegrationTests.java | 38 +++++-- .../core/DefaultBulkOperationsUnitTests.java | 49 +++++++- 4 files changed, 166 insertions(+), 32 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultBulkOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultBulkOperations.java index fb4411b47..6c63ad49b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultBulkOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultBulkOperations.java @@ -15,14 +15,21 @@ */ package org.springframework.data.mongodb.core; +import lombok.Data; + import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; -import com.mongodb.client.model.DeleteOptions; import org.bson.Document; +import org.bson.conversions.Bson; import org.springframework.dao.DataAccessException; import org.springframework.dao.support.PersistenceExceptionTranslator; +import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.convert.UpdateMapper; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Update; import org.springframework.data.util.Pair; @@ -33,6 +40,8 @@ import com.mongodb.WriteConcern; import com.mongodb.client.MongoCollection; import com.mongodb.client.model.BulkWriteOptions; import com.mongodb.client.model.DeleteManyModel; +import com.mongodb.client.model.DeleteOneModel; +import com.mongodb.client.model.DeleteOptions; import com.mongodb.client.model.InsertOneModel; import com.mongodb.client.model.UpdateManyModel; import com.mongodb.client.model.UpdateOneModel; @@ -41,7 +50,7 @@ import com.mongodb.client.model.WriteModel; /** * Default implementation for {@link BulkOperations}. - * + * * @author Tobias Trelle * @author Oliver Gierke * @author Christoph Strobl @@ -50,11 +59,11 @@ import com.mongodb.client.model.WriteModel; class DefaultBulkOperations implements BulkOperations { private final MongoOperations mongoOperations; - private final BulkMode bulkMode; private final String collectionName; + private final BulkOperationContext bulkOperationContext; - private PersistenceExceptionTranslator exceptionTranslator; private WriteConcernResolver writeConcernResolver; + private PersistenceExceptionTranslator exceptionTranslator; private WriteConcern defaultWriteConcern; private BulkWriteOptions bulkOptions; @@ -62,28 +71,25 @@ class DefaultBulkOperations implements BulkOperations { List> models = new ArrayList<>(); /** - * Creates a new {@link DefaultBulkOperations} for the given {@link MongoOperations}, {@link BulkMode}, collection - * name and {@link WriteConcern}. + * Creates a new {@link DefaultBulkOperations} for the given {@link MongoOperations}, collection name and + * {@link BulkOperationContext}. * - * @param mongoOperations The underlying {@link MongoOperations}, must not be {@literal null}. - * @param bulkMode must not be {@literal null}. - * @param collectionName Name of the collection to work on, must not be {@literal null} or empty. - * @param entityType the entity type, can be {@literal null}. + * @param mongoOperations must not be {@literal null}. + * @param collectionName must not be {@literal null}. + * @param bulkOperationContext must not be {@literal null}. + * @since 2.0 */ - DefaultBulkOperations(MongoOperations mongoOperations, BulkMode bulkMode, String collectionName, - Class entityType) { + DefaultBulkOperations(MongoOperations mongoOperations, String collectionName, + BulkOperationContext bulkOperationContext) { Assert.notNull(mongoOperations, "MongoOperations must not be null!"); - Assert.notNull(bulkMode, "BulkMode must not be null!"); - Assert.hasText(collectionName, "Collection name must not be null or empty!"); + Assert.hasText(collectionName, "CollectionName must not be null nor empty!"); + Assert.notNull(bulkOperationContext, "BulkOperationContext must not be null!"); this.mongoOperations = mongoOperations; - this.bulkMode = bulkMode; this.collectionName = collectionName; - + this.bulkOperationContext = bulkOperationContext; this.exceptionTranslator = new MongoExceptionTranslator(); - this.writeConcernResolver = DefaultWriteConcernResolver.INSTANCE; - this.bulkOptions = initBulkOperation(); } @@ -282,7 +288,7 @@ class DefaultBulkOperations implements BulkOperations { collection = collection.withWriteConcern(defaultWriteConcern); } - return collection.bulkWrite(models, bulkOptions); + return collection.bulkWrite(models.stream().map(this::mapWriteModel).collect(Collectors.toList()), bulkOptions); } catch (BulkWriteException o_O) { @@ -323,7 +329,8 @@ class DefaultBulkOperations implements BulkOperations { private final BulkWriteOptions initBulkOperation() { BulkWriteOptions options = new BulkWriteOptions(); - switch (bulkMode) { + + switch (bulkOperationContext.getBulkMode()) { case ORDERED: return options.ordered(true); case UNORDERED: @@ -331,4 +338,64 @@ class DefaultBulkOperations implements BulkOperations { } throw new IllegalStateException("BulkMode was null!"); } + + private WriteModel mapWriteModel(WriteModel writeModel) { + + if (writeModel instanceof UpdateOneModel) { + + UpdateOneModel model = (UpdateOneModel) writeModel; + + return new UpdateOneModel(getMappedQuery(model.getFilter()), getMappedUpdate(model.getUpdate()), + model.getOptions()); + } + + if (writeModel instanceof UpdateManyModel) { + + UpdateManyModel model = (UpdateManyModel) writeModel; + + return new UpdateManyModel(getMappedQuery(model.getFilter()), getMappedUpdate(model.getUpdate()), + model.getOptions()); + } + + if (writeModel instanceof DeleteOneModel) { + + DeleteOneModel model = (DeleteOneModel) writeModel; + + return new DeleteOneModel(getMappedQuery(model.getFilter()), model.getOptions()); + } + + if (writeModel instanceof DeleteManyModel) { + + DeleteManyModel model = (DeleteManyModel) writeModel; + + return new DeleteManyModel(getMappedQuery(model.getFilter()), model.getOptions()); + } + + return writeModel; + } + + private Bson getMappedUpdate(Bson update) { + return bulkOperationContext.getUpdateMapper().getMappedObject(update, bulkOperationContext.getEntity()); + } + + private Bson getMappedQuery(Bson query) { + return bulkOperationContext.getQueryMapper().getMappedObject(query, bulkOperationContext.getEntity()); + } + + /** + * {@link BulkOperationContext} holds information about + * {@link org.springframework.data.mongodb.core.BulkOperations.BulkMode} the entity in use as well as references to + * {@link QueryMapper} and {@link UpdateMapper}. + * + * @author Christoph Strobl + * @since 2.0 + */ + @Data + static class BulkOperationContext { + + final BulkMode bulkMode; + final Optional> entity; + final QueryMapper queryMapper; + final UpdateMapper updateMapper; + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 82bf2ca50..88612c16c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -67,6 +67,7 @@ import org.springframework.data.mapping.model.ConvertingPropertyAccessor; import org.springframework.data.mapping.model.MappingException; import org.springframework.data.mongodb.MongoDbFactory; import org.springframework.data.mongodb.core.BulkOperations.BulkMode; +import org.springframework.data.mongodb.core.DefaultBulkOperations.BulkOperationContext; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.AggregationOptions; @@ -557,7 +558,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, Assert.notNull(mode, "BulkMode must not be null!"); Assert.hasText(collectionName, "Collection name must not be null or empty!"); - DefaultBulkOperations operations = new DefaultBulkOperations(this, mode, collectionName, entityType); + DefaultBulkOperations operations = new DefaultBulkOperations(this, collectionName, + new BulkOperationContext(mode, getPersistentEntity(entityType), queryMapper, updateMapper)); operations.setExceptionTranslator(exceptionTranslator); operations.setWriteConcernResolver(writeConcernResolver); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultBulkOperationsIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultBulkOperationsIntegrationTests.java index 1bd055b1d..b0c749b45 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultBulkOperationsIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultBulkOperationsIntegrationTests.java @@ -21,6 +21,7 @@ import static org.junit.Assert.*; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Optional; import org.bson.Document; import org.junit.Before; @@ -28,6 +29,10 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.mongodb.core.BulkOperations.BulkMode; +import org.springframework.data.mongodb.core.DefaultBulkOperations.BulkOperationContext; +import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.convert.UpdateMapper; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Update; @@ -65,17 +70,20 @@ public class DefaultBulkOperationsIntegrationTests { @Test(expected = IllegalArgumentException.class) // DATAMONGO-934 public void rejectsNullMongoOperations() { - new DefaultBulkOperations(null, null, COLLECTION_NAME, null); + new DefaultBulkOperations(null, COLLECTION_NAME, + new BulkOperationContext(BulkMode.ORDERED, Optional.empty(), null, null)); + } @Test(expected = IllegalArgumentException.class) // DATAMONGO-934 public void rejectsNullCollectionName() { - new DefaultBulkOperations(operations, null, null, null); + new DefaultBulkOperations(operations, null, + new BulkOperationContext(BulkMode.ORDERED, Optional.empty(), null, null)); } @Test(expected = IllegalArgumentException.class) // DATAMONGO-934 public void rejectsEmptyCollectionName() { - new DefaultBulkOperations(operations, null, "", null); + new DefaultBulkOperations(operations, "", new BulkOperationContext(BulkMode.ORDERED, Optional.empty(), null, null)); } @Test // DATAMONGO-934 @@ -191,7 +199,7 @@ public class DefaultBulkOperationsIntegrationTests { @Test // DATAMONGO-934 public void mixedBulkOrdered() { - com.mongodb.bulk.BulkWriteResult result = createBulkOps(BulkMode.ORDERED).insert(newDoc("1", "v1")).// + com.mongodb.bulk.BulkWriteResult result = createBulkOps(BulkMode.ORDERED, BaseDoc.class).insert(newDoc("1", "v1")).// updateOne(where("_id", "1"), set("value", "v2")).// remove(where("value", "v2")).// execute(); @@ -213,8 +221,8 @@ public class DefaultBulkOperationsIntegrationTests { List> updates = Arrays.asList(Pair.of(where("value", "v2"), set("value", "v3"))); List removes = Arrays.asList(where("_id", "1")); - com.mongodb.bulk.BulkWriteResult result = createBulkOps(BulkMode.ORDERED).insert(inserts).updateMulti(updates) - .remove(removes).execute(); + com.mongodb.bulk.BulkWriteResult result = createBulkOps(BulkMode.ORDERED, BaseDoc.class).insert(inserts) + .updateMulti(updates).remove(removes).execute(); assertThat(result, notNullValue()); assertThat(result.getInsertedCount(), is(3)); @@ -230,7 +238,7 @@ public class DefaultBulkOperationsIntegrationTests { specialDoc.value = "normal-value"; specialDoc.specialValue = "special-value"; - createBulkOps(BulkMode.ORDERED).insert(Arrays.asList(specialDoc)).execute(); + createBulkOps(BulkMode.ORDERED, SpecialDoc.class).insert(Arrays.asList(specialDoc)).execute(); BaseDoc doc = operations.findOne(where("_id", specialDoc.id), BaseDoc.class, COLLECTION_NAME); @@ -264,11 +272,21 @@ public class DefaultBulkOperationsIntegrationTests { } private BulkOperations createBulkOps(BulkMode mode) { + return createBulkOps(mode, null); + } + + private BulkOperations createBulkOps(BulkMode mode, Class entityType) { + + Optional> entity = entityType != null + ? operations.getConverter().getMappingContext().getPersistentEntity(entityType) : Optional.empty(); + + BulkOperationContext bulkOperationContext = new BulkOperationContext(mode, entity, + new QueryMapper(operations.getConverter()), new UpdateMapper(operations.getConverter())); - DefaultBulkOperations operations = new DefaultBulkOperations(this.operations, mode, COLLECTION_NAME, null); - operations.setDefaultWriteConcern(WriteConcern.ACKNOWLEDGED); + DefaultBulkOperations bulkOps = new DefaultBulkOperations(operations, COLLECTION_NAME, bulkOperationContext); + bulkOps.setDefaultWriteConcern(WriteConcern.ACKNOWLEDGED); - return operations; + return bulkOps; } private void insertSomeDocuments() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultBulkOperationsUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultBulkOperationsUnitTests.java index f0d3ea1d1..d81b5c139 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultBulkOperationsUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultBulkOperationsUnitTests.java @@ -19,6 +19,8 @@ import static org.assertj.core.api.Assertions.*; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.*; import static org.mockito.Mockito.any; +import static org.springframework.data.mongodb.core.query.Criteria.*; +import static org.springframework.data.mongodb.core.query.Query.*; import java.util.List; @@ -31,7 +33,14 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.data.annotation.Id; import org.springframework.data.mongodb.core.BulkOperations.BulkMode; +import org.springframework.data.mongodb.core.DefaultBulkOperations.BulkOperationContext; +import org.springframework.data.mongodb.core.convert.DbRefResolver; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.convert.MongoConverter; +import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.convert.UpdateMapper; import org.springframework.data.mongodb.core.mapping.Field; +import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.Update; @@ -49,14 +58,25 @@ public class DefaultBulkOperationsUnitTests { @Mock MongoTemplate template; @Mock MongoCollection collection; + @Mock DbRefResolver dbRefResolver; + MongoConverter converter; + MongoMappingContext mappingContext; DefaultBulkOperations ops; @Before public void setUp() { + mappingContext = new MongoMappingContext(); + mappingContext.afterPropertiesSet(); + + converter = new MappingMongoConverter(dbRefResolver, mappingContext); + when(template.getCollection(anyString())).thenReturn(collection); - ops = new DefaultBulkOperations(template, BulkMode.ORDERED, "collection-1", SomeDomainType.class); + + ops = new DefaultBulkOperations(template, "collection-1", + new BulkOperationContext(BulkMode.ORDERED, mappingContext.getPersistentEntity(SomeDomainType.class), + new QueryMapper(converter), new UpdateMapper(converter))); } @Test // DATAMONGO-1518 @@ -103,6 +123,33 @@ public class DefaultBulkOperationsUnitTests { .isEqualTo(com.mongodb.client.model.Collation.builder().locale("de").build()); } + @Test // DATAMONGO-1678 + public void bulkUpdateShouldMapQueryAndUpdateCorrectly() { + + ops.updateOne(query(where("firstName").is("danerys")), Update.update("firstName", "queen danerys")).execute(); + + ArgumentCaptor>> captor = ArgumentCaptor.forClass(List.class); + + verify(collection).bulkWrite(captor.capture(), any()); + + UpdateOneModel updateModel = (UpdateOneModel) captor.getValue().get(0); + assertThat(updateModel.getFilter()).isEqualTo(new Document("first_name", "danerys")); + assertThat(updateModel.getUpdate()).isEqualTo(new Document("$set", new Document("first_name", "queen danerys"))); + } + + @Test // DATAMONGO-1678 + public void bulkRemoveShouldMapQueryCorrectly() { + + ops.remove(query(where("firstName").is("danerys"))).execute(); + + ArgumentCaptor>> captor = ArgumentCaptor.forClass(List.class); + + verify(collection).bulkWrite(captor.capture(), any()); + + DeleteManyModel updateModel = (DeleteManyModel) captor.getValue().get(0); + assertThat(updateModel.getFilter()).isEqualTo(new Document("first_name", "danerys")); + } + class SomeDomainType { @Id String id;