diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/BasicUpdate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/BasicUpdate.java index 68d96d3ab..12843ce62 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/BasicUpdate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/BasicUpdate.java @@ -15,16 +15,21 @@ */ package org.springframework.data.mongodb.core.query; +import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.function.BiFunction; import org.bson.Document; + import org.springframework.lang.Nullable; import org.springframework.util.ClassUtils; /** + * {@link Document}-based {@link Update} variant. + * * @author Thomas Risberg * @author John Brisbin * @author Oliver Gierke @@ -36,12 +41,10 @@ public class BasicUpdate extends Update { private final Document updateObject; public BasicUpdate(String updateString) { - super(); - this.updateObject = Document.parse(updateString); + this(Document.parse(updateString)); } public BasicUpdate(Document updateObject) { - super(); this.updateObject = updateObject; } @@ -89,7 +92,17 @@ public class BasicUpdate extends Update { @Override public Update pullAll(String key, Object[] values) { - setOperationValue("$pullAll", key, List.of(values)); + setOperationValue("$pullAll", key, List.of(values), (o, o2) -> { + + if (o instanceof List prev && o2 instanceof List currentValue) { + List merged = new ArrayList<>(prev.size() + currentValue.size()); + merged.addAll(prev); + merged.addAll(currentValue); + return merged; + } + + return o2; + }); return this; } @@ -109,21 +122,31 @@ public class BasicUpdate extends Update { return updateObject; } - void setOperationValue(String operator, String key, Object value) { + void setOperationValue(String operator, String key, @Nullable Object value) { + setOperationValue(operator, key, value, (o, o2) -> o2); + } + + void setOperationValue(String operator, String key, @Nullable Object value, + BiFunction mergeFunction) { if (!updateObject.containsKey(operator)) { updateObject.put(operator, Collections.singletonMap(key, value)); } else { - Object existingValue = updateObject.get(operator); - if (existingValue instanceof Map existing) { + Object o = updateObject.get(operator); + if (o instanceof Map existing) { Map target = new LinkedHashMap<>(existing); - target.put(key, value); + + if (target.containsKey(key)) { + target.put(key, mergeFunction.apply(target.get(key), value)); + } else { + target.put(key, value); + } updateObject.put(operator, target); } else { throw new IllegalStateException( "Cannot add ['%s' : { '%s' : ... }]. Operator already exists with value of type [%s] which is not suitable for appending" .formatted(operator, key, - existingValue != null ? ClassUtils.getShortName(existingValue.getClass()) : "null")); + o != null ? ClassUtils.getShortName(o.getClass()) : "null")); } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java index 552e2b252..2fc2e2d0f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java @@ -447,13 +447,11 @@ public class Update implements UpdateDefinition { if (existingValue == null) { keyValueMap = new Document(); this.modifierOps.put(operator, keyValueMap); + } else if (existingValue instanceof Document document) { + keyValueMap = document; } else { - if (existingValue instanceof Document document) { - keyValueMap = document; - } else { - throw new InvalidDataAccessApiUsageException( - "Modifier Operations should be a LinkedHashMap but was " + existingValue.getClass()); - } + throw new InvalidDataAccessApiUsageException( + "Modifier Operations should be a LinkedHashMap but was " + existingValue.getClass()); } keyValueMap.put(key, value); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/BasicUpdateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/BasicUpdateUnitTests.java index fba105689..dacc27023 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/BasicUpdateUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/BasicUpdateUnitTests.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -15,9 +15,12 @@ */ package org.springframework.data.mongodb.core.query; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.*; +import static org.springframework.data.mongodb.test.util.Assertions.*; import static org.springframework.data.mongodb.test.util.Assertions.assertThat; +import java.util.Arrays; +import java.util.List; import java.util.function.Function; import java.util.stream.Stream; @@ -27,12 +30,16 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.MethodSource; + import org.springframework.data.mongodb.core.query.Update.Position; /** + * Unit tests for {@link BasicUpdate}. + * * @author Christoph Strobl + * @author Mark Paluch */ -public class BasicUpdateUnitTests { +class BasicUpdateUnitTests { @Test // GH-4918 void setOperationValueShouldAppendsOpsCorrectly() { @@ -80,8 +87,18 @@ public class BasicUpdateUnitTests { .containsKey("%s.key-2".formatted(operator)); } - static Stream updateOpArgs() { + @Test // GH-4918 + void shouldNotOverridePullAll() { + Document source = Document.parse("{ '$pullAll' : { 'key-1' : ['value-1'] } }"); + Update update = new BasicUpdate(source).pullAll("key-1", new String[] { "value-2" }).pullAll("key-2", + new String[] { "value-3" }); + + assertThat(update.getUpdateObject()).containsEntry("$pullAll.key-1", Arrays.asList("value-1", "value-2")) + .containsEntry("$pullAll.key-2", List.of("value-3")); + } + + static Stream updateOpArgs() { return Stream.of( // Arguments.of("$set", (Function) update -> update.set("key-2", "value-2")), Arguments.of("$unset", (Function) update -> update.unset("key-2")), diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VersionedPersonRepositoryIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VersionedPersonRepositoryIntegrationTests.java index 9da6962b8..f4e1e0282 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VersionedPersonRepositoryIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VersionedPersonRepositoryIntegrationTests.java @@ -15,13 +15,14 @@ */ package org.springframework.data.mongodb.repository; -import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.*; import org.bson.Document; import org.bson.types.ObjectId; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.ComponentScan.Filter; import org.springframework.context.annotation.Configuration; @@ -40,12 +41,13 @@ import org.springframework.test.context.junit.jupiter.SpringExtension; import com.mongodb.client.MongoClient; /** + * Integration tests for Repositories using optimistic locking. + * * @author Christoph Strobl - * @since 2025/03 */ @ExtendWith({ MongoClientExtension.class, SpringExtension.class }) @ContextConfiguration -public class VersionedPersonRepositoryIntegrationTests { +class VersionedPersonRepositoryIntegrationTests { static @Client MongoClient mongoClient; @@ -70,14 +72,15 @@ public class VersionedPersonRepositoryIntegrationTests { @BeforeEach void beforeEach() { - MongoTestUtils.flushCollection("versioned-person-tests", template.getCollectionName(VersionedPersonWithCounter.class), - mongoClient); + MongoTestUtils.flushCollection("versioned-person-tests", + template.getCollectionName(VersionedPersonWithCounter.class), mongoClient); } @Test // GH-4918 void updatesVersionedTypeCorrectly() { - VersionedPerson person = template.insert(VersionedPersonWithCounter.class).one(new VersionedPersonWithCounter("Donald", "Duckling")); + VersionedPerson person = template.insert(VersionedPersonWithCounter.class) + .one(new VersionedPersonWithCounter("Donald", "Duckling")); int updateCount = versionedPersonRepository.findAndSetFirstnameToLastnameByLastname(person.getLastname()); @@ -93,7 +96,8 @@ public class VersionedPersonRepositoryIntegrationTests { @Test // GH-4918 void updatesVersionedTypeCorrectlyWhenUpdateIsUsingInc() { - VersionedPerson person = template.insert(VersionedPersonWithCounter.class).one(new VersionedPersonWithCounter("Donald", "Duckling")); + VersionedPerson person = template.insert(VersionedPersonWithCounter.class) + .one(new VersionedPersonWithCounter("Donald", "Duckling")); int updateCount = versionedPersonRepository.findAndIncCounterByLastname(person.getLastname()); @@ -103,13 +107,15 @@ public class VersionedPersonRepositoryIntegrationTests { return collection.find(new Document("_id", new ObjectId(person.getId()))).first(); }); - assertThat(document).containsEntry("lastname", "Duckling").containsEntry("version", 1L).containsEntry("counter", 42); + assertThat(document).containsEntry("lastname", "Duckling").containsEntry("version", 1L).containsEntry("counter", + 42); } @Test // GH-4918 void updatesVersionedTypeCorrectlyWhenUpdateCoversVersionBump() { - VersionedPerson person = template.insert(VersionedPersonWithCounter.class).one(new VersionedPersonWithCounter("Donald", "Duckling")); + VersionedPerson person = template.insert(VersionedPersonWithCounter.class) + .one(new VersionedPersonWithCounter("Donald", "Duckling")); int updateCount = versionedPersonRepository.findAndSetFirstnameToLastnameIncVersionByLastname(person.getLastname(), 10); @@ -123,7 +129,7 @@ public class VersionedPersonRepositoryIntegrationTests { assertThat(document).containsEntry("firstname", "Duckling").containsEntry("version", 10L); } - public interface VersionedPersonRepository extends CrudRepository { + interface VersionedPersonRepository extends CrudRepository { @Update("{ '$set': { 'firstname' : ?0 } }") int findAndSetFirstnameToLastnameByLastname(String lastname); @@ -156,5 +162,7 @@ public class VersionedPersonRepositoryIntegrationTests { public void setCounter(int counter) { this.counter = counter; } + } + }