diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java index c15365fec86..753fac912fb 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java @@ -242,7 +242,7 @@ final class DefaultJdbcClient implements JdbcClient { public int update(KeyHolder generatedKeyHolder) { return (useNamedParams() ? namedParamOps.update(this.sql, this.namedParamSource, generatedKeyHolder) : - classicOps.update(getPreparedStatementCreatorForIndexedParams(), generatedKeyHolder)); + classicOps.update(getPreparedStatementCreatorForIndexedParams(true), generatedKeyHolder)); } private boolean useNamedParams() { @@ -258,7 +258,13 @@ final class DefaultJdbcClient implements JdbcClient { } private PreparedStatementCreator getPreparedStatementCreatorForIndexedParams() { - return new PreparedStatementCreatorFactory(this.sql).newPreparedStatementCreator(this.indexedParams); + return getPreparedStatementCreatorForIndexedParams(false); + } + + private PreparedStatementCreator getPreparedStatementCreatorForIndexedParams(boolean returnGeneratedKeys) { + PreparedStatementCreatorFactory factory = new PreparedStatementCreatorFactory(this.sql); + factory.setReturnGeneratedKeys(returnGeneratedKeys); + return factory.newPreparedStatementCreator(this.indexedParams); } diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIntegrationTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIntegrationTests.java new file mode 100644 index 00000000000..1ebada3139f --- /dev/null +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIntegrationTests.java @@ -0,0 +1,118 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jdbc.core.simple; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.io.ClassRelativeResourceLoader; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.init.DatabasePopulator; +import org.springframework.jdbc.support.GeneratedKeyHolder; +import org.springframework.jdbc.support.KeyHolder; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType.H2; + +/** + * Integration tests for {@link JdbcClient} using an embedded H2 database. + * + * @author Sam Brannen + * @since 6.1 + * @see JdbcClientIndexedParameterTests + * @see JdbcClientNamedParameterTests + */ +class JdbcClientIntegrationTests { + + private static final String INSERT_WITH_NAMED_PARAMS = + "INSERT INTO users (first_name, last_name) VALUES(:firstName, :lastName)"; + private static final String INSERT_WITH_POSITIONAL_PARAMS = + "INSERT INTO users (first_name, last_name) VALUES(?, ?)"; + + + private final EmbeddedDatabase embeddedDatabase = + new EmbeddedDatabaseBuilder(new ClassRelativeResourceLoader(DatabasePopulator.class)) + .generateUniqueName(true) + .setType(H2) + .addScripts("users-schema.sql", "users-data.sql") + .build(); + + private final JdbcClient jdbcClient = JdbcClient.create(this.embeddedDatabase); + + + @BeforeEach + void checkDatabase() { + assertNumUsers(1); + } + + @AfterEach + void shutdownDatabase() { + this.embeddedDatabase.shutdown(); + } + + @Test + void updateWithGeneratedKeysAndPositionalParameters() { + int expectedId = 2; + String firstName = "Jane"; + String lastName = "Smith"; + + KeyHolder generatedKeyHolder = new GeneratedKeyHolder(); + + int rowsAffected = this.jdbcClient.sql(INSERT_WITH_POSITIONAL_PARAMS) + .params(firstName, lastName) + .update(generatedKeyHolder); + + assertThat(rowsAffected).isEqualTo(1); + assertThat(generatedKeyHolder.getKey()).isEqualTo(expectedId); + assertNumUsers(2); + assertUser(expectedId, firstName, lastName); + } + + @Test + void updateWithGeneratedKeysAndNamedParameters() { + int expectedId = 2; + String firstName = "Jane"; + String lastName = "Smith"; + + KeyHolder generatedKeyHolder = new GeneratedKeyHolder(); + + int rowsAffected = this.jdbcClient.sql(INSERT_WITH_NAMED_PARAMS) + .param("firstName", firstName) + .param("lastName", lastName) + .update(generatedKeyHolder); + + assertThat(rowsAffected).isEqualTo(1); + assertThat(generatedKeyHolder.getKey()).isEqualTo(expectedId); + assertNumUsers(2); + assertUser(expectedId, firstName, lastName); + } + + private void assertNumUsers(long count) { + long numUsers = this.jdbcClient.sql("select count(id) from users").query(Long.class).single(); + assertThat(numUsers).isEqualTo(count); + } + + private void assertUser(long id, String firstName, String lastName) { + User user = this.jdbcClient.sql("select * from users where id = ?").param(id).query(User.class).single(); + assertThat(user).isEqualTo(new User(id, firstName, lastName)); + } + + record User(long id, String firstName, String lastName) {}; + +}