diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQuery.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQuery.java index f50d7bfb8..3e7450e6c 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQuery.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQuery.java @@ -23,21 +23,18 @@ import java.sql.JDBCType; import java.sql.SQLType; import java.util.ArrayList; import java.util.Collection; +import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.function.Function; import java.util.function.Supplier; import org.springframework.beans.BeanInstantiationException; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.BeanFactory; -import org.springframework.data.jdbc.core.convert.JdbcColumnTypes; import org.springframework.data.jdbc.core.convert.JdbcConverter; import org.springframework.data.jdbc.core.mapping.JdbcValue; -import org.springframework.data.jdbc.support.JdbcUtil; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.repository.query.RelationalParameterAccessor; -import org.springframework.data.relational.repository.query.RelationalParameters; import org.springframework.data.relational.repository.query.RelationalParametersParameterAccessor; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; @@ -47,7 +44,6 @@ import org.springframework.data.repository.query.SpelEvaluator; import org.springframework.data.repository.query.SpelQueryContext; import org.springframework.data.util.Lazy; import org.springframework.data.util.TypeInformation; -import org.springframework.data.util.TypeUtils; import org.springframework.jdbc.core.ResultSetExtractor; import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; @@ -55,7 +51,6 @@ import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; -import org.springframework.util.ConcurrentReferenceHashMap; import org.springframework.util.ObjectUtils; /** @@ -75,7 +70,7 @@ import org.springframework.util.ObjectUtils; */ public class StringBasedJdbcQuery extends AbstractJdbcQuery { - private static final String PARAMETER_NEEDS_TO_BE_NAMED = "For queries with named parameters you need to provide names for method parameters; Use @Param for query method parameters, or when on Java 8+ use the javac flag -parameters"; + private static final String PARAMETER_NEEDS_TO_BE_NAMED = "For queries with named parameters you need to provide names for method parameters; Use @Param for query method parameters, or use the javac flag -parameters"; private final JdbcConverter converter; private final RowMapperFactory rowMapperFactory; private final SpelEvaluator spelEvaluator; @@ -188,77 +183,103 @@ public class StringBasedJdbcQuery extends AbstractJdbcQuery { private MapSqlParameterSource bindParameters(RelationalParameterAccessor accessor) { - MapSqlParameterSource parameters = new MapSqlParameterSource(); - Parameters bindableParameters = accessor.getBindableParameters(); + MapSqlParameterSource parameters = new MapSqlParameterSource( + new LinkedHashMap<>(bindableParameters.getNumberOfParameters(), 1.0f)); for (Parameter bindableParameter : bindableParameters) { - convertAndAddParameter(parameters, bindableParameter, accessor.getBindableValue(bindableParameter.getIndex())); + + Object value = accessor.getBindableValue(bindableParameter.getIndex()); + String parameterName = bindableParameter.getName() + .orElseThrow(() -> new IllegalStateException(PARAMETER_NEEDS_TO_BE_NAMED)); + JdbcParameters.JdbcParameter parameter = getQueryMethod().getParameters() + .getParameter(bindableParameter.getIndex()); + + JdbcValue jdbcValue = writeValue(value, parameter.getTypeInformation(), parameter); + SQLType jdbcType = jdbcValue.getJdbcType(); + + if (jdbcType == null) { + parameters.addValue(parameterName, jdbcValue.getValue()); + } else { + parameters.addValue(parameterName, jdbcValue.getValue(), jdbcType.getVendorTypeNumber()); + } } return parameters; } - private void convertAndAddParameter(MapSqlParameterSource parameters, Parameter p, Object value) { + private JdbcValue writeValue(@Nullable Object value, TypeInformation typeInformation, + JdbcParameters.JdbcParameter parameter) { - String parameterName = p.getName().orElseThrow(() -> new IllegalStateException(PARAMETER_NEEDS_TO_BE_NAMED)); + if (value == null) { + return JdbcValue.of(value, parameter.getSqlType()); + } - JdbcParameters.JdbcParameter parameter = getQueryMethod().getParameters().getParameter(p.getIndex()); - TypeInformation typeInformation = parameter.getTypeInformation(); + if (typeInformation.isCollectionLike() && value instanceof Collection collection) { - JdbcValue jdbcValue; - if (typeInformation.isCollectionLike() // - && value instanceof Collection collectionValue// - ) { - if ( typeInformation.getActualType().getType().isArray() ){ + TypeInformation actualType = typeInformation.getActualType(); - TypeInformation arrayElementType = typeInformation.getActualType().getActualType(); + // tuple-binding + if (actualType != null && actualType.getType().isArray()) { - List mapped = new ArrayList<>(); + TypeInformation nestedElementType = actualType.getRequiredActualType(); + return writeCollection(collection, JDBCType.OTHER, + array -> writeArrayValue(parameter, array, nestedElementType)); + } - for (Object array : collectionValue) { - int length = Array.getLength(array); - Object[] mappedArray = new Object[length]; + // parameter expansion + return writeCollection(collection, parameter.getActualSqlType(), + it -> converter.writeJdbcValue(it, typeInformation.getRequiredActualType(), parameter.getActualSqlType())); + } - for (int i = 0; i < length; i++) { - Object element = Array.get(array, i); - JdbcValue elementJdbcValue = converter.writeJdbcValue(element, arrayElementType, parameter.getActualSqlType()); + SQLType sqlType = parameter.getSqlType(); + return converter.writeJdbcValue(value, typeInformation, sqlType); + } - mappedArray[i] = elementJdbcValue.getValue(); - } - mapped.add(mappedArray); - } - jdbcValue = JdbcValue.of(mapped, JDBCType.OTHER); + private JdbcValue writeCollection(Collection value, SQLType defaultType, Function mapper) { - } else { - List mapped = new ArrayList<>(); - SQLType jdbcType = null; + if (value.isEmpty()) { + return JdbcValue.of(value, defaultType); + } - TypeInformation actualType = typeInformation.getRequiredActualType(); - for (Object o : collectionValue) { - JdbcValue elementJdbcValue = converter.writeJdbcValue(o, actualType, parameter.getActualSqlType()); - if (jdbcType == null) { - jdbcType = elementJdbcValue.getJdbcType(); - } + JdbcValue jdbcValue; + List mapped = new ArrayList<>(value.size()); + SQLType jdbcType = null; - mapped.add(elementJdbcValue.getValue()); - } + for (Object o : value) { - jdbcValue = JdbcValue.of(mapped, jdbcType); + Object mappedValue = mapper.apply(o); + + if (mappedValue instanceof JdbcValue jv) { + if (jdbcType == null) { + jdbcType = jv.getJdbcType(); + } + mappedValue = jv.getValue(); } - } else { - SQLType sqlType = parameter.getSqlType(); - jdbcValue = converter.writeJdbcValue(value, typeInformation, sqlType); + mapped.add(mappedValue); } - SQLType jdbcType = jdbcValue.getJdbcType(); - if (jdbcType == null) { + jdbcValue = JdbcValue.of(mapped, jdbcType == null ? defaultType : jdbcType); - parameters.addValue(parameterName, jdbcValue.getValue()); - } else { - parameters.addValue(parameterName, jdbcValue.getValue(), jdbcType.getVendorTypeNumber()); + return jdbcValue; + } + + private Object[] writeArrayValue(JdbcParameters.JdbcParameter parameter, Object array, + TypeInformation nestedElementType) { + + int length = Array.getLength(array); + Object[] mappedArray = new Object[length]; + + for (int i = 0; i < length; i++) { + + Object element = Array.get(array, i); + JdbcValue elementJdbcValue = converter.writeJdbcValue(element, nestedElementType, parameter.getActualSqlType()); + + mappedArray[i] = elementJdbcValue.getValue(); } + + return mappedArray; } RowMapper determineRowMapper(ResultProcessor resultProcessor, boolean hasDynamicProjection) { diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java index 9f4773d85..8fcfd73f4 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java @@ -1350,7 +1350,9 @@ public class JdbcRepositoryIntegrationTests { new Object[]{three.idProp, "two"} // matches nothing ); - repository.findByListInTuple(tuples); + List result = repository.findByListInTuple(tuples); + + assertThat(result).containsOnly(two); } private Root createRoot(String namePrefix) {