diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/AggregateChangeExecutor.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/AggregateChangeExecutor.java index 5f7546ff9..70e68d8c8 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/AggregateChangeExecutor.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/AggregateChangeExecutor.java @@ -68,18 +68,21 @@ class AggregateChangeExecutor { List> actions = new ArrayList<>(); aggregateChange.forEachAction(action -> { + action.executeWith(interpreter); actions.add(action); }); - T newRoot = populateIdsIfNecessary(actions); - if (newRoot != null) { - newRoot = populateRootVersionIfNecessary(newRoot, actions); - aggregateChange.setEntity(newRoot); + T root = populateIdsIfNecessary(actions); + root = root == null ? aggregateChange.getEntity() : root; + + if (root != null) { + + root = populateRootVersionIfNecessary(root, actions); + aggregateChange.setEntity(root); } } - @SuppressWarnings("unchecked") private T populateRootVersionIfNecessary(T newRoot, List> actions) { // Does the root entity have a version attribute? @@ -90,7 +93,8 @@ class AggregateChangeExecutor { } // Find the root action - Optional> rootAction = actions.parallelStream().filter(action -> action instanceof DbAction.WithVersion) + Optional> rootAction = actions.parallelStream() // + .filter(action -> action instanceof DbAction.WithVersion) // .findFirst(); if (!rootAction.isPresent()) { diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/DefaultJdbcInterpreter.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/DefaultJdbcInterpreter.java index 0bc2a6e92..a075197ff 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/DefaultJdbcInterpreter.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/DefaultJdbcInterpreter.java @@ -118,30 +118,37 @@ class DefaultJdbcInterpreter implements Interpreter { RelationalPersistentEntity persistentEntity = getRequiredPersistentEntity(update.getEntityType()); if (persistentEntity.hasVersionProperty()) { + updateWithVersion(update, persistentEntity); + } else { + updateWithoutVersion(update); + } + } + + private void updateWithoutVersion(UpdateRoot update) { + + if (!accessStrategy.update(update.getEntity(), update.getEntityType())) { + + throw new IncorrectUpdateSemanticsDataAccessException( + String.format(UPDATE_FAILED, update.getEntity(), getIdFrom(update))); + } + } - // If the root aggregate has a version property, increment it. - Number previousVersion = RelationalEntityVersionUtils.getVersionNumberFromEntity(update.getEntity(), - persistentEntity, converter); + private void updateWithVersion(UpdateRoot update, RelationalPersistentEntity persistentEntity) { - Assert.notNull(previousVersion, "The root aggregate cannot be updated because the version property is null."); + // If the root aggregate has a version property, increment it. + Number previousVersion = RelationalEntityVersionUtils.getVersionNumberFromEntity(update.getEntity(), + persistentEntity, converter); - T rootEntity = RelationalEntityVersionUtils.setVersionNumberOnEntity(update.getEntity(), - previousVersion.longValue() + 1, persistentEntity, converter); + Assert.notNull(previousVersion, "The root aggregate cannot be updated because the version property is null."); - if (accessStrategy.updateWithVersion(rootEntity, update.getEntityType(), previousVersion)) { - // Successful update, set the in-memory version on the action. - update.setNextVersion(previousVersion); - } else { - throw new OptimisticLockingFailureException( - String.format(UPDATE_FAILED_OPTIMISTIC_LOCKING, update.getEntity())); - } - } else { + update.setNextVersion(previousVersion.longValue() + 1); + T rootEntity = RelationalEntityVersionUtils.setVersionNumberOnEntity(update.getEntity(), update.getNextVersion(), + persistentEntity, converter); - if (!accessStrategy.update(update.getEntity(), update.getEntityType())) { + if (!accessStrategy.updateWithVersion(rootEntity, update.getEntityType(), previousVersion)) { - throw new IncorrectUpdateSemanticsDataAccessException( - String.format(UPDATE_FAILED, update.getEntity(), getIdFrom(update))); - } + throw new OptimisticLockingFailureException( + String.format(UPDATE_FAILED_OPTIMISTIC_LOCKING, update.getEntity())); } } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/JdbcAggregateTemplateIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/JdbcAggregateTemplateIntegrationTests.java index 8a2a69ac3..9a24375eb 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/JdbcAggregateTemplateIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/JdbcAggregateTemplateIntegrationTests.java @@ -709,13 +709,15 @@ public class JdbcAggregateTemplateIntegrationTests { assertThat(reloadedAggregate.getVersion()).describedAs("version field should initially have the value 1") .isEqualTo(1L); - AggregateWithImmutableVersion saved = template.save(reloadedAggregate); - AggregateWithImmutableVersion updatedAggregate = template.findById(id, aggregate.getClass()); + AggregateWithImmutableVersion savedAgain = template.save(reloadedAggregate); + AggregateWithImmutableVersion reloadedAgain = template.findById(id, aggregate.getClass()); - assertThat(saved.version) - .describedAs("returned by save(): "+ saved + " vs. returned by findById(): " + updatedAggregate) - .isEqualTo(updatedAggregate.version); - assertThat(updatedAggregate.getVersion()).describedAs("version field should increment by one with each save") + assertThat(savedAgain.version) + .describedAs("The object returned by save should have an increased version") + .isEqualTo(2L); + + assertThat(reloadedAgain.getVersion()) + .describedAs("version field should increment by one with each save") .isEqualTo(2L); assertThatThrownBy(() -> template.save(new AggregateWithImmutableVersion(id, 1L)))