Browse Source

Add support for `AggregationUpdate` to `BulkOperations`.

We now accept `UpdateDefinition` in `BulkOperations` to support custom update definitions and aggregation updates.

Closes #3872
Original pull request: #4344
pull/4373/head
Christoph Strobl 3 years ago committed by Mark Paluch
parent
commit
0ba857aa22
No known key found for this signature in database
GPG Key ID: 4406B84C1661DCD1
  1. 44
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/BulkOperations.java
  2. 46
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultBulkOperations.java
  3. 51
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultBulkOperationsIntegrationTests.java

44
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/BulkOperations.java

@ -19,6 +19,7 @@ import java.util.List;
import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update; import org.springframework.data.mongodb.core.query.Update;
import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.util.Pair; import org.springframework.data.util.Pair;
import com.mongodb.bulk.BulkWriteResult; import com.mongodb.bulk.BulkWriteResult;
@ -75,7 +76,19 @@ public interface BulkOperations {
* @param update {@link Update} operation to perform, must not be {@literal null}. * @param update {@link Update} operation to perform, must not be {@literal null}.
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}. * @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
*/ */
BulkOperations updateOne(Query query, Update update); default BulkOperations updateOne(Query query, Update update) {
return updateOne(query, (UpdateDefinition) update);
}
/**
* Add a single update to the bulk operation. For the update request, only the first matching document is updated.
*
* @param query update criteria, must not be {@literal null}.
* @param update {@link Update} operation to perform, must not be {@literal null}.
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
* @since 4.1
*/
BulkOperations updateOne(Query query, UpdateDefinition update);
/** /**
* Add a list of updates to the bulk operation. For each update request, only the first matching document is updated. * Add a list of updates to the bulk operation. For each update request, only the first matching document is updated.
@ -92,7 +105,19 @@ public interface BulkOperations {
* @param update Update operation to perform. * @param update Update operation to perform.
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}. * @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
*/ */
BulkOperations updateMulti(Query query, Update update); default BulkOperations updateMulti(Query query, Update update) {
return updateMulti(query, (UpdateDefinition) update);
}
/**
* Add a single update to the bulk operation. For the update request, all matching documents are updated.
*
* @param query Update criteria.
* @param update Update operation to perform.
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
* @since 4.1
*/
BulkOperations updateMulti(Query query, UpdateDefinition update);
/** /**
* Add a list of updates to the bulk operation. For each update request, all matching documents are updated. * Add a list of updates to the bulk operation. For each update request, all matching documents are updated.
@ -110,7 +135,20 @@ public interface BulkOperations {
* @param update Update operation to perform. * @param update Update operation to perform.
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}. * @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
*/ */
BulkOperations upsert(Query query, Update update); default BulkOperations upsert(Query query, Update update) {
return upsert(query, (UpdateDefinition) update);
}
/**
* Add a single upsert to the bulk operation. An upsert is an update if the set of matching documents is not empty,
* else an insert.
*
* @param query Update criteria.
* @param update Update operation to perform.
* @return the current {@link BulkOperations} instance with the update added, will never be {@literal null}.
* @since 4.1
*/
BulkOperations upsert(Query query, UpdateDefinition update);
/** /**
* Add a list of upserts to the bulk operation. An upsert is an update if the set of matching documents is not empty, * Add a list of upserts to the bulk operation. An upsert is an update if the set of matching documents is not empty,

46
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultBulkOperations.java

@ -16,7 +16,6 @@
package org.springframework.data.mongodb.core; package org.springframework.data.mongodb.core;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -25,8 +24,12 @@ import org.bson.Document;
import org.bson.conversions.Bson; import org.bson.conversions.Bson;
import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisher;
import org.springframework.dao.DataIntegrityViolationException; import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.data.mapping.PersistentEntity;
import org.springframework.data.mapping.callback.EntityCallbacks; import org.springframework.data.mapping.callback.EntityCallbacks;
import org.springframework.data.mongodb.BulkOperationException; import org.springframework.data.mongodb.BulkOperationException;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.convert.QueryMapper; import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.convert.UpdateMapper; import org.springframework.data.mongodb.core.convert.UpdateMapper;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
@ -133,12 +136,12 @@ class DefaultBulkOperations implements BulkOperations {
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public BulkOperations updateOne(Query query, Update update) { public BulkOperations updateOne(Query query, UpdateDefinition update) {
Assert.notNull(query, "Query must not be null"); Assert.notNull(query, "Query must not be null");
Assert.notNull(update, "Update must not be null"); Assert.notNull(update, "Update must not be null");
return updateOne(Collections.singletonList(Pair.of(query, update))); return update(query, update, false, false);
} }
@Override @Override
@ -155,12 +158,14 @@ class DefaultBulkOperations implements BulkOperations {
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public BulkOperations updateMulti(Query query, Update update) { public BulkOperations updateMulti(Query query, UpdateDefinition update) {
Assert.notNull(query, "Query must not be null"); Assert.notNull(query, "Query must not be null");
Assert.notNull(update, "Update must not be null"); Assert.notNull(update, "Update must not be null");
return updateMulti(Collections.singletonList(Pair.of(query, update))); update(query, update, false, true);
return this;
} }
@Override @Override
@ -176,7 +181,7 @@ class DefaultBulkOperations implements BulkOperations {
} }
@Override @Override
public BulkOperations upsert(Query query, Update update) { public BulkOperations upsert(Query query, UpdateDefinition update) {
return update(query, update, true, true); return update(query, update, true, true);
} }
@ -294,7 +299,7 @@ class DefaultBulkOperations implements BulkOperations {
maybeInvokeBeforeSaveCallback(it.getSource(), target); maybeInvokeBeforeSaveCallback(it.getSource(), target);
} }
return mapWriteModel(it.getModel()); return mapWriteModel(it.getSource(), it.getModel());
} }
/** /**
@ -306,7 +311,7 @@ class DefaultBulkOperations implements BulkOperations {
* @param multi whether to issue a multi-update. * @param multi whether to issue a multi-update.
* @return the {@link BulkOperations} with the update registered. * @return the {@link BulkOperations} with the update registered.
*/ */
private BulkOperations update(Query query, Update update, boolean upsert, boolean multi) { private BulkOperations update(Query query, UpdateDefinition update, boolean upsert, boolean multi) {
Assert.notNull(query, "Query must not be null"); Assert.notNull(query, "Query must not be null");
Assert.notNull(update, "Update must not be null"); Assert.notNull(update, "Update must not be null");
@ -322,11 +327,16 @@ class DefaultBulkOperations implements BulkOperations {
return this; return this;
} }
private WriteModel<Document> mapWriteModel(WriteModel<Document> writeModel) { private WriteModel<Document> mapWriteModel(Object source, WriteModel<Document> writeModel) {
if (writeModel instanceof UpdateOneModel) { if (writeModel instanceof UpdateOneModel) {
UpdateOneModel<Document> model = (UpdateOneModel<Document>) writeModel; UpdateOneModel<Document> model = (UpdateOneModel<Document>) writeModel;
if (source instanceof AggregationUpdate aggregationUpdate) {
List<Document> pipeline = mapUpdatePipeline(aggregationUpdate);
return new UpdateOneModel<>(getMappedQuery(model.getFilter()), pipeline, model.getOptions());
}
return new UpdateOneModel<>(getMappedQuery(model.getFilter()), getMappedUpdate(model.getUpdate()), return new UpdateOneModel<>(getMappedQuery(model.getFilter()), getMappedUpdate(model.getUpdate()),
model.getOptions()); model.getOptions());
@ -335,6 +345,11 @@ class DefaultBulkOperations implements BulkOperations {
if (writeModel instanceof UpdateManyModel) { if (writeModel instanceof UpdateManyModel) {
UpdateManyModel<Document> model = (UpdateManyModel<Document>) writeModel; UpdateManyModel<Document> model = (UpdateManyModel<Document>) writeModel;
if (source instanceof AggregationUpdate aggregationUpdate) {
List<Document> pipeline = mapUpdatePipeline(aggregationUpdate);
return new UpdateManyModel<>(getMappedQuery(model.getFilter()), pipeline, model.getOptions());
}
return new UpdateManyModel<>(getMappedQuery(model.getFilter()), getMappedUpdate(model.getUpdate()), return new UpdateManyModel<>(getMappedQuery(model.getFilter()), getMappedUpdate(model.getUpdate()),
model.getOptions()); model.getOptions());
@ -357,6 +372,19 @@ class DefaultBulkOperations implements BulkOperations {
return writeModel; return writeModel;
} }
private List<Document> mapUpdatePipeline(AggregationUpdate source) {
Class<?> type = bulkOperationContext.getEntity().isPresent()
? bulkOperationContext.getEntity().map(PersistentEntity::getType).get()
: Object.class;
AggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext(type,
bulkOperationContext.getUpdateMapper().getMappingContext(), bulkOperationContext.getQueryMapper());
List<Document> pipeline = new AggregationUtil(bulkOperationContext.getQueryMapper(),
bulkOperationContext.getQueryMapper().getMappingContext()).createPipeline(source,
context);
return pipeline;
}
private Bson getMappedUpdate(Bson update) { private Bson getMappedUpdate(Bson update) {
return bulkOperationContext.getUpdateMapper().getMappedObject(update, bulkOperationContext.getEntity()); return bulkOperationContext.getUpdateMapper().getMappedObject(update, bulkOperationContext.getEntity());
} }

51
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultBulkOperationsIntegrationTests.java

@ -21,20 +21,27 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Stream;
import com.mongodb.bulk.BulkWriteResult;
import org.bson.Document; import org.bson.Document;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.data.mongodb.BulkOperationException; import org.springframework.data.mongodb.BulkOperationException;
import org.springframework.data.mongodb.core.BulkOperations.BulkMode; import org.springframework.data.mongodb.core.BulkOperations.BulkMode;
import org.springframework.data.mongodb.core.DefaultBulkOperations.BulkOperationContext; import org.springframework.data.mongodb.core.DefaultBulkOperations.BulkOperationContext;
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
import org.springframework.data.mongodb.core.convert.QueryMapper; import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.convert.UpdateMapper; import org.springframework.data.mongodb.core.convert.UpdateMapper;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update; import org.springframework.data.mongodb.core.query.Update;
import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.mongodb.test.util.MongoTemplateExtension; import org.springframework.data.mongodb.test.util.MongoTemplateExtension;
import org.springframework.data.mongodb.test.util.MongoTestTemplate; import org.springframework.data.mongodb.test.util.MongoTestTemplate;
import org.springframework.data.mongodb.test.util.Template; import org.springframework.data.mongodb.test.util.Template;
@ -135,13 +142,14 @@ public class DefaultBulkOperationsIntegrationTests {
}); });
} }
@Test // DATAMONGO-934 @ParameterizedTest // DATAMONGO-934, GH-3872
public void upsertDoesUpdate() { @MethodSource("upsertArguments")
void upsertDoesUpdate(UpdateDefinition update) {
insertSomeDocuments(); insertSomeDocuments();
com.mongodb.bulk.BulkWriteResult result = createBulkOps(BulkMode.ORDERED).// com.mongodb.bulk.BulkWriteResult result = createBulkOps(BulkMode.ORDERED).//
upsert(where("value", "value1"), set("value", "value2")).// upsert(where("value", "value1"), update).//
execute(); execute();
assertThat(result).isNotNull(); assertThat(result).isNotNull();
@ -152,11 +160,12 @@ public class DefaultBulkOperationsIntegrationTests {
assertThat(result.getUpserts().size()).isZero(); assertThat(result.getUpserts().size()).isZero();
} }
@Test // DATAMONGO-934 @ParameterizedTest // DATAMONGO-934, GH-3872
public void upsertDoesInsert() { @MethodSource("upsertArguments")
void upsertDoesInsert(UpdateDefinition update) {
com.mongodb.bulk.BulkWriteResult result = createBulkOps(BulkMode.ORDERED).// com.mongodb.bulk.BulkWriteResult result = createBulkOps(BulkMode.ORDERED).//
upsert(where("_id", "1"), set("value", "v1")).// upsert(where("_id", "1"), update).//
execute(); execute();
assertThat(result).isNotNull(); assertThat(result).isNotNull();
@ -171,11 +180,37 @@ public class DefaultBulkOperationsIntegrationTests {
testUpdate(BulkMode.ORDERED, false, 2); testUpdate(BulkMode.ORDERED, false, 2);
} }
@Test // GH-3872
public void updateOneWithAggregation() {
insertSomeDocuments();
BulkOperations bulkOps = createBulkOps(BulkMode.ORDERED);
bulkOps.updateOne(where("value", "value1"), AggregationUpdate.update().set("value").toValue("value3"));
BulkWriteResult result = bulkOps.execute();
assertThat(result.getModifiedCount()).isEqualTo(1);
assertThat(operations.<Long>execute(COLLECTION_NAME, collection -> collection.countDocuments(new org.bson.Document("value", "value3")))).isOne();
}
@Test // DATAMONGO-934 @Test // DATAMONGO-934
public void updateMultiOrdered() { public void updateMultiOrdered() {
testUpdate(BulkMode.ORDERED, true, 4); testUpdate(BulkMode.ORDERED, true, 4);
} }
@Test // GH-3872
public void updateMultiWithAggregation() {
insertSomeDocuments();
BulkOperations bulkOps = createBulkOps(BulkMode.ORDERED);
bulkOps.updateMulti(where("value", "value1"), AggregationUpdate.update().set("value").toValue("value3"));
BulkWriteResult result = bulkOps.execute();
assertThat(result.getModifiedCount()).isEqualTo(2);
assertThat(operations.<Long>execute(COLLECTION_NAME, collection -> collection.countDocuments(new org.bson.Document("value", "value3")))).isEqualTo(2);
}
@Test // DATAMONGO-934 @Test // DATAMONGO-934
public void updateOneUnOrdered() { public void updateOneUnOrdered() {
testUpdate(BulkMode.UNORDERED, false, 2); testUpdate(BulkMode.UNORDERED, false, 2);
@ -355,6 +390,10 @@ public class DefaultBulkOperationsIntegrationTests {
coll.insertOne(rawDoc("4", "value2")); coll.insertOne(rawDoc("4", "value2"));
} }
private static Stream<Arguments> upsertArguments() {
return Stream.of(Arguments.of(set("value", "value2")), Arguments.of(AggregationUpdate.update().set("value").toValue("value2")));
}
private static BaseDoc newDoc(String id) { private static BaseDoc newDoc(String id) {
BaseDoc doc = new BaseDoc(); BaseDoc doc = new BaseDoc();

Loading…
Cancel
Save