diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index a6ac9c360..11fc19622 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -763,27 +763,33 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { } protected void doInsertAll(Collection listToSave, MongoWriter writer) { - Map> objs = new HashMap>(); - for (T o : listToSave) { + Map> elementsByCollection = new HashMap>(); + + for (T element : listToSave) { + + if (element == null) { + continue; + } + + MongoPersistentEntity entity = mappingContext.getPersistentEntity(element.getClass()); - MongoPersistentEntity entity = mappingContext.getPersistentEntity(o.getClass()); if (entity == null) { - throw new InvalidDataAccessApiUsageException("No Persitent Entity information found for the class " - + o.getClass().getName()); + throw new InvalidDataAccessApiUsageException("No PersistentEntity information found for " + element.getClass()); } + String collection = entity.getCollection(); + List collectionElements = elementsByCollection.get(collection); - List objList = objs.get(collection); - if (null == objList) { - objList = new ArrayList(); - objs.put(collection, objList); + if (null == collectionElements) { + collectionElements = new ArrayList(); + elementsByCollection.put(collection, collectionElements); } - objList.add(o); + collectionElements.add(element); } - for (Map.Entry> entry : objs.entrySet()) { + for (Map.Entry> entry : elementsByCollection.entrySet()) { doInsertBatch(entry.getKey(), entry.getValue(), this.mongoConverter); } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java index a6076cb2f..ded2b9bf7 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java @@ -74,6 +74,7 @@ import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Update; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; import com.mongodb.BasicDBObject; @@ -188,6 +189,7 @@ public class MongoTemplateTests { template.dropCollection(DocumentWithDBRefCollection.class); template.dropCollection(SomeContent.class); template.dropCollection(SomeTemplate.class); + template.dropCollection(Address.class); } @Test @@ -2740,6 +2742,23 @@ public class MongoTemplateTests { assertThat(template.findAll(DBObject.class, "collection"), hasSize(0)); } + /** + * @see DATAMONGO-1207 + */ + @Test + public void ignoresNullElementsForInsertAll() { + + Address newYork = new Address("NY", "New York"); + Address washington = new Address("DC", "Washington"); + + template.insertAll(Arrays.asList(newYork, null, washington)); + + List
result = template.findAll(Address.class); + + assertThat(result, hasSize(2)); + assertThat(result, hasItems(newYork, washington)); + } + static class DoucmentWithNamedIdField { @Id String someIdKey; @@ -2926,6 +2945,41 @@ public class MongoTemplateTests { String state; String city; + + Address() {} + + Address(String state, String city) { + this.state = state; + this.city = city; + } + + @Override + public boolean equals(Object obj) { + + if (obj == this) { + return true; + } + + if (!(obj instanceof Address)) { + return false; + } + + Address that = (Address) obj; + + return ObjectUtils.nullSafeEquals(this.city, that.city) && // + ObjectUtils.nullSafeEquals(this.state, that.state); + } + + @Override + public int hashCode() { + + int result = 17; + + result += 31 * ObjectUtils.nullSafeHashCode(this.city); + result += 31 * ObjectUtils.nullSafeHashCode(this.state); + + return result; + } } static class VersionedPerson {