diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java index 6d9fb662d..dbef6d1e1 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java @@ -16,7 +16,6 @@ package org.springframework.data.jdbc.core; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; @@ -24,7 +23,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -58,7 +56,6 @@ import org.springframework.data.util.Streamable; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; -import org.springframework.util.ObjectUtils; /** * {@link JdbcAggregateOperations} implementation, storing aggregates in and obtaining them from a JDBC data store. @@ -176,8 +173,19 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations { @Override public List saveAll(Iterable instances) { - return doWithBatch(instances, entity -> changeCreatorSelectorForSave(entity).apply(entity), this::verifyIdProperty, - this::performSaveAll); + + Assert.notNull(instances, "Aggregate instances must not be null"); + + if (!instances.iterator().hasNext()) { + return Collections.emptyList(); + } + + List> entityAndChangeCreators = new ArrayList<>(); + for (T instance : instances) { + verifyIdProperty(instance); + entityAndChangeCreators.add(new EntityAndChangeCreator<>(instance, changeCreatorSelectorForSave(instance))); + } + return performSaveAll(entityAndChangeCreators); } /** @@ -198,7 +206,21 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations { @Override public List insertAll(Iterable instances) { - return doWithBatch(instances, entity -> createInsertChange(prepareVersionForInsert(entity)), this::performSaveAll); + + Assert.notNull(instances, "Aggregate instances must not be null"); + + if (!instances.iterator().hasNext()) { + return Collections.emptyList(); + } + + List> entityAndChangeCreators = new ArrayList<>(); + for (T instance : instances) { + + Function> changeCreator = entity -> createInsertChange(prepareVersionForInsert(entity)); + EntityAndChangeCreator entityChange = new EntityAndChangeCreator<>(instance, changeCreator); + entityAndChangeCreators.add(entityChange); + } + return performSaveAll(entityAndChangeCreators); } /** @@ -219,35 +241,21 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations { @Override public List updateAll(Iterable instances) { - return doWithBatch(instances, entity -> createUpdateChange(prepareVersionForUpdate(entity)), this::performSaveAll); - } - - private List doWithBatch(Iterable iterable, Function> changeCreator, - Function>, List> performFunction) { - return doWithBatch(iterable, changeCreator, entity -> {}, performFunction); - } - - private List doWithBatch(Iterable iterable, Function> changeCreator, - Consumer beforeEntityChange, Function>, List> performFunction) { - Assert.notNull(iterable, "Aggregate instances must not be null"); + Assert.notNull(instances, "Aggregate instances must not be null"); - if (ObjectUtils.isEmpty(iterable)) { + if (!instances.iterator().hasNext()) { return Collections.emptyList(); } - List> entityAndChangeCreators = new ArrayList<>( - iterable instanceof Collection c ? c.size() : 16); - - for (T instance : iterable) { - - beforeEntityChange.accept(instance); + List> entityAndChangeCreators = new ArrayList<>(); + for (T instance : instances) { + Function> changeCreator = entity -> createUpdateChange(prepareVersionForUpdate(entity)); EntityAndChangeCreator entityChange = new EntityAndChangeCreator<>(instance, changeCreator); entityAndChangeCreators.add(entityChange); } - - return performFunction.apply(entityAndChangeCreators); + return performSaveAll(entityAndChangeCreators); } @Override diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIdGenerationIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIdGenerationIntegrationTests.java index 726aa4bb0..7ec08f8a6 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIdGenerationIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIdGenerationIntegrationTests.java @@ -95,6 +95,18 @@ class JdbcRepositoryIdGenerationIntegrationTests { assertThat(immutableWithManualIdEntityRepository.findAll()).hasSize(1); } + @Test // DATAJDBC-393 + void manuallyGeneratedIdForSaveAll() { + + ImmutableWithManualIdEntity one = new ImmutableWithManualIdEntity(null, "one"); + ImmutableWithManualIdEntity two = new ImmutableWithManualIdEntity(null, "two"); + List saved = immutableWithManualIdEntityRepository.saveAll(List.of(one, two)); + + assertThat(saved).allSatisfy(e -> assertThat(e.id).isNotNull()); + + assertThat(immutableWithManualIdEntityRepository.findAll()).hasSize(2); + } + private interface PrimitiveIdEntityRepository extends ListCrudRepository {} private interface ReadOnlyIdEntityRepository extends ListCrudRepository {}