diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index e7419a06f..762dfb363 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -932,7 +932,9 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { maybeEmitEvent(new BeforeSaveEvent(o, dbDoc, collectionName)); dbObjectList.add(dbDoc); } - List ids = insertDBObjectList(collectionName, dbObjectList); + + List ids = consolidateIdentifiers(insertDBObjectList(collectionName, dbObjectList), dbObjectList); + int i = 0; for (T obj : batchToSave) { if (i < ids.size()) { @@ -1037,6 +1039,8 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { }); } + // TODO: 2.0 - Change method signature to return List and return all identifiers (DATAMONGO-1513, + // DATAMONGO-1519) protected List insertDBObjectList(final String collectionName, final List dbDocList) { if (dbDocList.isEmpty()) { return Collections.emptyList(); @@ -2115,6 +2119,28 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { return resolved == null ? ex : resolved; } + /** + * Returns all identifiers for the given documents. Will augment the given identifiers and fill in only the ones that + * are {@literal null} currently. This would've been better solved in {@link #insertDBObjectList(String, List)} + * directly but would require a signature change of that method. + * + * @param ids + * @param documents + * @return TODO: Remove for 2.0 and change method signature of {@link #insertDBObjectList(String, List)}. + */ + private static List consolidateIdentifiers(List ids, List documents) { + + List result = new ArrayList(ids.size()); + + for (int i = 0; i < ids.size(); i++) { + + ObjectId objectId = ids.get(i); + result.add(objectId == null ? documents.get(i).get(ID_FIELD) : objectId); + } + + return result; + } + // Callback implementations /** diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java index 86ce2a08b..62f51e5a7 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java @@ -33,6 +33,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.UUID; import org.bson.types.ObjectId; import org.joda.time.DateTime; @@ -43,6 +44,7 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ConfigurableApplicationContext; import org.springframework.core.convert.converter.Converter; import org.springframework.dao.DataAccessException; import org.springframework.dao.DataIntegrityViolationException; @@ -70,11 +72,14 @@ import org.springframework.data.mongodb.core.index.IndexField; import org.springframework.data.mongodb.core.index.IndexInfo; import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; +import org.springframework.data.mongodb.core.mapping.event.AbstractMongoEventListener; +import org.springframework.data.mongodb.core.mapping.event.BeforeSaveEvent; import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Update; import org.springframework.data.util.CloseableIterator; +import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.util.ObjectUtils; @@ -114,6 +119,7 @@ public class MongoTemplateTests { @Autowired MongoTemplate template; @Autowired MongoDbFactory factory; + @Autowired ConfigurableApplicationContext context; MongoTemplate mappingTemplate; org.springframework.data.util.Version mongoVersion; @@ -3164,6 +3170,28 @@ public class MongoTemplateTests { assertThat(template.findOne(query(where("id").is(wgj.id)), WithGeoJson.class).point, is(equalTo(wgj.point))); } + /** + * @see DATAMONGO-1513 + */ + @Test + @DirtiesContext + public void populatesIdsAddedByEventListener() { + + context.addApplicationListener(new AbstractMongoEventListener() { + + @Override + public void onBeforeSave(BeforeSaveEvent event) { + event.getDBObject().put("_id", UUID.randomUUID().toString()); + } + }); + + Document document = new Document(); + + template.insertAll(Arrays.asList(document)); + + assertThat(document.id, is(notNullValue())); + } + static class DoucmentWithNamedIdField { @Id String someIdKey;