diff --git a/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java b/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java index ddc09f557..09250b70f 100644 --- a/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java +++ b/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java @@ -379,33 +379,10 @@ public class MappingR2dbcConverter extends BasicRelationalConverter implements R Object result = getPotentiallyConvertedSimpleWrite(value); - if (property.isIdProperty() && isNew) { - if (shouldSkipIdValue(result, property)) { - return; - } - } - sink.put(property.getColumnName(), Parameter.fromOrEmpty(result, getPotentiallyConvertedSimpleNullType(property.getType()))); } - private boolean shouldSkipIdValue(@Nullable Object value, RelationalPersistentProperty property) { - - if (value == null) { - return true; - } - - if (!property.getType().isPrimitive()) { - return value == null; - } - - if (value instanceof Number) { - return ((Number) value).longValue() == 0L; - } - - return false; - } - private void writePropertyInternal(OutboundRow sink, Object value, boolean isNew, RelationalPersistentProperty property) { diff --git a/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java b/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java index 8aede1d6c..2ee6a5555 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java +++ b/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java @@ -546,11 +546,41 @@ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAw OutboundRow outboundRow = dataAccessStrategy.getOutboundRow(initializedEntity); + potentiallyRemoveId(persistentEntity, outboundRow); + return maybeCallBeforeSave(initializedEntity, outboundRow, tableName) // .flatMap(entityToSave -> doInsert(entityToSave, tableName, outboundRow)); }); } + private void potentiallyRemoveId(RelationalPersistentEntity persistentEntity, OutboundRow outboundRow) { + + RelationalPersistentProperty idProperty = persistentEntity.getIdProperty(); + if (idProperty == null) { + return; + } + + SqlIdentifier columnName = idProperty.getColumnName(); + Parameter parameter = outboundRow.get(columnName); + + if (shouldSkipIdValue(parameter, idProperty)) { + outboundRow.remove(columnName); + } + } + + private boolean shouldSkipIdValue(@Nullable Parameter value, RelationalPersistentProperty property) { + + if (value == null || value.getValue() == null) { + return true; + } + + if (value.getValue() instanceof Number) { + return ((Number) value.getValue()).longValue() == 0L; + } + + return false; + } + private Mono doInsert(T entity, SqlIdentifier tableName, OutboundRow outboundRow) { StatementMapper mapper = dataAccessStrategy.getStatementMapper(); diff --git a/src/test/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverterUnitTests.java b/src/test/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverterUnitTests.java index 2f6b61b2d..51b4b21ca 100644 --- a/src/test/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverterUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverterUnitTests.java @@ -212,15 +212,6 @@ public class MappingR2dbcConverterUnitTests { assertThat(result.entity).isNotNull(); } - @Test // gh-402 - public void writeShouldSkipPrimitiveIdIfValueIsZero() { - - OutboundRow row = new OutboundRow(); - converter.write(new WithPrimitiveId(0), row); - - assertThat(row).isEmpty(); - } - @Test // gh-402 public void writeShouldWritePrimitiveIdIfValueIsNonZero() { diff --git a/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java b/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java index 54f3497c0..013dfb251 100644 --- a/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java @@ -280,6 +280,46 @@ public class R2dbcEntityTemplateUnitTests { Parameter.from(1L)); } + @Test // gh-557, gh-402 + public void shouldSkipDefaultIdValueOnInsert() { + + MockRowMetadata metadata = MockRowMetadata.builder().build(); + MockResult result = MockResult.builder().rowMetadata(metadata).rowsUpdated(1).build(); + + recorder.addStubbing(s -> s.startsWith("INSERT"), result); + + entityTemplate.insert(new PersonWithPrimitiveId(0, "bar")).as(StepVerifier::create) // + .expectNextCount(1) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("INSERT")); + + assertThat(statement.getSql()).isEqualTo("INSERT INTO person_with_primitive_id (name) VALUES ($1)"); + assertThat(statement.getBindings()).hasSize(1).containsEntry(0, Parameter.from("bar")); + } + + @Test // gh-557, gh-402 + public void shouldSkipDefaultIdValueOnVersionedInsert() { + + MockRowMetadata metadata = MockRowMetadata.builder().build(); + MockResult result = MockResult.builder().rowMetadata(metadata).rowsUpdated(1).build(); + + recorder.addStubbing(s -> s.startsWith("INSERT"), result); + + entityTemplate.insert(new VersionedPersonWithPrimitiveId(0, 0, "bar")).as(StepVerifier::create) // + .assertNext(actual -> { + assertThat(actual.getVersion()).isEqualTo(1); + }) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("INSERT")); + + assertThat(statement.getSql()) + .isEqualTo("INSERT INTO versioned_person_with_primitive_id (version, name) VALUES ($1, $2)"); + assertThat(statement.getBindings()).hasSize(2).containsEntry(0, Parameter.from(1L)).containsEntry(1, + Parameter.from("bar")); + } + @Test // gh-451 public void shouldInsertCorrectlyVersionedAndAudited() { @@ -449,6 +489,26 @@ public class R2dbcEntityTemplateUnitTests { String name; } + @Value + @With + static class PersonWithPrimitiveId { + + @Id int id; + + String name; + } + + @Value + @With + static class VersionedPersonWithPrimitiveId { + + @Id int id; + + @Version long version; + + String name; + } + @Value @With static class WithAuditingAndOptimisticLocking {