From 81fc166a48ff46be6a77d84b3a291e3f0d51ef7d Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Tue, 17 Dec 2024 11:01:37 +0100 Subject: [PATCH] some hacking --- .../data/mongodb/core/MongoTemplate.java | 176 +++++++++++++++ .../core/convert/AbstractMongoConverter.java | 2 + .../mongodb/core/convert/MongoConverters.java | 14 ++ .../mongodb/core/MongoTemplateBulkTests.java | 210 ++++++++++++++++++ .../mongodb/test/util/MongoTestTemplate.java | 93 ++++++++ 5 files changed, 495 insertions(+) create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateBulkTests.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 6fa64a484..0fcf3841d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -26,6 +26,7 @@ import java.util.function.BiPredicate; import java.util.stream.Collectors; import java.util.stream.Stream; +import com.mongodb.bulk.BulkWriteResult; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.bson.Document; @@ -41,6 +42,7 @@ import org.springframework.context.ConfigurableApplicationContext; import org.springframework.core.io.Resource; import org.springframework.core.io.ResourceLoader; import org.springframework.dao.DataAccessException; +import org.springframework.dao.DataIntegrityViolationException; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.dao.support.PersistenceExceptionTranslator; @@ -1460,6 +1462,180 @@ public class MongoTemplate : (T) doSave(collectionName, objectToSave, this.mongoConverter); } + public Collection saveAll(Collection objectsToSave) { + + List> saves = new ArrayList<>(objectsToSave.size()); + int requiredUpdateCount = -1; + int replaceCount = 0; + Class type = null; + List inserts = new ArrayList<>(objectsToSave.size()); + for(T entity : objectsToSave) { + + if(type == null) { + type = entity.getClass(); + } + AdaptibleEntity source = operations.forEntity(entity, mongoConverter.getConversionService()); + + if(source.isVersionedEntity() && !source.isNew()) { + + Query query = source.getQueryForVersion(); + + // Bump version number + T toSave = source.incrementVersion(); + MappedDocument mapped = source.toMappedDocument(mongoConverter); + UpdateDefinition update = mapped.updateWithoutId(); + + MongoPersistentEntity persistentEntity = mappingContext.getPersistentEntity(entity.getClass()); + Document queryObj = queryMapper.getMappedObject(query.getQueryObject(), persistentEntity); + Document updateObj = updateMapper.getMappedObject(update.getUpdateObject(), persistentEntity); + + saves.add(new UpdateOneModel<>(queryObj, updateObj)); + if(requiredUpdateCount < 0) { + requiredUpdateCount = 1; + } else { + requiredUpdateCount++; + } + } + if (source.isNew()) { + Document target = new Document(); + mongoConverter.write(entity, target); + saves.add(new InsertOneModel<>(target)); + inserts.add(entity); + } else { + Document target = new Document(); + mongoConverter.write(entity, target); + Document queryObj = queryMapper.getMappedObject(source.getByIdQuery().getQueryObject(), mappingContext.getPersistentEntity(entity.getClass())); + saves.add(new ReplaceOneModel<>(queryObj, target, new com.mongodb.client.model.ReplaceOptions().upsert(true))); + replaceCount++; + } + } + + BulkWriteResult result = execute(type, collection -> { + return collection.bulkWrite(saves, new BulkWriteOptions().ordered(true)); + }); + + if(requiredUpdateCount > 0) { + + + if(result.getMatchedCount() != (replaceCount + requiredUpdateCount)) { + throw new DataIntegrityViolationException("Holy Moly, Batman!"); + } + } + + + if(!inserts.isEmpty()) { + for(int i = 0;i) objectsToSave; + } + + public Collection saveAllBulkMongoDB8(Collection objectsToSave) { + + List saves = new ArrayList<>(objectsToSave.size()); + List namespaces = new ArrayList<>(objectsToSave.size()); + + + int requiredUpdateCount = -1; + int replaceCount = 0; + Class type = null; + List inserts = new ArrayList<>(objectsToSave.size()); + for(T entity : objectsToSave) { + + if(type == null) { + type = entity.getClass(); + } + + MongoPersistentEntity persistentEntity = mappingContext.getPersistentEntity(entity.getClass()); + if(!namespaces.contains(persistentEntity.getCollection())) { + namespaces.add(persistentEntity.getCollection()); + } + int nsIndex = namespaces.indexOf(persistentEntity.getCollection()); + + AdaptibleEntity source = operations.forEntity(entity, mongoConverter.getConversionService()); + + if(source.isVersionedEntity() && !source.isNew()) { + + Query query = source.getQueryForVersion(); + + // Bump version number + T toSave = source.incrementVersion(); + MappedDocument mapped = source.toMappedDocument(mongoConverter); + UpdateDefinition update = mapped.updateWithoutId(); + + Document queryObj = queryMapper.getMappedObject(query.getQueryObject(), persistentEntity); + Document updateObj = updateMapper.getMappedObject(update.getUpdateObject(), persistentEntity); + + + // meh - there's no replace op + Document saveOp = new Document("update", nsIndex) + .append("filter", queryObj) + .append("multi", false) + .append("updateMods", updateObj); + saves.add(saveOp); + + if(requiredUpdateCount < 0) { + requiredUpdateCount = 1; + } else { + requiredUpdateCount++; + } + } + if (source.isNew()) { + Document target = new Document(); + mongoConverter.write(entity, target); + + Document saveOp = new Document("insert", nsIndex) + .append("document", target); + saves.add(saveOp); + inserts.add(entity); + } else { + Document target = new Document(); + mongoConverter.write(entity, target); + Document queryObj = queryMapper.getMappedObject(source.getByIdQuery().getQueryObject(), mappingContext.getPersistentEntity(entity.getClass())); + + + Document saveOp = new Document("update", nsIndex) + .append("filter", queryObj) + .append("multi", false) + .append("updateMods", target); + saves.add(saveOp); + replaceCount++; + } + } + + BulkWriteResult result = execute(db -> { + db.runCommand(new Document()); + return null; + }); + + if(requiredUpdateCount > 0) { + + + if(result.getMatchedCount() != (replaceCount + requiredUpdateCount)) { + throw new DataIntegrityViolationException("Holy Moly, Batman!"); + } + } + + + if(!inserts.isEmpty()) { + for(int i = 0;i) objectsToSave; + } + @SuppressWarnings("unchecked") private T doSaveVersioned(AdaptibleEntity source, String collectionName) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/AbstractMongoConverter.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/AbstractMongoConverter.java index 8936074ba..06c2e178f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/AbstractMongoConverter.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/AbstractMongoConverter.java @@ -28,6 +28,7 @@ import org.springframework.data.convert.ConverterBuilder; import org.springframework.data.convert.CustomConversions; import org.springframework.data.mapping.model.EntityInstantiators; import org.springframework.data.mongodb.core.convert.MongoConverters.BigIntegerToObjectIdConverter; +import org.springframework.data.mongodb.core.convert.MongoConverters.BsonObjectIdToStringConverter; import org.springframework.data.mongodb.core.convert.MongoConverters.ObjectIdToBigIntegerConverter; import org.springframework.data.mongodb.core.convert.MongoConverters.ObjectIdToStringConverter; import org.springframework.data.mongodb.core.convert.MongoConverters.StringToObjectIdConverter; @@ -86,6 +87,7 @@ public abstract class AbstractMongoConverter implements MongoConverter, Initiali private void initializeConverters() { conversionService.addConverter(ObjectIdToStringConverter.INSTANCE); + conversionService.addConverter(BsonObjectIdToStringConverter.INSTANCE); conversionService.addConverter(StringToObjectIdConverter.INSTANCE); if (!conversionService.canConvert(ObjectId.class, BigInteger.class)) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java index d39d1f5e9..b92b52ea4 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java @@ -31,6 +31,7 @@ import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import org.bson.BsonObjectId; import org.bson.BsonReader; import org.bson.BsonTimestamp; import org.bson.BsonUndefined; @@ -125,6 +126,19 @@ abstract class MongoConverters { } } + /** + * Simple singleton to convert {@link ObjectId}s to their {@link String} representation. + * + * @author Oliver Gierke + */ + enum BsonObjectIdToStringConverter implements Converter { + INSTANCE; + + public String convert(BsonObjectId id) { + return id.getValue().toString(); + } + } + /** * Simple singleton to convert {@link String}s to their {@link ObjectId} representation. * diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateBulkTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateBulkTests.java new file mode 100644 index 000000000..4b4508bd4 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateBulkTests.java @@ -0,0 +1,210 @@ +/* + * Copyright 2024. the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core; + +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import org.assertj.core.api.InstanceOfAssertFactories; +import org.bson.types.ObjectId; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.data.annotation.Version; +import org.springframework.data.auditing.IsNewAwareAuditingHandler; +import org.springframework.data.mapping.context.PersistentEntities; +import org.springframework.data.mongodb.core.MongoTemplateTests.PersonWithIdPropertyOfTypeUUIDListener; +import org.springframework.data.mongodb.core.mapping.Document; +import org.springframework.data.mongodb.test.util.Client; +import org.springframework.data.mongodb.test.util.MongoClientExtension; +import org.springframework.data.mongodb.test.util.MongoTestTemplate; + +import com.mongodb.client.MongoClient; + +/** + * @author Christoph Strobl + */ +@ExtendWith(MongoClientExtension.class) +public class MongoTemplateBulkTests { + + public static final String DB_NAME = "mongo-template-bulk-tests"; + + static @Client MongoClient client; + + ConfigurableApplicationContext context = new GenericApplicationContext(); + + MongoTestTemplate template = new MongoTestTemplate(cfg -> { + + cfg.configureDatabaseFactory(it -> { + + it.client(client); + it.defaultDb(DB_NAME); + }); + + cfg.configureMappingContext(it -> { + it.autocreateIndex(false); + it.initialEntitySet(AuditablePerson.class); + }); + + cfg.configureApplicationContext(it -> { + it.applicationContext(context); + it.addEventListener(new PersonWithIdPropertyOfTypeUUIDListener()); + }); + + cfg.configureAuditing(it -> { + it.auditingHandler(ctx -> { + return new IsNewAwareAuditingHandler(PersistentEntities.of(ctx)); + }); + }); + }); + + @BeforeEach + void beforeEach() { + template.flush(SimpleEntity.class); + } + + @Test + void justSimpleNew() { + + List entities = simpleEntities(5); + template.saveAll(entities); + + template.verify().collection(SimpleEntity.class).hasSize(5).documentsSatisfy(document -> { + assertThat(document) // + .hasEntrySatisfying("_id", value -> assertThat(value).isInstanceOf(ObjectId.class)) // + .hasEntrySatisfying("name", + value -> assertThat(value).asInstanceOf(InstanceOfAssertFactories.STRING).startsWith("name-")); + }); + + assertThat(entities).map(SimpleEntity::getId).allMatch(ObjectId::isValid); + } + + @Test + void justSimpleReplace() { + + List entities = simpleEntities(5).stream() + .peek(entity -> entity.id = "%s".formatted(entity.name.replace("name", "id"))).collect(Collectors.toList()); + template.saveAll(entities); + + template.verify().collection(SimpleEntity.class).hasSize(5).documentsSatisfy(document -> { + assertThat(document) // + .hasEntrySatisfying("_id", value -> assertThat(value).isInstanceOf(String.class)) // + .hasEntrySatisfying("name", + value -> assertThat(value).asInstanceOf(InstanceOfAssertFactories.STRING).startsWith("name-")); + }); + } + + @Test + void mixedNewReplace() { + int i = 0; + + List entities = simpleEntities(5); + for (SimpleEntity entity : entities) { + if (i % 2 == 0) { + entity.id = "%s".formatted(entity.name.replace("name", "id")); + } + i++; + } + template.saveAll(entities); + + template.verify().collection(SimpleEntity.class).documents().atPosition(0).satisfies(document -> { + assertThat(document.get("_id")).isInstanceOf(String.class); // + }).atPosition(1).satisfies(document -> { + assertThat(document.get("_id")).isInstanceOf(ObjectId.class); // + }).atPosition(2).satisfies(document -> { + assertThat(document.get("_id")).isInstanceOf(String.class); // + }); + } + + @Test + void replaceExisting() { + int i = 0; + + SimpleEntity e1 = new SimpleEntity(); + e1.id = "id-1"; + e1.name = "name-1"; + + SimpleEntity e2 = new SimpleEntity(); + e2.id = "id-2"; + e2.name = "name-2"; + + template.saveAll(List.of(e1, e2)); + + e1.name = "name-11"; + + template.saveAll(List.of(e1, e2)); + + template.verify().collection(SimpleEntity.class).documents().hasSize(2) // + .atPosition(0).satisfies(document -> { + assertThat(document.get("name")).isEqualTo("name-11"); // + }).atPosition(1).satisfies(document -> { + assertThat(document.get("name")).isEqualTo("name-2");// + }); + } + + List simpleEntities(int count) { + + List entities = new ArrayList<>(count); + + for (int i = 0; i < count; i++) { + SimpleEntity simpleEntity = new SimpleEntity(); + simpleEntity.name = "name-%s".formatted(i); + entities.add(simpleEntity); + } + return entities; + } + + @Document("simple") + static class SimpleEntity { + + String id; + String name; + + public String getId() { + return id; + } + } + + @Document("versioned") + static class VersionedEntity { + + String id; + @Version Long version; + String name; + } + +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java index 1b4c3a1e2..47d53e1d8 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java @@ -15,11 +15,16 @@ */ package org.springframework.data.mongodb.test.util; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; + +import java.util.ArrayList; import java.util.Arrays; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Collectors; +import com.mongodb.client.MongoCursor; +import org.assertj.core.api.Assertions; import org.bson.Document; import org.springframework.context.ApplicationContext; import org.springframework.data.mapping.callback.EntityCallbacks; @@ -154,4 +159,92 @@ public class MongoTestTemplate extends MongoTemplate { return null; })); } + + public DBVerifyer verify() { + return new DBVerifyer(this); + } + + public static class DBVerifyer { + + MongoTemplate template; + + public DBVerifyer(MongoTemplate template) { + this.template = template; + } + + public CollectionVerifyer collection(String collection) { + return new CollectionVerifyer(template.getCollection(collection)); + } + + public CollectionVerifyer collection(Class type) { + return collection(template.getCollectionName(type)); + } + + } + + public static class CollectionVerifyer { + + MongoCollection collection; + + public CollectionVerifyer(MongoCollection collection) { + this.collection = collection; + } + + public CollectionVerifyer hasSize(long expectedSize) { + + Assertions.assertThat(this.collection.countDocuments()).isEqualTo(expectedSize); + return this; + } + + public CollectionVerifyer documentsSatisfy(Consumer sink) { + + try (MongoCursor iterator = this.collection.find().iterator()) { + while (iterator.hasNext()) { + sink.accept(iterator.next()); + } + } + return this; + } + + public DocumentsVerifyer documents() { + ArrayList documents = new ArrayList<>(); + this.collection.find().into(documents); + return new DocumentsVerifyer(documents); + } + } + + public static class DocumentsVerifyer { + + private final ArrayList documents; + + public DocumentsVerifyer(ArrayList documents) { + this.documents = documents; + } + + public DocumentVerifyer atPosition(int index) { + return new DocumentVerifyer(this, documents.get(index)); + } + + public DocumentsVerifyer hasSize(int expectedSize) { + + assertThat(documents).hasSize(expectedSize); + return this; + } + } + + public static class DocumentVerifyer { + DocumentsVerifyer source; + Document doc; + + public DocumentVerifyer(DocumentsVerifyer source, Document doc) { + this.source = source; + this.doc = doc; + } + + public DocumentsVerifyer satisfies(Consumer sink) { + sink.accept(doc); + return source; + } + } + }