Browse Source

Use PreparedStatementCreator for query/update with indexed params

Closes gh-31122
pull/31158/head
Juergen Hoeller 2 years ago
parent
commit
855fe39b7f
  1. 24
      spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java
  2. 14
      spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java
  3. 32
      spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIndexedParameterTests.java
  4. 34
      spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientNamedParameterTests.java
  5. 6
      spring-jdbc/src/test/java/org/springframework/jdbc/object/SqlUpdateTests.java

24
spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -46,8 +46,9 @@ public class PreparedStatementCreatorFactory { @@ -46,8 +46,9 @@ public class PreparedStatementCreatorFactory {
/** The SQL, which won't change when the parameters change. */
private final String sql;
/** List of SqlParameter objects (may not be {@code null}). */
private final List<SqlParameter> declaredParameters;
/** List of SqlParameter objects (may be {@code null}). */
@Nullable
private List<SqlParameter> declaredParameters;
private int resultSetType = ResultSet.TYPE_FORWARD_ONLY;
@ -66,7 +67,6 @@ public class PreparedStatementCreatorFactory { @@ -66,7 +67,6 @@ public class PreparedStatementCreatorFactory {
*/
public PreparedStatementCreatorFactory(String sql) {
this.sql = sql;
this.declaredParameters = new ArrayList<>();
}
/**
@ -104,6 +104,9 @@ public class PreparedStatementCreatorFactory { @@ -104,6 +104,9 @@ public class PreparedStatementCreatorFactory {
* @param param the parameter to add to the list of declared parameters
*/
public void addParameter(SqlParameter param) {
if (this.declaredParameters == null) {
this.declaredParameters = new ArrayList<>();
}
this.declaredParameters.add(param);
}
@ -180,7 +183,7 @@ public class PreparedStatementCreatorFactory { @@ -180,7 +183,7 @@ public class PreparedStatementCreatorFactory {
*/
public PreparedStatementCreator newPreparedStatementCreator(String sqlToUse, @Nullable Object[] params) {
return new PreparedStatementCreatorImpl(
sqlToUse, params != null ? Arrays.asList(params) : Collections.emptyList());
sqlToUse, (params != null ? Arrays.asList(params) : Collections.emptyList()));
}
@ -201,7 +204,7 @@ public class PreparedStatementCreatorFactory { @@ -201,7 +204,7 @@ public class PreparedStatementCreatorFactory {
public PreparedStatementCreatorImpl(String actualSql, List<?> parameters) {
this.actualSql = actualSql;
this.parameters = parameters;
if (parameters.size() != declaredParameters.size()) {
if (declaredParameters != null && parameters.size() != declaredParameters.size()) {
// Account for named parameters being used multiple times
Set<String> names = new HashSet<>();
for (int i = 0; i < parameters.size(); i++) {
@ -249,14 +252,14 @@ public class PreparedStatementCreatorFactory { @@ -249,14 +252,14 @@ public class PreparedStatementCreatorFactory {
int sqlColIndx = 1;
for (int i = 0; i < this.parameters.size(); i++) {
Object in = this.parameters.get(i);
SqlParameter declaredParameter;
SqlParameter declaredParameter = null;
// SqlParameterValue overrides declared parameter meta-data, in particular for
// independence from the declared parameter position in case of named parameters.
if (in instanceof SqlParameterValue sqlParameterValue) {
in = sqlParameterValue.getValue();
declaredParameter = sqlParameterValue;
}
else {
else if (declaredParameters != null) {
if (declaredParameters.size() <= i) {
throw new InvalidDataAccessApiUsageException(
"SQL [" + sql + "]: unable to access parameter number " + (i + 1) +
@ -265,7 +268,10 @@ public class PreparedStatementCreatorFactory { @@ -265,7 +268,10 @@ public class PreparedStatementCreatorFactory {
}
declaredParameter = declaredParameters.get(i);
}
if (in instanceof Iterable<?> entries && declaredParameter.getSqlType() != Types.ARRAY) {
if (declaredParameter == null) {
StatementCreatorUtils.setParameterValue(ps, sqlColIndx++, SqlTypeValue.TYPE_UNKNOWN, in);
}
else if (in instanceof Iterable<?> entries && declaredParameter.getSqlType() != Types.ARRAY) {
for (Object entry : entries) {
if (entry instanceof Object[] valueArray) {
for (Object argValue : valueArray) {

14
spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java

@ -28,6 +28,8 @@ import javax.sql.DataSource; @@ -28,6 +28,8 @@ import javax.sql.DataSource;
import org.springframework.beans.BeanUtils;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.PreparedStatementCreatorFactory;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.RowCallbackHandler;
import org.springframework.jdbc.core.RowMapper;
@ -206,7 +208,7 @@ final class DefaultJdbcClient implements JdbcClient { @@ -206,7 +208,7 @@ final class DefaultJdbcClient implements JdbcClient {
namedParamOps.query(this.sql, this.namedParams, rch);
}
else {
classicOps.query(this.sql, rch, this.indexedParams.toArray());
classicOps.query(getPreparedStatementCreatorForIndexedParams(), rch);
}
}
@ -214,7 +216,7 @@ final class DefaultJdbcClient implements JdbcClient { @@ -214,7 +216,7 @@ final class DefaultJdbcClient implements JdbcClient {
public <T> T query(ResultSetExtractor<T> rse) {
T result = (useNamedParams() ?
namedParamOps.query(this.sql, this.namedParams, rse) :
classicOps.query(this.sql, rse, this.indexedParams.toArray()));
classicOps.query(getPreparedStatementCreatorForIndexedParams(), rse));
Assert.state(result != null, "No result from ResultSetExtractor");
return result;
}
@ -223,14 +225,14 @@ final class DefaultJdbcClient implements JdbcClient { @@ -223,14 +225,14 @@ final class DefaultJdbcClient implements JdbcClient {
public int update() {
return (useNamedParams() ?
namedParamOps.update(this.sql, this.namedParamSource) :
classicOps.update(this.sql, this.indexedParams.toArray()));
classicOps.update(getPreparedStatementCreatorForIndexedParams()));
}
@Override
public int update(KeyHolder generatedKeyHolder) {
return (useNamedParams() ?
namedParamOps.update(this.sql, this.namedParamSource, generatedKeyHolder) :
classicOps.update(this.sql, this.indexedParams.toArray(), generatedKeyHolder));
classicOps.update(getPreparedStatementCreatorForIndexedParams(), generatedKeyHolder));
}
private boolean useNamedParams() {
@ -245,6 +247,10 @@ final class DefaultJdbcClient implements JdbcClient { @@ -245,6 +247,10 @@ final class DefaultJdbcClient implements JdbcClient {
return hasNamedParams;
}
private PreparedStatementCreator getPreparedStatementCreatorForIndexedParams() {
return new PreparedStatementCreatorFactory(this.sql).newPreparedStatementCreator(this.indexedParams);
}
private class IndexedParamResultQuerySpec implements ResultQuerySpec {

32
spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIndexedParameterTests.java

@ -20,6 +20,7 @@ import java.sql.Connection; @@ -20,6 +20,7 @@ import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
@ -35,6 +36,8 @@ import org.junit.jupiter.api.Test; @@ -35,6 +36,8 @@ import org.junit.jupiter.api.Test;
import org.springframework.jdbc.Customer;
import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
@ -56,6 +59,9 @@ public class JdbcClientIndexedParameterTests { @@ -56,6 +59,9 @@ public class JdbcClientIndexedParameterTests {
private static final String UPDATE_NAMED_PARAMETERS =
"update seat_status set booking_id = null where performance_id = ? and price_band_id = ?";
private static final String INSERT_GENERATE_KEYS =
"insert into show (name) values(?)";
private static final String[] COLUMN_NAMES = new String[] {"id", "forename"};
@ -67,6 +73,8 @@ public class JdbcClientIndexedParameterTests { @@ -67,6 +73,8 @@ public class JdbcClientIndexedParameterTests {
private ResultSet resultSet = mock();
private ResultSetMetaData resultSetMetaData = mock();
private DatabaseMetaData databaseMetaData = mock();
private JdbcClient client = JdbcClient.create(dataSource);
@ -329,4 +337,28 @@ public class JdbcClientIndexedParameterTests { @@ -329,4 +337,28 @@ public class JdbcClientIndexedParameterTests {
verify(connection).close();
}
@Test
public void testUpdateAndGeneratedKeys() throws SQLException {
given(resultSetMetaData.getColumnCount()).willReturn(1);
given(resultSetMetaData.getColumnLabel(1)).willReturn("1");
given(resultSet.getMetaData()).willReturn(resultSetMetaData);
given(resultSet.next()).willReturn(true, false);
given(resultSet.getObject(1)).willReturn(11);
given(preparedStatement.executeUpdate()).willReturn(1);
given(preparedStatement.getGeneratedKeys()).willReturn(resultSet);
given(connection.prepareStatement(INSERT_GENERATE_KEYS, PreparedStatement.RETURN_GENERATED_KEYS))
.willReturn(preparedStatement);
KeyHolder generatedKeyHolder = new GeneratedKeyHolder();
int rowsAffected = client.sql(INSERT_GENERATE_KEYS).param("rod").update(generatedKeyHolder);
assertThat(rowsAffected).isEqualTo(1);
assertThat(generatedKeyHolder.getKeyList()).hasSize(1);
assertThat(generatedKeyHolder.getKey()).isEqualTo(11);
verify(preparedStatement).setString(1, "rod");
verify(resultSet).close();
verify(preparedStatement).close();
verify(connection).close();
}
}

34
spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientNamedParameterTests.java

@ -20,6 +20,7 @@ import java.sql.Connection; @@ -20,6 +20,7 @@ import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
@ -37,6 +38,8 @@ import org.junit.jupiter.api.Test; @@ -37,6 +38,8 @@ import org.junit.jupiter.api.Test;
import org.springframework.jdbc.Customer;
import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
@ -62,6 +65,11 @@ public class JdbcClientNamedParameterTests { @@ -62,6 +65,11 @@ public class JdbcClientNamedParameterTests {
private static final String UPDATE_NAMED_PARAMETERS_PARSED =
"update seat_status set booking_id = null where performance_id = ? and price_band_id = ?";
private static final String INSERT_GENERATE_KEYS =
"insert into show (name) values(:name)";
private static final String INSERT_GENERATE_KEYS_PARSED =
"insert into show (name) values(?)";
private static final String[] COLUMN_NAMES = new String[] {"id", "forename"};
@ -73,6 +81,8 @@ public class JdbcClientNamedParameterTests { @@ -73,6 +81,8 @@ public class JdbcClientNamedParameterTests {
private ResultSet resultSet = mock();
private ResultSetMetaData resultSetMetaData = mock();
private DatabaseMetaData databaseMetaData = mock();
private JdbcClient client = JdbcClient.create(dataSource);
@ -335,4 +345,28 @@ public class JdbcClientNamedParameterTests { @@ -335,4 +345,28 @@ public class JdbcClientNamedParameterTests {
verify(connection).close();
}
@Test
public void testUpdateAndGeneratedKeys() throws SQLException {
given(resultSetMetaData.getColumnCount()).willReturn(1);
given(resultSetMetaData.getColumnLabel(1)).willReturn("1");
given(resultSet.getMetaData()).willReturn(resultSetMetaData);
given(resultSet.next()).willReturn(true, false);
given(resultSet.getObject(1)).willReturn(11);
given(preparedStatement.executeUpdate()).willReturn(1);
given(preparedStatement.getGeneratedKeys()).willReturn(resultSet);
given(connection.prepareStatement(INSERT_GENERATE_KEYS_PARSED, PreparedStatement.RETURN_GENERATED_KEYS))
.willReturn(preparedStatement);
KeyHolder generatedKeyHolder = new GeneratedKeyHolder();
int rowsAffected = client.sql(INSERT_GENERATE_KEYS).param("name", "rod").update(generatedKeyHolder);
assertThat(rowsAffected).isEqualTo(1);
assertThat(generatedKeyHolder.getKeyList()).hasSize(1);
assertThat(generatedKeyHolder.getKey()).isEqualTo(11);
verify(preparedStatement).setString(1, "rod");
verify(resultSet).close();
verify(preparedStatement).close();
verify(connection).close();
}
}

6
spring-jdbc/src/test/java/org/springframework/jdbc/object/SqlUpdateTests.java

@ -211,9 +211,8 @@ public class SqlUpdateTests { @@ -211,9 +211,8 @@ public class SqlUpdateTests {
given(resultSet.getObject(1)).willReturn(11);
given(preparedStatement.executeUpdate()).willReturn(1);
given(preparedStatement.getGeneratedKeys()).willReturn(resultSet);
given(connection.prepareStatement(INSERT_GENERATE_KEYS,
PreparedStatement.RETURN_GENERATED_KEYS)
).willReturn(preparedStatement);
given(connection.prepareStatement(INSERT_GENERATE_KEYS, PreparedStatement.RETURN_GENERATED_KEYS))
.willReturn(preparedStatement);
GeneratedKeysUpdater pc = new GeneratedKeysUpdater();
KeyHolder generatedKeyHolder = new GeneratedKeyHolder();
@ -294,6 +293,7 @@ public class SqlUpdateTests { @@ -294,6 +293,7 @@ public class SqlUpdateTests {
pc::run);
}
private class Updater extends SqlUpdate {
public Updater() {

Loading…
Cancel
Save