diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoActionOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoActionOperation.java index 446c9e557..467f6586f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoActionOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoActionOperation.java @@ -21,9 +21,10 @@ package org.springframework.data.mongodb.core; * * @author Mark Pollack * @author Oliver Gierke + * @author Christoph Strobl * @see MongoAction */ public enum MongoActionOperation { - REMOVE, UPDATE, INSERT, INSERT_LIST, SAVE, BULK; + REMOVE, UPDATE, INSERT, INSERT_LIST, SAVE, BULK, REPLACE; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 441be2fd8..83c66d188 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -2082,8 +2082,6 @@ public class MongoTemplate Assert.isTrue(query.getLimit() <= 1, "Query must not define a limit other than 1 ore none"); Assert.isTrue(query.getSkip() <= 0, "Query must not define skip"); - CollectionPreparerDelegate collectionPreparer = createDelegate(query); - UpdateContext updateContext = null; if (replacement instanceof AggregationUpdate au) { updateContext = queryOperations.updateContext(au, query, options.isUpsert()); @@ -2097,10 +2095,14 @@ public class MongoTemplate maybeEmitEvent(new BeforeSaveEvent<>(replacement, mappedReplacement, collectionName)); replacement = maybeCallBeforeSave(replacement, mappedReplacement, collectionName); - UpdateResult result = doReplace(options, entityType, collectionName, updateContext, collectionPreparer, - mappedReplacement); + MongoAction action = new MongoAction(writeConcern, MongoActionOperation.REPLACE, collectionName, entityType, + mappedReplacement, updateContext.getQueryObject()); + + UpdateResult result = doReplace(options, entityType, collectionName, updateContext, + createCollectionPreparer(query, action), mappedReplacement); if (result.wasAcknowledged()) { + maybeEmitEvent(new AfterSaveEvent<>(replacement, mappedReplacement, collectionName)); maybeCallAfterSave(replacement, mappedReplacement, collectionName); } @@ -2773,6 +2775,17 @@ public class MongoTemplate return CollectionPreparerDelegate.of(query); } + CollectionPreparer> createCollectionPreparer(Query query, @Nullable MongoAction action) { + CollectionPreparer> collectionPreparer = createDelegate(query); + if (action == null) { + return collectionPreparer; + } + return collectionPreparer.andThen(collection -> { + WriteConcern writeConcern = prepareWriteConcern(action); + return writeConcern != null ? collection.withWriteConcern(writeConcern) : collection; + }); + } + /** * Customize this part for findAndReplace. * @@ -2809,9 +2822,11 @@ public class MongoTemplate } private UpdateResult doReplace(ReplaceOptions options, Class entityType, String collectionName, - UpdateContext updateContext, CollectionPreparerDelegate collectionPreparer, Document replacement) { + UpdateContext updateContext, CollectionPreparer> collectionPreparer, + Document replacement) { MongoPersistentEntity persistentEntity = mappingContext.getPersistentEntity(entityType); + ReplaceCallback replaceCallback = new ReplaceCallback(collectionPreparer, updateContext.getMappedQuery(persistentEntity), replacement, updateContext.getReplaceOptions(entityType, it -> { it.upsert(options.isUpsert()); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java index 70c1d49f7..95cf4f767 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java @@ -17,6 +17,7 @@ package org.springframework.data.mongodb.core; import static org.springframework.data.mongodb.core.query.SerializationUtils.*; +import org.springframework.data.mongodb.core.CollectionPreparerSupport.CollectionPreparerDelegate; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.util.function.Tuple2; @@ -1968,10 +1969,14 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati return createMono(collectionName, collection -> { - ReactiveCollectionPreparerDelegate collectionPreparer = ReactiveCollectionPreparerDelegate.of(query); - MongoCollection collectionToUse = collectionPreparer.prepare(collection); + Document mappedUpdate = updateContext.getMappedUpdate(entity); + + MongoAction action = new MongoAction(writeConcern, MongoActionOperation.REPLACE, collectionName, entityType, + mappedUpdate, updateContext.getQueryObject()); - return collectionToUse.replaceOne(updateContext.getMappedQuery(entity), updateContext.getMappedUpdate(entity), updateContext.getReplaceOptions(entityType, it -> { + MongoCollection collectionToUse = createCollectionPreparer(query, action).prepare(collection); + + return collectionToUse.replaceOne(updateContext.getMappedQuery(entity), mappedUpdate, updateContext.getReplaceOptions(entityType, it -> { it.upsert(options.isUpsert()); })); }); @@ -2359,6 +2364,21 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati objectCallback, collectionName); } + CollectionPreparer> createCollectionPreparer(Query query) { + return ReactiveCollectionPreparerDelegate.of(query); + } + + CollectionPreparer> createCollectionPreparer(Query query, @Nullable MongoAction action) { + CollectionPreparer> collectionPreparer = createCollectionPreparer(query); + if (action == null) { + return collectionPreparer; + } + return collectionPreparer.andThen(collection -> { + WriteConcern writeConcern = prepareWriteConcern(action); + return writeConcern != null ? collection.withWriteConcern(writeConcern) : collection; + }); + } + /** * Map the results of an ad-hoc query on the default MongoDB collection to a List of the specified targetClass while * using sourceClass for mapping the query. diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java index 3101818e4..d196394fd 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java @@ -178,6 +178,7 @@ public class MongoTemplateUnitTests extends MongoOperationsUnitTests { when(collection.withWriteConcern(any())).thenReturn(collectionWithWriteConcern); when(collection.distinct(anyString(), any(Document.class), any())).thenReturn(distinctIterable); when(collectionWithWriteConcern.deleteOne(any(Bson.class), any())).thenReturn(deleteResult); + when(collectionWithWriteConcern.replaceOne(any(), any(), any(com.mongodb.client.model.ReplaceOptions.class))).thenReturn(updateResult); when(findIterable.projection(any())).thenReturn(findIterable); when(findIterable.sort(any(org.bson.Document.class))).thenReturn(findIterable); when(findIterable.collation(any())).thenReturn(findIterable); @@ -2497,6 +2498,22 @@ public class MongoTemplateUnitTests extends MongoOperationsUnitTests { assertThat(options.getValue().getHintString()).isEqualTo("index-to-use"); } + @Test // GH-4462 + void replaceShouldApplyWriteConcern() { + + template.setWriteConcernResolver(new WriteConcernResolver() { + public WriteConcern resolve(MongoAction action) { + + assertThat(action.getMongoActionOperation()).isEqualTo(MongoActionOperation.REPLACE); + return WriteConcern.UNACKNOWLEDGED; + } + }); + + template.replace(new BasicQuery("{}").withHint("index-to-use"), new Sith(), ReplaceOptions.replaceOptions().upsert()); + + verify(collection).withWriteConcern(eq(WriteConcern.UNACKNOWLEDGED)); + } + class AutogenerateableId { @Id BigInteger id; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateUnitTests.java index 8d368e280..c7b1e8f03 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateUnitTests.java @@ -20,6 +20,8 @@ import static org.mockito.Mockito.*; import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; import static org.springframework.data.mongodb.test.util.Assertions.assertThat; +import com.mongodb.WriteConcern; +import org.springframework.data.mongodb.core.MongoTemplateUnitTests.Sith; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -1602,6 +1604,83 @@ public class ReactiveMongoTemplateUnitTests { verify(changeStreamPublisher).startAfter(eq(token)); } + @Test // GH-4462 + void replaceShouldUseCollationWhenPresent() { + + template.replace(new BasicQuery("{}").collation(Collation.of("fr")), new Jedi()).subscribe(); + + ArgumentCaptor options = ArgumentCaptor + .forClass(com.mongodb.client.model.ReplaceOptions.class); + verify(collection).replaceOne(any(Bson.class), any(), options.capture()); + + assertThat(options.getValue().isUpsert()).isFalse(); + assertThat(options.getValue().getCollation().getLocale()).isEqualTo("fr"); + } + + @Test // GH-4462 + void replaceShouldNotUpsertByDefault() { + + template.replace(new BasicQuery("{}"), new MongoTemplateUnitTests.Sith()).subscribe(); + + ArgumentCaptor options = ArgumentCaptor + .forClass(com.mongodb.client.model.ReplaceOptions.class); + verify(collection).replaceOne(any(Bson.class), any(), options.capture()); + + assertThat(options.getValue().isUpsert()).isFalse(); + } + + @Test // GH-4462 + void replaceShouldUpsert() { + + template.replace(new BasicQuery("{}"), new MongoTemplateUnitTests.Sith(), org.springframework.data.mongodb.core.ReplaceOptions.replaceOptions().upsert()).subscribe(); + + ArgumentCaptor options = ArgumentCaptor + .forClass(com.mongodb.client.model.ReplaceOptions.class); + verify(collection).replaceOne(any(Bson.class), any(), options.capture()); + + assertThat(options.getValue().isUpsert()).isTrue(); + } + + @Test // GH-4462 + void replaceShouldUseDefaultCollationWhenPresent() { + + template.replace(new BasicQuery("{}"), new MongoTemplateUnitTests.Sith(), org.springframework.data.mongodb.core.ReplaceOptions.replaceOptions()).subscribe(); + + ArgumentCaptor options = ArgumentCaptor + .forClass(com.mongodb.client.model.ReplaceOptions.class); + verify(collection).replaceOne(any(Bson.class), any(), options.capture()); + + assertThat(options.getValue().getCollation().getLocale()).isEqualTo("de_AT"); + } + + @Test // GH-4462 + void replaceShouldUseHintIfPresent() { + + template.replace(new BasicQuery("{}").withHint("index-to-use"), new MongoTemplateUnitTests.Sith(), org.springframework.data.mongodb.core.ReplaceOptions.replaceOptions().upsert()).subscribe(); + + ArgumentCaptor options = ArgumentCaptor + .forClass(com.mongodb.client.model.ReplaceOptions.class); + verify(collection).replaceOne(any(Bson.class), any(), options.capture()); + + assertThat(options.getValue().getHintString()).isEqualTo("index-to-use"); + } + + @Test // GH-4462 + void replaceShouldApplyWriteConcern() { + + template.setWriteConcernResolver(new WriteConcernResolver() { + public WriteConcern resolve(MongoAction action) { + + assertThat(action.getMongoActionOperation()).isEqualTo(MongoActionOperation.REPLACE); + return WriteConcern.UNACKNOWLEDGED; + } + }); + + template.replace(new BasicQuery("{}").withHint("index-to-use"), new Sith(), org.springframework.data.mongodb.core.ReplaceOptions.replaceOptions().upsert()).subscribe(); + + verify(collection).withWriteConcern(eq(WriteConcern.UNACKNOWLEDGED)); + } + private void stubFindSubscribe(Document document) { Publisher realPublisher = Flux.just(document);