diff --git a/src/main/java/org/springframework/data/jdbc/repository/SimpleJdbcRepository.java b/src/main/java/org/springframework/data/jdbc/repository/SimpleJdbcRepository.java index e66502ac8..8ec48d2d5 100644 --- a/src/main/java/org/springframework/data/jdbc/repository/SimpleJdbcRepository.java +++ b/src/main/java/org/springframework/data/jdbc/repository/SimpleJdbcRepository.java @@ -23,6 +23,7 @@ import java.util.stream.StreamSupport; import javax.sql.DataSource; import org.springframework.data.jdbc.mapping.model.JdbcPersistentEntity; import org.springframework.data.jdbc.mapping.model.JdbcPersistentProperty; +import org.springframework.data.jdbc.repository.support.JdbcPersistentEntityInformation; import org.springframework.data.mapping.PropertyHandler; import org.springframework.data.repository.CrudRepository; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; @@ -37,6 +38,7 @@ import org.springframework.jdbc.support.KeyHolder; public class SimpleJdbcRepository implements CrudRepository { private final JdbcPersistentEntity entity; + private final JdbcPersistentEntityInformation entityInformation; private final NamedParameterJdbcOperations template; private final SqlGenerator sql; @@ -45,6 +47,7 @@ public class SimpleJdbcRepository implements CrudRep public SimpleJdbcRepository(JdbcPersistentEntity entity, DataSource dataSource) { this.entity = entity; + this.entityInformation = new JdbcPersistentEntityInformation(entity); this.template = new NamedParameterJdbcTemplate(dataSource); entityRowMapper = new EntityRowMapper(entity); @@ -54,14 +57,19 @@ public class SimpleJdbcRepository implements CrudRep @Override public S save(S instance) { - KeyHolder holder = new GeneratedKeyHolder(); + if (entityInformation.isNew(instance)) { - template.update( - sql.getInsert(), - new MapSqlParameterSource(getPropertyMap(instance)), - holder); + KeyHolder holder = new GeneratedKeyHolder(); + + template.update( + sql.getInsert(), + new MapSqlParameterSource(getPropertyMap(instance)), + holder); - entity.setId(instance, holder.getKey()); + entity.setId(instance, holder.getKey()); + } else { + template.update(sql.getUpdate(), getPropertyMap(instance)); + } return instance; } diff --git a/src/main/java/org/springframework/data/jdbc/repository/SqlGenerator.java b/src/main/java/org/springframework/data/jdbc/repository/SqlGenerator.java index b561029bb..2d613d78a 100644 --- a/src/main/java/org/springframework/data/jdbc/repository/SqlGenerator.java +++ b/src/main/java/org/springframework/data/jdbc/repository/SqlGenerator.java @@ -17,6 +17,7 @@ package org.springframework.data.jdbc.repository; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collector; import java.util.stream.Collectors; import org.springframework.data.jdbc.mapping.model.JdbcPersistentEntity; import org.springframework.data.mapping.PropertyHandler; @@ -37,9 +38,13 @@ class SqlGenerator { private final String deleteByIdSql; private final String deleteAllSql; private final String deleteByListSql; + private final String updateSql; + private final List propertyNames = new ArrayList<>(); SqlGenerator(JdbcPersistentEntity entity) { + entity.doWithProperties((PropertyHandler) persistentProperty -> propertyNames.add(persistentProperty.getName())); + findOneSql = createFindOneSelectSql(entity); findAllSql = createFindAllSql(entity); findAllInListSql = createFindAllInListSql(entity); @@ -48,6 +53,7 @@ class SqlGenerator { countSql = createCountSql(entity); insertSql = createInsertSql(entity); + updateSql = createUpdateSql(entity); deleteByIdSql = createDeleteSql(entity); deleteAllSql = createDeleteAllSql(entity); @@ -74,6 +80,10 @@ class SqlGenerator { return insertSql; } + String getUpdate() { + return updateSql; + } + String getCount() { return countSql; } @@ -106,22 +116,26 @@ class SqlGenerator { } private String createCountSql(JdbcPersistentEntity entity) { - return String.format("select count(*) from %s", entity.getTableName(), entity.getIdColumn()); + return String.format("select count(*) from %s", entity.getTableName()); } private String createInsertSql(JdbcPersistentEntity entity) { - List propertyNames = new ArrayList<>(); - entity.doWithProperties((PropertyHandler) persistentProperty -> propertyNames.add(persistentProperty.getName())); - String insertTemplate = "insert into %s (%s) values (%s)"; - String tableName = entity.getType().getSimpleName(); - String tableColumns = propertyNames.stream().collect(Collectors.joining(", ")); String parameterNames = propertyNames.stream().collect(Collectors.joining(", :", ":", "")); - return String.format(insertTemplate, tableName, tableColumns, parameterNames); + return String.format(insertTemplate, entity.getTableName(), tableColumns, parameterNames); + } + + private String createUpdateSql(JdbcPersistentEntity entity) { + + String updateTemplate = "update %s set %s where %s = :%s"; + + String setClause = propertyNames.stream().map(n -> String.format("%s = :%s", n, n)).collect(Collectors.joining(", ")); + + return String.format(updateTemplate, entity.getTableName(), setClause, entity.getIdColumn(), entity.getIdColumn()); } private String createDeleteSql(JdbcPersistentEntity entity) { @@ -133,6 +147,7 @@ class SqlGenerator { } private String createDeleteByListSql(JdbcPersistentEntity entity) { - return String.format("delete from %s where id in (:ids)", entity.getTableName()); + return String.format("delete from %s where %s in (:ids)", entity.getTableName(), entity.getIdColumn()); } + } diff --git a/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIdGenerationIntegrationTests.java b/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIdGenerationIntegrationTests.java new file mode 100644 index 000000000..2f2b5396a --- /dev/null +++ b/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIdGenerationIntegrationTests.java @@ -0,0 +1,129 @@ +/* + * Copyright 2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jdbc.repository; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.*; + +import org.junit.After; +import org.junit.Test; +import org.springframework.data.annotation.Id; +import org.springframework.data.jdbc.repository.support.JdbcRepositoryFactory; +import org.springframework.data.repository.CrudRepository; +import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; + +import lombok.Data; + +/** + * testing special cases for Id generation with JdbcRepositories. + * + * @author Jens Schauder + */ +public class JdbcRepositoryIdGenerationIntegrationTests { + + private final EmbeddedDatabase db = new EmbeddedDatabaseBuilder() + .generateUniqueName(true) + .setType(EmbeddedDatabaseType.HSQL) + .setScriptEncoding("UTF-8") + .ignoreFailedDrops(true) + .addScript("org.springframework.data.jdbc.repository/jdbc-repository-id-generation-integration-tests.sql") + .build(); + + private final NamedParameterJdbcTemplate template = new NamedParameterJdbcTemplate(db); + + private final ReadOnlyIdEntityRepository repository = createRepository(db); + + private ReadOnlyIdEntity entity = createDummyEntity(); + + @After + public void after() { + db.shutdown(); + } + + @Test + public void idWithoutSetterGetsSet() { + + entity = repository.save(entity); + + assertThat(entity.getId()).isNotNull(); + + ReadOnlyIdEntity reloadedEntity = repository.findOne(entity.getId()); + + assertEquals( + entity.getId(), + reloadedEntity.getId()); + assertEquals( + entity.getName(), + reloadedEntity.getName()); + } + + @Test + public void primitiveIdGetsSet() { + + entity = repository.save(entity); + + assertThat(entity.getId()).isNotNull(); + + ReadOnlyIdEntity reloadedEntity = repository.findOne(entity.getId()); + + assertEquals( + entity.getId(), + reloadedEntity.getId()); + assertEquals( + entity.getName(), + reloadedEntity.getName()); + } + + + private static ReadOnlyIdEntityRepository createRepository(EmbeddedDatabase db) { + return new JdbcRepositoryFactory(db).getRepository(ReadOnlyIdEntityRepository.class); + } + + + private static ReadOnlyIdEntity createDummyEntity() { + + ReadOnlyIdEntity entity = new ReadOnlyIdEntity(null); + entity.setName("Entity Name"); + return entity; + } + + private interface ReadOnlyIdEntityRepository extends CrudRepository { + + } + + @Data + static class ReadOnlyIdEntity { + + @Id + private final Long id; + String name; + } + + private interface PrimitiveIdEntityRepository extends CrudRepository { + + } + + @Data + static class PrimitiveIdEntity { + + @Id + private final Long id; + String name; + } +} diff --git a/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java b/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java index 0251b976b..410b92fbd 100644 --- a/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java +++ b/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java @@ -44,14 +44,14 @@ public class JdbcRepositoryIntegrationTests { .setType(EmbeddedDatabaseType.HSQL) .setScriptEncoding("UTF-8") .ignoreFailedDrops(true) - .addScript("org.springframework.data.jdbc.repository/createTable.sql") + .addScript("org.springframework.data.jdbc.repository/jdbc-repository-integration-tests.sql") .build(); private final NamedParameterJdbcTemplate template = new NamedParameterJdbcTemplate(db); private final DummyEntityRepository repository = createRepository(db); - private DummyEntity entity = createDummyEntity(23L); + private DummyEntity entity = createDummyEntity(); @After public void after() { @@ -65,8 +65,8 @@ public class JdbcRepositoryIntegrationTests { entity = repository.save(entity); int count = template.queryForObject( - "SELECT count(*) FROM dummyentity WHERE id = :id", - new MapSqlParameterSource("id", entity.getId()), + "SELECT count(*) FROM dummyentity WHERE idProp = :id", + new MapSqlParameterSource("id", entity.getIdProp()), Integer.class); assertEquals( @@ -79,59 +79,24 @@ public class JdbcRepositoryIntegrationTests { entity = repository.save(entity); - DummyEntity reloadedEntity = repository.findOne(entity.getId()); + DummyEntity reloadedEntity = repository.findOne(entity.getIdProp()); assertEquals( - entity.getId(), - reloadedEntity.getId()); + entity.getIdProp(), + reloadedEntity.getIdProp()); assertEquals( entity.getName(), reloadedEntity.getName()); } - @Test - public void canSaveAndLoadAnEntityWithDatabaseBasedIdGeneration() { - - entity = createDummyEntity(null); - - entity = repository.save(entity); - - assertThat(entity).isNotNull(); - - DummyEntity reloadedEntity = repository.findOne(entity.getId()); - - assertEquals( - entity.getId(), - reloadedEntity.getId()); - assertEquals( - entity.getName(), - reloadedEntity.getName()); - } - - @Test public void saveMany() { - DummyEntity other = createDummyEntity(24L); + DummyEntity other = createDummyEntity(); repository.save(asList(entity, other)); - assertThat(repository.findAll()).extracting(DummyEntity::getId).containsExactlyInAnyOrder(23L, 24L); - } - - @Test - public void saveManyWithIdGeneration() { - - DummyEntity one = createDummyEntity(null); - DummyEntity two = createDummyEntity(null); - - Iterable entities = repository.save(asList(one, two)); - - assertThat(entities).allMatch(e -> e.getId() != null); - - assertThat(repository.findAll()) - .extracting(DummyEntity::getId) - .containsExactlyInAnyOrder(new Long[]{one.getId(), two.getId()}); + assertThat(repository.findAll()).extracting(DummyEntity::getIdProp).containsExactlyInAnyOrder(entity.getIdProp(), other.getIdProp()); } @Test @@ -139,40 +104,40 @@ public class JdbcRepositoryIntegrationTests { entity = repository.save(entity); - assertTrue(repository.exists(entity.getId())); - assertFalse(repository.exists(entity.getId() + 1)); + assertTrue(repository.exists(entity.getIdProp())); + assertFalse(repository.exists(entity.getIdProp() + 1)); } @Test public void findAllFindsAllEntities() { - DummyEntity other = createDummyEntity(24L); + DummyEntity other = createDummyEntity(); other = repository.save(other); entity = repository.save(entity); Iterable all = repository.findAll(); - assertThat(all).extracting("id").containsExactlyInAnyOrder(entity.getId(), other.getId()); + assertThat(all).extracting("idProp").containsExactlyInAnyOrder(entity.getIdProp(), other.getIdProp()); } @Test public void findAllFindsAllSpecifiedEntities() { - repository.save(createDummyEntity(24L)); - DummyEntity other = repository.save(createDummyEntity(25L)); + DummyEntity two = repository.save(createDummyEntity()); + DummyEntity three = repository.save(createDummyEntity()); entity = repository.save(entity); - Iterable all = repository.findAll(asList(entity.getId(), other.getId())); + Iterable all = repository.findAll(asList(entity.getIdProp(), three.getIdProp())); - assertThat(all).extracting("id").containsExactlyInAnyOrder(entity.getId(), other.getId()); + assertThat(all).extracting("idProp").containsExactlyInAnyOrder(entity.getIdProp(), three.getIdProp()); } @Test public void count() { - repository.save(createDummyEntity(24L)); - repository.save(createDummyEntity(25L)); + repository.save(createDummyEntity()); + repository.save(createDummyEntity()); repository.save(entity); assertThat(repository.count()).isEqualTo(3L); @@ -181,25 +146,27 @@ public class JdbcRepositoryIntegrationTests { @Test public void deleteById() { - repository.save(createDummyEntity(24L)); - repository.save(createDummyEntity(25L)); - repository.save(entity); + entity = repository.save(entity); + DummyEntity two = repository.save(createDummyEntity()); + DummyEntity three = repository.save(createDummyEntity()); - repository.delete(24L); + repository.delete(two.getIdProp()); - assertThat(repository.findAll()).extracting(DummyEntity::getId).containsExactlyInAnyOrder(23L, 25L); + assertThat(repository.findAll()) + .extracting(DummyEntity::getIdProp) + .containsExactlyInAnyOrder(entity.getIdProp(), three.getIdProp()); } @Test public void deleteByEntity() { - repository.save(createDummyEntity(24L)); - repository.save(createDummyEntity(25L)); - repository.save(entity); + entity = repository.save(entity); + DummyEntity two = repository.save(createDummyEntity()); + DummyEntity three = repository.save(createDummyEntity()); repository.delete(entity); - assertThat(repository.findAll()).extracting(DummyEntity::getId).containsExactlyInAnyOrder(24L, 25L); + assertThat(repository.findAll()).extracting(DummyEntity::getIdProp).containsExactlyInAnyOrder(two.getIdProp(), three.getIdProp()); } @@ -207,20 +174,20 @@ public class JdbcRepositoryIntegrationTests { public void deleteByList() { repository.save(entity); - repository.save(createDummyEntity(24L)); - DummyEntity other = repository.save(createDummyEntity(25L)); + DummyEntity two = repository.save(createDummyEntity()); + DummyEntity three = repository.save(createDummyEntity()); - repository.delete(asList(entity, other)); + repository.delete(asList(entity, three)); - assertThat(repository.findAll()).extracting(DummyEntity::getId).containsExactlyInAnyOrder(24L); + assertThat(repository.findAll()).extracting(DummyEntity::getIdProp).containsExactlyInAnyOrder(two.getIdProp()); } @Test public void deleteAll() { repository.save(entity); - repository.save(createDummyEntity(24L)); - repository.save(createDummyEntity(25L)); + repository.save(createDummyEntity()); + repository.save(createDummyEntity()); repository.deleteAll(); @@ -228,15 +195,44 @@ public class JdbcRepositoryIntegrationTests { } + @Test + public void update() { + + entity = repository.save(entity); + + entity.setName("something else"); + + entity = repository.save(entity); + + DummyEntity reloaded = repository.findOne(entity.getIdProp()); + + assertThat(reloaded.getName()).isEqualTo(entity.getName()); + } + + @Test + public void updateMany() { + + entity = repository.save(entity); + DummyEntity other = repository.save(createDummyEntity()); + + entity.setName("something else"); + other.setName("others Name"); + + repository.save(asList(entity, other)); + + assertThat(repository.findAll()) + .extracting(DummyEntity::getName) + .containsExactlyInAnyOrder(entity.getName(), other.getName()); + } + private static DummyEntityRepository createRepository(EmbeddedDatabase db) { return new JdbcRepositoryFactory(db).getRepository(DummyEntityRepository.class); } - private static DummyEntity createDummyEntity(Long id) { + private static DummyEntity createDummyEntity() { DummyEntity entity = new DummyEntity(); - entity.setId(id); entity.setName("Entity Name"); return entity; } @@ -245,12 +241,11 @@ public class JdbcRepositoryIntegrationTests { } - // needs to be public in order for the Hamcrest property matcher to work. @Data - public static class DummyEntity { + static class DummyEntity { @Id - Long id; + private Long idProp; String name; } } diff --git a/src/test/resources/org.springframework.data.jdbc.repository/createTable.sql b/src/test/resources/org.springframework.data.jdbc.repository/createTable.sql deleted file mode 100644 index c818bad88..000000000 --- a/src/test/resources/org.springframework.data.jdbc.repository/createTable.sql +++ /dev/null @@ -1 +0,0 @@ -CREATE TABLE dummyentity (ID BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 4711) PRIMARY KEY, NAME VARCHAR(100)) \ No newline at end of file diff --git a/src/test/resources/org.springframework.data.jdbc.repository/jdbc-repository-id-generation-integration-tests.sql b/src/test/resources/org.springframework.data.jdbc.repository/jdbc-repository-id-generation-integration-tests.sql new file mode 100644 index 000000000..0e008210e --- /dev/null +++ b/src/test/resources/org.springframework.data.jdbc.repository/jdbc-repository-id-generation-integration-tests.sql @@ -0,0 +1,3 @@ +-- noinspection SqlNoDataSourceInspectionForFile + +CREATE TABLE ReadOnlyIdEntity (ID BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 1) PRIMARY KEY, NAME VARCHAR(100)) \ No newline at end of file diff --git a/src/test/resources/org.springframework.data.jdbc.repository/jdbc-repository-integration-tests.sql b/src/test/resources/org.springframework.data.jdbc.repository/jdbc-repository-integration-tests.sql new file mode 100644 index 000000000..04c850b9f --- /dev/null +++ b/src/test/resources/org.springframework.data.jdbc.repository/jdbc-repository-integration-tests.sql @@ -0,0 +1 @@ +CREATE TABLE dummyentity (idProp BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 1) PRIMARY KEY, NAME VARCHAR(100)) \ No newline at end of file