From 855fe39b7f462fdfddb54cbe9e252ea0b444944e Mon Sep 17 00:00:00 2001 From: Juergen Hoeller Date: Sun, 3 Sep 2023 00:44:11 +0200 Subject: [PATCH] Use PreparedStatementCreator for query/update with indexed params Closes gh-31122 --- .../core/PreparedStatementCreatorFactory.java | 24 ++++++++----- .../jdbc/core/simple/DefaultJdbcClient.java | 14 +++++--- .../JdbcClientIndexedParameterTests.java | 32 +++++++++++++++++ .../simple/JdbcClientNamedParameterTests.java | 34 +++++++++++++++++++ .../jdbc/object/SqlUpdateTests.java | 6 ++-- 5 files changed, 94 insertions(+), 16 deletions(-) diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java index e01abfaeb16..ddc9ac073c6 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,8 +46,9 @@ public class PreparedStatementCreatorFactory { /** The SQL, which won't change when the parameters change. */ private final String sql; - /** List of SqlParameter objects (may not be {@code null}). */ - private final List declaredParameters; + /** List of SqlParameter objects (may be {@code null}). */ + @Nullable + private List declaredParameters; private int resultSetType = ResultSet.TYPE_FORWARD_ONLY; @@ -66,7 +67,6 @@ public class PreparedStatementCreatorFactory { */ public PreparedStatementCreatorFactory(String sql) { this.sql = sql; - this.declaredParameters = new ArrayList<>(); } /** @@ -104,6 +104,9 @@ public class PreparedStatementCreatorFactory { * @param param the parameter to add to the list of declared parameters */ public void addParameter(SqlParameter param) { + if (this.declaredParameters == null) { + this.declaredParameters = new ArrayList<>(); + } this.declaredParameters.add(param); } @@ -180,7 +183,7 @@ public class PreparedStatementCreatorFactory { */ public PreparedStatementCreator newPreparedStatementCreator(String sqlToUse, @Nullable Object[] params) { return new PreparedStatementCreatorImpl( - sqlToUse, params != null ? Arrays.asList(params) : Collections.emptyList()); + sqlToUse, (params != null ? Arrays.asList(params) : Collections.emptyList())); } @@ -201,7 +204,7 @@ public class PreparedStatementCreatorFactory { public PreparedStatementCreatorImpl(String actualSql, List parameters) { this.actualSql = actualSql; this.parameters = parameters; - if (parameters.size() != declaredParameters.size()) { + if (declaredParameters != null && parameters.size() != declaredParameters.size()) { // Account for named parameters being used multiple times Set names = new HashSet<>(); for (int i = 0; i < parameters.size(); i++) { @@ -249,14 +252,14 @@ public class PreparedStatementCreatorFactory { int sqlColIndx = 1; for (int i = 0; i < this.parameters.size(); i++) { Object in = this.parameters.get(i); - SqlParameter declaredParameter; + SqlParameter declaredParameter = null; // SqlParameterValue overrides declared parameter meta-data, in particular for // independence from the declared parameter position in case of named parameters. if (in instanceof SqlParameterValue sqlParameterValue) { in = sqlParameterValue.getValue(); declaredParameter = sqlParameterValue; } - else { + else if (declaredParameters != null) { if (declaredParameters.size() <= i) { throw new InvalidDataAccessApiUsageException( "SQL [" + sql + "]: unable to access parameter number " + (i + 1) + @@ -265,7 +268,10 @@ public class PreparedStatementCreatorFactory { } declaredParameter = declaredParameters.get(i); } - if (in instanceof Iterable entries && declaredParameter.getSqlType() != Types.ARRAY) { + if (declaredParameter == null) { + StatementCreatorUtils.setParameterValue(ps, sqlColIndx++, SqlTypeValue.TYPE_UNKNOWN, in); + } + else if (in instanceof Iterable entries && declaredParameter.getSqlType() != Types.ARRAY) { for (Object entry : entries) { if (entry instanceof Object[] valueArray) { for (Object argValue : valueArray) { diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java index 03fef20454b..5275de7c544 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/simple/DefaultJdbcClient.java @@ -28,6 +28,8 @@ import javax.sql.DataSource; import org.springframework.beans.BeanUtils; import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.PreparedStatementCreator; +import org.springframework.jdbc.core.PreparedStatementCreatorFactory; import org.springframework.jdbc.core.ResultSetExtractor; import org.springframework.jdbc.core.RowCallbackHandler; import org.springframework.jdbc.core.RowMapper; @@ -206,7 +208,7 @@ final class DefaultJdbcClient implements JdbcClient { namedParamOps.query(this.sql, this.namedParams, rch); } else { - classicOps.query(this.sql, rch, this.indexedParams.toArray()); + classicOps.query(getPreparedStatementCreatorForIndexedParams(), rch); } } @@ -214,7 +216,7 @@ final class DefaultJdbcClient implements JdbcClient { public T query(ResultSetExtractor rse) { T result = (useNamedParams() ? namedParamOps.query(this.sql, this.namedParams, rse) : - classicOps.query(this.sql, rse, this.indexedParams.toArray())); + classicOps.query(getPreparedStatementCreatorForIndexedParams(), rse)); Assert.state(result != null, "No result from ResultSetExtractor"); return result; } @@ -223,14 +225,14 @@ final class DefaultJdbcClient implements JdbcClient { public int update() { return (useNamedParams() ? namedParamOps.update(this.sql, this.namedParamSource) : - classicOps.update(this.sql, this.indexedParams.toArray())); + classicOps.update(getPreparedStatementCreatorForIndexedParams())); } @Override public int update(KeyHolder generatedKeyHolder) { return (useNamedParams() ? namedParamOps.update(this.sql, this.namedParamSource, generatedKeyHolder) : - classicOps.update(this.sql, this.indexedParams.toArray(), generatedKeyHolder)); + classicOps.update(getPreparedStatementCreatorForIndexedParams(), generatedKeyHolder)); } private boolean useNamedParams() { @@ -245,6 +247,10 @@ final class DefaultJdbcClient implements JdbcClient { return hasNamedParams; } + private PreparedStatementCreator getPreparedStatementCreatorForIndexedParams() { + return new PreparedStatementCreatorFactory(this.sql).newPreparedStatementCreator(this.indexedParams); + } + private class IndexedParamResultQuerySpec implements ResultQuerySpec { diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIndexedParameterTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIndexedParameterTests.java index 1a07a2c76c9..bd2ab5ba499 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIndexedParameterTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientIndexedParameterTests.java @@ -20,6 +20,7 @@ import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Types; import java.util.ArrayList; @@ -35,6 +36,8 @@ import org.junit.jupiter.api.Test; import org.springframework.jdbc.Customer; import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.jdbc.support.GeneratedKeyHolder; +import org.springframework.jdbc.support.KeyHolder; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.anyString; @@ -56,6 +59,9 @@ public class JdbcClientIndexedParameterTests { private static final String UPDATE_NAMED_PARAMETERS = "update seat_status set booking_id = null where performance_id = ? and price_band_id = ?"; + private static final String INSERT_GENERATE_KEYS = + "insert into show (name) values(?)"; + private static final String[] COLUMN_NAMES = new String[] {"id", "forename"}; @@ -67,6 +73,8 @@ public class JdbcClientIndexedParameterTests { private ResultSet resultSet = mock(); + private ResultSetMetaData resultSetMetaData = mock(); + private DatabaseMetaData databaseMetaData = mock(); private JdbcClient client = JdbcClient.create(dataSource); @@ -329,4 +337,28 @@ public class JdbcClientIndexedParameterTests { verify(connection).close(); } + @Test + public void testUpdateAndGeneratedKeys() throws SQLException { + given(resultSetMetaData.getColumnCount()).willReturn(1); + given(resultSetMetaData.getColumnLabel(1)).willReturn("1"); + given(resultSet.getMetaData()).willReturn(resultSetMetaData); + given(resultSet.next()).willReturn(true, false); + given(resultSet.getObject(1)).willReturn(11); + given(preparedStatement.executeUpdate()).willReturn(1); + given(preparedStatement.getGeneratedKeys()).willReturn(resultSet); + given(connection.prepareStatement(INSERT_GENERATE_KEYS, PreparedStatement.RETURN_GENERATED_KEYS)) + .willReturn(preparedStatement); + + KeyHolder generatedKeyHolder = new GeneratedKeyHolder(); + int rowsAffected = client.sql(INSERT_GENERATE_KEYS).param("rod").update(generatedKeyHolder); + + assertThat(rowsAffected).isEqualTo(1); + assertThat(generatedKeyHolder.getKeyList()).hasSize(1); + assertThat(generatedKeyHolder.getKey()).isEqualTo(11); + verify(preparedStatement).setString(1, "rod"); + verify(resultSet).close(); + verify(preparedStatement).close(); + verify(connection).close(); + } + } diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientNamedParameterTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientNamedParameterTests.java index cd3348b09ce..f2fe53beb2b 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientNamedParameterTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/simple/JdbcClientNamedParameterTests.java @@ -20,6 +20,7 @@ import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Types; import java.util.ArrayList; @@ -37,6 +38,8 @@ import org.junit.jupiter.api.Test; import org.springframework.jdbc.Customer; import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.jdbc.support.GeneratedKeyHolder; +import org.springframework.jdbc.support.KeyHolder; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.anyString; @@ -62,6 +65,11 @@ public class JdbcClientNamedParameterTests { private static final String UPDATE_NAMED_PARAMETERS_PARSED = "update seat_status set booking_id = null where performance_id = ? and price_band_id = ?"; + private static final String INSERT_GENERATE_KEYS = + "insert into show (name) values(:name)"; + private static final String INSERT_GENERATE_KEYS_PARSED = + "insert into show (name) values(?)"; + private static final String[] COLUMN_NAMES = new String[] {"id", "forename"}; @@ -73,6 +81,8 @@ public class JdbcClientNamedParameterTests { private ResultSet resultSet = mock(); + private ResultSetMetaData resultSetMetaData = mock(); + private DatabaseMetaData databaseMetaData = mock(); private JdbcClient client = JdbcClient.create(dataSource); @@ -335,4 +345,28 @@ public class JdbcClientNamedParameterTests { verify(connection).close(); } + @Test + public void testUpdateAndGeneratedKeys() throws SQLException { + given(resultSetMetaData.getColumnCount()).willReturn(1); + given(resultSetMetaData.getColumnLabel(1)).willReturn("1"); + given(resultSet.getMetaData()).willReturn(resultSetMetaData); + given(resultSet.next()).willReturn(true, false); + given(resultSet.getObject(1)).willReturn(11); + given(preparedStatement.executeUpdate()).willReturn(1); + given(preparedStatement.getGeneratedKeys()).willReturn(resultSet); + given(connection.prepareStatement(INSERT_GENERATE_KEYS_PARSED, PreparedStatement.RETURN_GENERATED_KEYS)) + .willReturn(preparedStatement); + + KeyHolder generatedKeyHolder = new GeneratedKeyHolder(); + int rowsAffected = client.sql(INSERT_GENERATE_KEYS).param("name", "rod").update(generatedKeyHolder); + + assertThat(rowsAffected).isEqualTo(1); + assertThat(generatedKeyHolder.getKeyList()).hasSize(1); + assertThat(generatedKeyHolder.getKey()).isEqualTo(11); + verify(preparedStatement).setString(1, "rod"); + verify(resultSet).close(); + verify(preparedStatement).close(); + verify(connection).close(); + } + } diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/object/SqlUpdateTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/object/SqlUpdateTests.java index 8f29e657cd8..2904dec4139 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/object/SqlUpdateTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/object/SqlUpdateTests.java @@ -211,9 +211,8 @@ public class SqlUpdateTests { given(resultSet.getObject(1)).willReturn(11); given(preparedStatement.executeUpdate()).willReturn(1); given(preparedStatement.getGeneratedKeys()).willReturn(resultSet); - given(connection.prepareStatement(INSERT_GENERATE_KEYS, - PreparedStatement.RETURN_GENERATED_KEYS) - ).willReturn(preparedStatement); + given(connection.prepareStatement(INSERT_GENERATE_KEYS, PreparedStatement.RETURN_GENERATED_KEYS)) + .willReturn(preparedStatement); GeneratedKeysUpdater pc = new GeneratedKeysUpdater(); KeyHolder generatedKeyHolder = new GeneratedKeyHolder(); @@ -294,6 +293,7 @@ public class SqlUpdateTests { pc::run); } + private class Updater extends SqlUpdate { public Updater() {