Browse Source

Fix id setting for partial updates of collections of immutable types.

We gather immutable entities of which the id has changed, in order to set them as values in the parent entity.
We now also gather unchanged entities.
So they get set with the changed one in the parent.

Closes #1907
Original pull request: #1920
pull/1925/head
Jens Schauder 1 year ago committed by Mark Paluch
parent
commit
a7d7adaaf2
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 63
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java
  2. 71
      spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryWithListsIntegrationTests.java
  3. 13
      spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/DbAction.java

63
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java

@ -15,16 +15,7 @@ @@ -15,16 +15,7 @@
*/
package org.springframework.data.jdbc.core;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.*;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;
@ -241,7 +232,7 @@ class JdbcAggregateChangeExecutionContext { @@ -241,7 +232,7 @@ class JdbcAggregateChangeExecutionContext {
RelationalPersistentEntity<?> persistentEntity = getRequiredPersistentEntity(idOwningAction.getEntityType());
Object identifier = persistentEntity.getIdentifierAccessor(idOwningAction.getEntity()).getIdentifier();
Assert.state(identifier != null,() -> "Couldn't obtain a required id value for " + persistentEntity);
Assert.state(identifier != null, () -> "Couldn't obtain a required id value for " + persistentEntity);
return identifier;
}
@ -268,12 +259,22 @@ class JdbcAggregateChangeExecutionContext { @@ -268,12 +259,22 @@ class JdbcAggregateChangeExecutionContext {
}
// the id property was immutable, so we have to propagate changes up the tree
if (newEntity != action.getEntity() && action instanceof DbAction.Insert<?> insert) {
if (action instanceof DbAction.Insert<?> insert) {
Pair<?, ?> qualifier = insert.getQualifier();
Object qualifierValue = qualifier == null ? null : qualifier.getSecond();
cascadingValues.stage(insert.getDependingOn(), insert.getPropertyPath(),
qualifier == null ? null : qualifier.getSecond(), newEntity);
if (newEntity != action.getEntity()) {
cascadingValues.stage(insert.getDependingOn(), insert.getPropertyPath(),
qualifierValue, newEntity);
} else if (insert.getPropertyPath().getLeafProperty().isCollectionLike()) {
cascadingValues.gather(insert.getDependingOn(), insert.getPropertyPath(),
qualifierValue, newEntity);
}
}
}
@ -360,7 +361,7 @@ class JdbcAggregateChangeExecutionContext { @@ -360,7 +361,7 @@ class JdbcAggregateChangeExecutionContext {
static final List<MultiValueAggregator> aggregators = Arrays.asList(SetAggregator.INSTANCE, MapAggregator.INSTANCE,
ListAggregator.INSTANCE, SingleElementAggregator.INSTANCE);
Map<DbAction, Map<PersistentPropertyPath, Object>> values = new HashMap<>();
Map<DbAction, Map<PersistentPropertyPath, StagedValue>> values = new HashMap<>();
/**
* Adds a value that needs to be set in an entity higher up in the tree of entities in the aggregate. If the
@ -375,18 +376,26 @@ class JdbcAggregateChangeExecutionContext { @@ -375,18 +376,26 @@ class JdbcAggregateChangeExecutionContext {
*/
@SuppressWarnings("unchecked")
<T> void stage(DbAction<?> action, PersistentPropertyPath path, @Nullable Object qualifier, Object value) {
gather(action, path, qualifier, value);
values.get(action).get(path).isStaged = true;
}
<T> void gather(DbAction<?> action, PersistentPropertyPath path, @Nullable Object qualifier, Object value) {
MultiValueAggregator<T> aggregator = getAggregatorFor(path);
Map<PersistentPropertyPath, Object> valuesForPath = this.values.computeIfAbsent(action,
Map<PersistentPropertyPath, StagedValue> valuesForPath = this.values.computeIfAbsent(action,
dbAction -> new HashMap<>());
T currentValue = (T) valuesForPath.computeIfAbsent(path,
persistentPropertyPath -> aggregator.createEmptyInstance());
StagedValue stagedValue = valuesForPath.computeIfAbsent(path,
persistentPropertyPath -> new StagedValue(aggregator.createEmptyInstance()));
T currentValue = (T) stagedValue.value;
Object newValue = aggregator.add(currentValue, qualifier, value);
valuesForPath.put(path, newValue);
stagedValue.value = newValue;
valuesForPath.put(path, stagedValue);
}
private MultiValueAggregator getAggregatorFor(PersistentPropertyPath path) {
@ -408,7 +417,21 @@ class JdbcAggregateChangeExecutionContext { @@ -408,7 +417,21 @@ class JdbcAggregateChangeExecutionContext {
* property.
*/
void forEachPath(DbAction<?> dbAction, BiConsumer<PersistentPropertyPath, Object> action) {
values.getOrDefault(dbAction, Collections.emptyMap()).forEach(action);
values.getOrDefault(dbAction, Collections.emptyMap()).forEach((persistentPropertyPath, stagedValue) -> {
if (stagedValue.isStaged) {
action.accept(persistentPropertyPath, stagedValue.value);
}
});
}
}
private static class StagedValue {
Object value;
boolean isStaged;
public StagedValue(Object value) {
this.value = value;
}
}

71
spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryWithListsIntegrationTests.java

@ -32,6 +32,7 @@ import org.springframework.context.annotation.Bean; @@ -32,6 +32,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.data.annotation.Id;
import org.springframework.data.annotation.PersistenceCreator;
import org.springframework.data.jdbc.repository.support.JdbcRepositoryFactory;
import org.springframework.data.jdbc.testing.EnabledOnFeature;
import org.springframework.data.jdbc.testing.IntegrationTest;
@ -55,8 +56,7 @@ public class JdbcRepositoryWithListsIntegrationTests { @@ -55,8 +56,7 @@ public class JdbcRepositoryWithListsIntegrationTests {
private static DummyEntity createDummyEntity() {
DummyEntity entity = new DummyEntity();
entity.setName("Entity Name");
DummyEntity entity = new DummyEntity(null, "Entity Name", new ArrayList<>());
return entity;
}
@ -94,7 +94,7 @@ public class JdbcRepositoryWithListsIntegrationTests { @@ -94,7 +94,7 @@ public class JdbcRepositoryWithListsIntegrationTests {
assertThat(reloaded.content) //
.isNotNull() //
.extracting(e -> e.id) //
.containsExactlyInAnyOrder(element1.id, element2.id);
.containsExactlyInAnyOrder(entity.content.get(0).id, entity.content.get(1).id);
}
@Test // GH-1159
@ -147,9 +147,9 @@ public class JdbcRepositoryWithListsIntegrationTests { @@ -147,9 +147,9 @@ public class JdbcRepositoryWithListsIntegrationTests {
@EnabledOnFeature(SUPPORTS_GENERATED_IDS_IN_REFERENCED_ENTITIES)
public void updateList() {
Element element1 = createElement("one");
Element element2 = createElement("two");
Element element3 = createElement("three");
Element element1 = new Element("one");
Element element2 = new Element("two");
Element element3 = new Element("three");
DummyEntity entity = createDummyEntity();
entity.content.add(element1);
@ -157,14 +157,15 @@ public class JdbcRepositoryWithListsIntegrationTests { @@ -157,14 +157,15 @@ public class JdbcRepositoryWithListsIntegrationTests {
entity = repository.save(entity);
entity.content.remove(element1);
element2.content = "two changed";
entity.content.remove(0);
entity.content.set(0, new Element(entity.content.get(0).id, "two changed"));
entity.content.add(element3);
entity = repository.save(entity);
assertThat(entity.id).isNotNull();
assertThat(entity.content).allMatch(v -> v.id != null);
assertThat(entity.content).hasSize(2);
DummyEntity reloaded = repository.findById(entity.id).orElseThrow(AssertionFailedError::new);
@ -175,8 +176,8 @@ public class JdbcRepositoryWithListsIntegrationTests { @@ -175,8 +176,8 @@ public class JdbcRepositoryWithListsIntegrationTests {
assertThat(reloaded.content) //
.extracting(e -> e.id, e -> e.content) //
.containsExactly( //
tuple(element2.id, "two changed"), //
tuple(element3.id, "three") //
tuple(entity.content.get(0).id, "two changed"), //
tuple(entity.content.get(1).id, "three") //
);
Long count = template.queryForObject("SELECT count(1) FROM Element", new HashMap<>(), Long.class);
@ -186,8 +187,8 @@ public class JdbcRepositoryWithListsIntegrationTests { @@ -186,8 +187,8 @@ public class JdbcRepositoryWithListsIntegrationTests {
@Test // DATAJDBC-130
public void deletingWithList() {
Element element1 = createElement("one");
Element element2 = createElement("two");
Element element1 = new Element("one");
Element element2 = new Element("two");
DummyEntity entity = createDummyEntity();
entity.content.add(element1);
@ -203,13 +204,6 @@ public class JdbcRepositoryWithListsIntegrationTests { @@ -203,13 +204,6 @@ public class JdbcRepositoryWithListsIntegrationTests {
assertThat(count).isEqualTo(0);
}
private Element createElement(String content) {
Element element = new Element();
element.content = content;
return element;
}
interface DummyEntityRepository extends CrudRepository<DummyEntity, Long> {}
interface RootRepository extends CrudRepository<Root, Long> {}
@ -229,43 +223,22 @@ public class JdbcRepositoryWithListsIntegrationTests { @@ -229,43 +223,22 @@ public class JdbcRepositoryWithListsIntegrationTests {
}
}
static class DummyEntity {
record DummyEntity(@Id Long id, String name, List<Element> content) {
}
String name;
List<Element> content = new ArrayList<>();
@Id private Long id;
record Element(@Id Long id, String content) {
public String getName() {
return this.name;
}
@PersistenceCreator
Element {}
public List<Element> getContent() {
return this.content;
Element() {
this(null, null);
}
public Long getId() {
return this.id;
Element(String content) {
this(null, content);
}
public void setName(String name) {
this.name = name;
}
public void setContent(List<Element> content) {
this.content = content;
}
public void setId(Long id) {
this.id = id;
}
}
static class Element {
String content;
@Id private Long id;
public Element() {}
}
static class Root {

13
spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/DbAction.java

@ -15,9 +15,12 @@ @@ -15,9 +15,12 @@
*/
package org.springframework.data.relational.core.conversion;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import org.springframework.data.mapping.PersistentPropertyPath;
@ -479,15 +482,13 @@ public interface DbAction<T> { @@ -479,15 +482,13 @@ public interface DbAction<T> {
default Pair<PersistentPropertyPath<RelationalPersistentProperty>, Object> getQualifier() {
Map<PersistentPropertyPath<RelationalPersistentProperty>, Object> qualifiers = getQualifiers();
if (qualifiers.isEmpty())
if (qualifiers.isEmpty()) {
return null;
if (qualifiers.size() > 1) {
throw new IllegalStateException("Can't handle more then one qualifier");
}
Map.Entry<PersistentPropertyPath<RelationalPersistentProperty>, Object> entry = qualifiers.entrySet().iterator()
.next();
Set<Map.Entry<PersistentPropertyPath<RelationalPersistentProperty>, Object>> entries = qualifiers.entrySet();
Map.Entry<PersistentPropertyPath<RelationalPersistentProperty>, Object> entry = entries.stream().sorted(Comparator.comparing(e -> -e.getKey().getLength())).findFirst().get();
if (entry.getValue() == null) {
return null;
}

Loading…
Cancel
Save