diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java index b6c90a161f8..7ef7d6ccfc0 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java @@ -218,7 +218,7 @@ final class DefaultJdbcClient implements JdbcClient { namedParamOps.query(this.sql, this.namedParamSource, rch); } else { - classicOps.query(getPreparedStatementCreatorForIndexedParams(), rch); + classicOps.query(statementCreatorForIndexedParams(), rch); } } @@ -226,7 +226,7 @@ final class DefaultJdbcClient implements JdbcClient { public T query(ResultSetExtractor rse) { T result = (useNamedParams() ? namedParamOps.query(this.sql, this.namedParamSource, rse) : - classicOps.query(getPreparedStatementCreatorForIndexedParams(), rse)); + classicOps.query(statementCreatorForIndexedParams(), rse)); Assert.state(result != null, "No result from ResultSetExtractor"); return result; } @@ -235,14 +235,21 @@ final class DefaultJdbcClient implements JdbcClient { public int update() { return (useNamedParams() ? namedParamOps.update(this.sql, this.namedParamSource) : - classicOps.update(getPreparedStatementCreatorForIndexedParams())); + classicOps.update(statementCreatorForIndexedParams())); } @Override public int update(KeyHolder generatedKeyHolder) { return (useNamedParams() ? namedParamOps.update(this.sql, this.namedParamSource, generatedKeyHolder) : - classicOps.update(getPreparedStatementCreatorForIndexedParams(true), generatedKeyHolder)); + classicOps.update(statementCreatorForIndexedParamsWithKeys(null), generatedKeyHolder)); + } + + @Override + public int update(KeyHolder generatedKeyHolder, String... keyColumnNames) { + return (useNamedParams() ? + namedParamOps.update(this.sql, this.namedParamSource, generatedKeyHolder, keyColumnNames) : + classicOps.update(statementCreatorForIndexedParamsWithKeys(keyColumnNames), generatedKeyHolder)); } private boolean useNamedParams() { @@ -257,14 +264,19 @@ final class DefaultJdbcClient implements JdbcClient { return hasNamedParams; } - private PreparedStatementCreator getPreparedStatementCreatorForIndexedParams() { - return getPreparedStatementCreatorForIndexedParams(false); + private PreparedStatementCreator statementCreatorForIndexedParams() { + return new PreparedStatementCreatorFactory(this.sql).newPreparedStatementCreator(this.indexedParams); } - private PreparedStatementCreator getPreparedStatementCreatorForIndexedParams(boolean returnGeneratedKeys) { - PreparedStatementCreatorFactory factory = new PreparedStatementCreatorFactory(this.sql); - factory.setReturnGeneratedKeys(returnGeneratedKeys); - return factory.newPreparedStatementCreator(this.indexedParams); + private PreparedStatementCreator statementCreatorForIndexedParamsWithKeys(@Nullable String[] keyColumnNames) { + PreparedStatementCreatorFactory pscf = new PreparedStatementCreatorFactory(this.sql); + if (keyColumnNames != null) { + pscf.setGeneratedKeysColumnNames(keyColumnNames); + } + else { + pscf.setReturnGeneratedKeys(true); + } + return pscf.newPreparedStatementCreator(this.indexedParams); } diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/JdbcClient.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/JdbcClient.java index e42e67ab114..5e6d98e6ddc 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/JdbcClient.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/JdbcClient.java @@ -288,6 +288,16 @@ public interface JdbcClient { * @see java.sql.PreparedStatement#executeUpdate() */ int update(KeyHolder generatedKeyHolder); + + /** + * Execute the provided SQL statement as an update. + * @param generatedKeyHolder a KeyHolder that will hold the generated keys + * (typically a {@link org.springframework.jdbc.support.GeneratedKeyHolder}) + * @param keyColumnNames names of the columns that will have keys generated for them + * @return the number of rows affected + * @see java.sql.PreparedStatement#executeUpdate() + */ + int update(KeyHolder generatedKeyHolder, String... keyColumnNames); } diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIndexedParameterTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIndexedParameterTests.java index 9789e4bb462..a0158a36f8a 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIndexedParameterTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIndexedParameterTests.java @@ -339,7 +339,7 @@ public class JdbcClientIndexedParameterTests { } @Test - public void updateAndGeneratedKeys() throws SQLException { + public void updateWithGeneratedKeys() throws SQLException { given(resultSetMetaData.getColumnCount()).willReturn(1); given(resultSetMetaData.getColumnLabel(1)).willReturn("1"); given(resultSet.getMetaData()).willReturn(resultSetMetaData); @@ -362,4 +362,28 @@ public class JdbcClientIndexedParameterTests { verify(connection).close(); } + @Test + public void updateWithGeneratedKeysAndKeyColumnNames() throws SQLException { + given(resultSetMetaData.getColumnCount()).willReturn(1); + given(resultSetMetaData.getColumnLabel(1)).willReturn("1"); + given(resultSet.getMetaData()).willReturn(resultSetMetaData); + given(resultSet.next()).willReturn(true, false); + given(resultSet.getObject(1)).willReturn(11); + given(preparedStatement.executeUpdate()).willReturn(1); + given(preparedStatement.getGeneratedKeys()).willReturn(resultSet); + given(connection.prepareStatement(INSERT_GENERATE_KEYS, new String[] {"id"})) + .willReturn(preparedStatement); + + KeyHolder generatedKeyHolder = new GeneratedKeyHolder(); + int rowsAffected = client.sql(INSERT_GENERATE_KEYS).param("rod").update(generatedKeyHolder, "id"); + + assertThat(rowsAffected).isEqualTo(1); + assertThat(generatedKeyHolder.getKeyList()).hasSize(1); + assertThat(generatedKeyHolder.getKey()).isEqualTo(11); + verify(preparedStatement).setString(1, "rod"); + verify(resultSet).close(); + verify(preparedStatement).close(); + verify(connection).close(); + } + } diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIntegrationTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIntegrationTests.java index 1ebada3139f..334230560ec 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIntegrationTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIntegrationTests.java @@ -34,16 +34,18 @@ import static org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType. * Integration tests for {@link JdbcClient} using an embedded H2 database. * * @author Sam Brannen + * @author Juergen Hoeller * @since 6.1 * @see JdbcClientIndexedParameterTests * @see JdbcClientNamedParameterTests */ class JdbcClientIntegrationTests { + private static final String INSERT_WITH_JDBC_PARAMS = + "INSERT INTO users (first_name, last_name) VALUES(?, ?)"; + private static final String INSERT_WITH_NAMED_PARAMS = "INSERT INTO users (first_name, last_name) VALUES(:firstName, :lastName)"; - private static final String INSERT_WITH_POSITIONAL_PARAMS = - "INSERT INTO users (first_name, last_name) VALUES(?, ?)"; private final EmbeddedDatabase embeddedDatabase = @@ -66,15 +68,16 @@ class JdbcClientIntegrationTests { this.embeddedDatabase.shutdown(); } + @Test - void updateWithGeneratedKeysAndPositionalParameters() { + void updateWithGeneratedKeys() { int expectedId = 2; String firstName = "Jane"; String lastName = "Smith"; KeyHolder generatedKeyHolder = new GeneratedKeyHolder(); - int rowsAffected = this.jdbcClient.sql(INSERT_WITH_POSITIONAL_PARAMS) + int rowsAffected = this.jdbcClient.sql(INSERT_WITH_JDBC_PARAMS) .params(firstName, lastName) .update(generatedKeyHolder); @@ -85,7 +88,25 @@ class JdbcClientIntegrationTests { } @Test - void updateWithGeneratedKeysAndNamedParameters() { + void updateWithGeneratedKeysAndKeyColumnNames() { + int expectedId = 2; + String firstName = "Jane"; + String lastName = "Smith"; + + KeyHolder generatedKeyHolder = new GeneratedKeyHolder(); + + int rowsAffected = this.jdbcClient.sql(INSERT_WITH_JDBC_PARAMS) + .params(firstName, lastName) + .update(generatedKeyHolder, "id"); + + assertThat(rowsAffected).isEqualTo(1); + assertThat(generatedKeyHolder.getKey()).isEqualTo(expectedId); + assertNumUsers(2); + assertUser(expectedId, firstName, lastName); + } + + @Test + void updateWithGeneratedKeysUsingNamedParameters() { int expectedId = 2; String firstName = "Jane"; String lastName = "Smith"; @@ -103,6 +124,26 @@ class JdbcClientIntegrationTests { assertUser(expectedId, firstName, lastName); } + @Test + void updateWithGeneratedKeysAndKeyColumnNamesUsingNamedParameters() { + int expectedId = 2; + String firstName = "Jane"; + String lastName = "Smith"; + + KeyHolder generatedKeyHolder = new GeneratedKeyHolder(); + + int rowsAffected = this.jdbcClient.sql(INSERT_WITH_NAMED_PARAMS) + .param("firstName", firstName) + .param("lastName", lastName) + .update(generatedKeyHolder, "id"); + + assertThat(rowsAffected).isEqualTo(1); + assertThat(generatedKeyHolder.getKey()).isEqualTo(expectedId); + assertNumUsers(2); + assertUser(expectedId, firstName, lastName); + } + + private void assertNumUsers(long count) { long numUsers = this.jdbcClient.sql("select count(id) from users").query(Long.class).single(); assertThat(numUsers).isEqualTo(count); @@ -113,6 +154,7 @@ class JdbcClientIntegrationTests { assertThat(user).isEqualTo(new User(id, firstName, lastName)); } + record User(long id, String firstName, String lastName) {}; } diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientNamedParameterTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientNamedParameterTests.java index 56487cbe7a1..784b6917e3a 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientNamedParameterTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientNamedParameterTests.java @@ -406,7 +406,7 @@ class JdbcClientNamedParameterTests { } @Test - void updateAndGeneratedKeys() throws SQLException { + void updateWithGeneratedKeys() throws SQLException { given(resultSetMetaData.getColumnCount()).willReturn(1); given(resultSetMetaData.getColumnLabel(1)).willReturn("1"); given(resultSet.getMetaData()).willReturn(resultSetMetaData); @@ -429,4 +429,28 @@ class JdbcClientNamedParameterTests { verify(connection).close(); } + @Test + public void updateWithGeneratedKeysAndKeyColumnNames() throws SQLException { + given(resultSetMetaData.getColumnCount()).willReturn(1); + given(resultSetMetaData.getColumnLabel(1)).willReturn("1"); + given(resultSet.getMetaData()).willReturn(resultSetMetaData); + given(resultSet.next()).willReturn(true, false); + given(resultSet.getObject(1)).willReturn(11); + given(preparedStatement.executeUpdate()).willReturn(1); + given(preparedStatement.getGeneratedKeys()).willReturn(resultSet); + given(connection.prepareStatement(INSERT_GENERATE_KEYS_PARSED, new String[] {"id"})) + .willReturn(preparedStatement); + + KeyHolder generatedKeyHolder = new GeneratedKeyHolder(); + int rowsAffected = client.sql(INSERT_GENERATE_KEYS).param("name", "rod").update(generatedKeyHolder, "id"); + + assertThat(rowsAffected).isEqualTo(1); + assertThat(generatedKeyHolder.getKeyList()).hasSize(1); + assertThat(generatedKeyHolder.getKey()).isEqualTo(11); + verify(preparedStatement).setString(1, "rod"); + verify(resultSet).close(); + verify(preparedStatement).close(); + verify(connection).close(); + } + }