diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/document/mongodb/query/Update.java b/spring-data-mongodb/src/main/java/org/springframework/data/document/mongodb/query/Update.java index 9abf2bac9..3b6e6e98a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/document/mongodb/query/Update.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/document/mongodb/query/Update.java @@ -15,175 +15,197 @@ */ package org.springframework.data.document.mongodb.query; -import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; +import org.springframework.dao.InvalidDataAccessApiUsageException; + import com.mongodb.BasicDBObject; import com.mongodb.DBObject; public class Update { - public enum Position { - LAST, FIRST - } - - private HashMap criteria = new LinkedHashMap(); - - /** - * Static factory method to create an Update using the provided key - * - * @param key - * @return - */ - public static Update update(String key, Object value) { - return new Update().set(key, value); - } - - /** - * Update using the $set update modifier - * - * @param key - * @param value - * @return - */ - public Update set(String key, Object value) { - criteria.put("$set", Collections.singletonMap(key, convertValueIfNecessary(value))); - return this; - } - - /** - * Update using the $unset update modifier - * - * @param key - * @return - */ - public Update unset(String key) { - criteria.put("$unset", Collections.singletonMap(key, 1)); - return this; - } - - /** - * Update using the $inc update modifier - * - * @param key - * @param inc - * @return - */ - public Update inc(String key, Number inc) { - criteria.put("$inc", Collections.singletonMap(key, inc)); - return this; - } - - /** - * Update using the $push update modifier - * - * @param key - * @param value - * @return - */ - public Update push(String key, Object value) { - criteria.put("$push", Collections.singletonMap(key, convertValueIfNecessary(value))); - return this; - } - - /** - * Update using the $pushAll update modifier - * - * @param key - * @param values - * @return - */ - public Update pushAll(String key, Object[] values) { - Object[] convertedValues = new Object[values.length]; - for (int i = 0; i < values.length; i++) { - convertedValues[i] = convertValueIfNecessary(values[i]); - } - DBObject keyValue = new BasicDBObject(); - keyValue.put(key, convertedValues); - criteria.put("$pushAll", keyValue); - return this; - } - - /** - * Update using the $addToSet update modifier - * - * @param key - * @param value - * @return - */ - public Update addToSet(String key, Object value) { - criteria.put("$addToSet", Collections.singletonMap(key, convertValueIfNecessary(value))); - return this; - } - - /** - * Update using the $pop update modifier - * - * @param key - * @param pos - * @return - */ - public Update pop(String key, Position pos) { - criteria.put("$pop", Collections.singletonMap(key, (pos == Position.FIRST ? -1 : 1))); - return this; - } - - /** - * Update using the $pull update modifier - * - * @param key - * @param value - * @return - */ - public Update pull(String key, Object value) { - criteria.put("$pull", Collections.singletonMap(key, convertValueIfNecessary(value))); - return this; - } - - /** - * Update using the $pullAll update modifier - * - * @param key - * @param values - * @return - */ - public Update pullAll(String key, Object[] values) { - Object[] convertedValues = new Object[values.length]; - for (int i = 0; i < values.length; i++) { - convertedValues[i] = convertValueIfNecessary(values[i]); - } - DBObject keyValue = new BasicDBObject(); - keyValue.put(key, convertedValues); - criteria.put("$pullAll", keyValue); - return this; - } - - /** - * Update using the $rename update modifier - * - * @param oldName - * @param newName - * @return - */ - public Update rename(String oldName, String newName) { - criteria.put("$rename", Collections.singletonMap(oldName, newName)); - return this; - } - - public DBObject getUpdateObject() { - DBObject dbo = new BasicDBObject(); - for (String k : criteria.keySet()) { - dbo.put(k, criteria.get(k)); - } - return dbo; - } - - protected Object convertValueIfNecessary(Object value) { - if (value instanceof Enum) { - return ((Enum) value).name(); - } - return value; - } + public enum Position { + LAST, FIRST + } + + private HashMap modifierOps = new LinkedHashMap(); + + /** + * Static factory method to create an Update using the provided key + * + * @param key + * @return + */ + public static Update update(String key, Object value) { + return new Update().set(key, value); + } + + /** + * Update using the $set update modifier + * + * @param key + * @param value + * @return + */ + public Update set(String key, Object value) { + addMultiFieldOperation("$set", key, convertValueIfNecessary(value)); + return this; + } + + /** + * Update using the $unset update modifier + * + * @param key + * @return + */ + public Update unset(String key) { + addMultiFieldOperation("$unset", key, 1); + return this; + } + + /** + * Update using the $inc update modifier + * + * @param key + * @param inc + * @return + */ + public Update inc(String key, Number inc) { + addMultiFieldOperation("$inc", key, inc); + return this; + } + + /** + * Update using the $push update modifier + * + * @param key + * @param value + * @return + */ + public Update push(String key, Object value) { + addMultiFieldOperation("$push", key, convertValueIfNecessary(value)); + return this; + } + + /** + * Update using the $pushAll update modifier + * + * @param key + * @param values + * @return + */ + public Update pushAll(String key, Object[] values) { + Object[] convertedValues = new Object[values.length]; + for (int i = 0; i < values.length; i++) { + convertedValues[i] = convertValueIfNecessary(values[i]); + } + DBObject keyValue = new BasicDBObject(); + keyValue.put(key, convertedValues); + modifierOps.put("$pushAll", keyValue); + return this; + } + + /** + * Update using the $addToSet update modifier + * + * @param key + * @param value + * @return + */ + public Update addToSet(String key, Object value) { + addMultiFieldOperation("$addToSet", key, convertValueIfNecessary(value)); + return this; + } + + /** + * Update using the $pop update modifier + * + * @param key + * @param pos + * @return + */ + public Update pop(String key, Position pos) { + addMultiFieldOperation("$pop", key, + (pos == Position.FIRST ? -1 : 1)); + return this; + } + + /** + * Update using the $pull update modifier + * + * @param key + * @param value + * @return + */ + public Update pull(String key, Object value) { + addMultiFieldOperation("$pull", key, convertValueIfNecessary(value)); + return this; + } + + /** + * Update using the $pullAll update modifier + * + * @param key + * @param values + * @return + */ + public Update pullAll(String key, Object[] values) { + Object[] convertedValues = new Object[values.length]; + for (int i = 0; i < values.length; i++) { + convertedValues[i] = convertValueIfNecessary(values[i]); + } + DBObject keyValue = new BasicDBObject(); + keyValue.put(key, convertedValues); + modifierOps.put("$pullAll", keyValue); + return this; + } + + /** + * Update using the $rename update modifier + * + * @param oldName + * @param newName + * @return + */ + public Update rename(String oldName, String newName) { + addMultiFieldOperation("$rename", oldName, newName); + return this; + } + + public DBObject getUpdateObject() { + DBObject dbo = new BasicDBObject(); + for (String k : modifierOps.keySet()) { + dbo.put(k, modifierOps.get(k)); + } + return dbo; + } + + @SuppressWarnings("unchecked") + protected void addMultiFieldOperation(String operator, String key, + Object value) { + Object existingValue = this.modifierOps.get(operator); + LinkedHashMap keyValueMap; + if (existingValue == null) { + keyValueMap = new LinkedHashMap(); + this.modifierOps.put(operator, keyValueMap); + } else { + if (existingValue instanceof LinkedHashMap) { + keyValueMap = (LinkedHashMap) existingValue; + } + else { + throw new InvalidDataAccessApiUsageException("Modifier Operations should be a LinkedHashMap but was " + + existingValue.getClass()); + } + } + keyValueMap.put(key, value); + } + + protected Object convertValueIfNecessary(Object value) { + if (value instanceof Enum) { + return ((Enum) value).name(); + } + return value; + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/document/mongodb/MongoTemplateTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/document/mongodb/MongoTemplateTests.java index f1421775e..b77cb356d 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/document/mongodb/MongoTemplateTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/document/mongodb/MongoTemplateTests.java @@ -15,7 +15,12 @@ */ package org.springframework.data.document.mongodb; -import static org.hamcrest.Matchers.*; +import static org.hamcrest.Matchers.endsWith; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isOneOf; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertThat; import static org.springframework.data.document.mongodb.query.Criteria.where; @@ -23,9 +28,6 @@ import java.util.Arrays; import java.util.HashSet; import java.util.List; -import com.mongodb.DBCollection; -import com.mongodb.DBObject; -import com.mongodb.Mongo; import org.bson.types.ObjectId; import org.junit.Before; import org.junit.Rule; @@ -37,11 +39,20 @@ import org.springframework.dao.DataIntegrityViolationException; import org.springframework.data.document.mongodb.convert.MappingMongoConverter; import org.springframework.data.document.mongodb.convert.MongoConverter; import org.springframework.data.document.mongodb.mapping.MongoMappingContext; -import org.springframework.data.document.mongodb.query.*; +import org.springframework.data.document.mongodb.query.Criteria; +import org.springframework.data.document.mongodb.query.Index; import org.springframework.data.document.mongodb.query.Index.Duplicates; +import org.springframework.data.document.mongodb.query.Order; +import org.springframework.data.document.mongodb.query.Query; +import org.springframework.data.document.mongodb.query.Update; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import com.mongodb.DBCollection; +import com.mongodb.DBObject; +import com.mongodb.Mongo; +import com.mongodb.WriteResult; + /** * Integration test for {@link MongoTemplate}. * @@ -339,6 +350,38 @@ public class MongoTemplateTests { } } + @Test + public void testUsingUpdateWithMultipleSet() throws Exception { + + template.remove(new Query(), PersonWithIdPropertyOfTypeObjectId.class); + + PersonWithIdPropertyOfTypeObjectId p1 = new PersonWithIdPropertyOfTypeObjectId(); + p1.setFirstName("Sven"); + p1.setAge(11); + template.insert("springdata", p1); + PersonWithIdPropertyOfTypeObjectId p2 = new PersonWithIdPropertyOfTypeObjectId(); + p2.setFirstName("Mary"); + p2.setAge(21); + template.insert("springdata", p2); + + Update u = new Update().set("firstName", "Bob").set("age", 10); + + WriteResult wr = template.updateMulti("springdata", new Query(), u); + + assertThat(wr.getN(), is(2)); + + Query q1 = new Query(Criteria.where("age").in(11, 21)); + List r1 = template.find("springdata", q1, PersonWithIdPropertyOfTypeObjectId.class); + assertThat(r1.size(), is(0)); + Query q2 = new Query(Criteria.where("age").is(10)); + List r2 = template.find("springdata", q2, PersonWithIdPropertyOfTypeObjectId.class); + assertThat(r2.size(), is(2)); + for (PersonWithIdPropertyOfTypeObjectId p : r2) { + assertThat(p.getAge(), is(10)); + assertThat(p.getFirstName(), is("Bob")); + } + } + @Test public void testRemovingDocument() throws Exception {