diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/AotRepositoryFragmentSupport.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/AotRepositoryFragmentSupport.java index 52b14d41d..d18014220 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/AotRepositoryFragmentSupport.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/AotRepositoryFragmentSupport.java @@ -38,7 +38,6 @@ import org.springframework.data.expression.ValueExpression; import org.springframework.data.jdbc.core.JdbcAggregateOperations; import org.springframework.data.jdbc.core.convert.JdbcColumnTypes; import org.springframework.data.jdbc.core.mapping.JdbcValue; -import org.springframework.data.jdbc.repository.query.EscapingParameterSource; import org.springframework.data.jdbc.repository.query.JdbcParameters; import org.springframework.data.jdbc.repository.query.JdbcValueBindUtil; import org.springframework.data.jdbc.repository.query.RowMapperFactory; @@ -60,6 +59,8 @@ import org.springframework.util.ConcurrentLruCache; /** * Support class for JDBC AOT repository fragments. + *

+ * This class is indented to be used by generated AOT fragments and not to be used directly. * * @author Mark Paluch * @since 4.0 @@ -141,10 +142,6 @@ public class AotRepositoryFragmentSupport { return DataAccessUtils.uniqueResult(results); } - protected SqlParameterSource escapingParameterSource(SqlParameterSource parameterSource) { - return new EscapingParameterSource(parameterSource, getDialect().getLikeEscaper()); - } - protected @Nullable Object escape(@Nullable Object value) { if (value == null) { diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java index 6919f7264..6e43d3de1 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java @@ -57,7 +57,6 @@ import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.core.RowMapperResultSetExtractor; import org.springframework.jdbc.core.SingleColumnRowMapper; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; -import org.springframework.jdbc.core.namedparam.SqlParameterSource; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; @@ -171,7 +170,7 @@ class JdbcCodeBlocks { }, b -> b.add(";\n$]"))); } - builder.add(buildQuery(false, entityQuery, criteria, this.parameterSourceVariableName, this.queryVariableName)); + builder.add(buildQuery(false, entityQuery, criteria)); if (countQuery != null) { @@ -179,15 +178,11 @@ class JdbcCodeBlocks { countAll.beginControlFlow("$T $L = () ->", LongSupplier.class, context.localVariable("countAll")); - String parameterSourceVariableName = context - .localVariable("count" + StringUtils.capitalize(this.parameterSourceVariableName)); - String queryVariableName = context.localVariable("count" + StringUtils.capitalize(this.queryVariableName)); + countAll.add(buildQuery(true, countQuery, criteria)); - countAll.add(buildQuery(true, countQuery, criteria, parameterSourceVariableName, queryVariableName)); - - countAll.addStatement("$1T $2L = queryForObject($3L, $4L, new $5T<>($1T.class))", Number.class, - context.localVariable("count"), queryVariableName, parameterSourceVariableName, - SingleColumnRowMapper.class); + countAll.addStatement("$1T $2L = $3L.executeWith(($4L, $5L) -> queryForObject($4L, $5L, new $6T<>($1T.class)))", + Number.class, context.localVariable("count"), context.localVariable("countBuilder"), + context.localVariable("sql"), context.localVariable("paramSource"), SingleColumnRowMapper.class); countAll.addStatement("return $1L != null ? $1L.longValue() : 0L", context.localVariable("count")); @@ -202,12 +197,10 @@ class JdbcCodeBlocks { return builder.build(); } - private CodeBlock buildQuery(boolean count, DerivedAotQuery aotQuery, CriteriaDefinition criteria, - String parameterSourceVariableName, String queryVariableName) { + private CodeBlock buildQuery(boolean count, DerivedAotQuery aotQuery, CriteriaDefinition criteria) { Builder builder = CodeBlock.builder(); String selection = context.localVariable(count ? "countBuilder" : "builder"); - String rawParameterSource = context.localVariable(count ? "countRawParameterSource" : "rawParameterSource"); String method; if (aotQuery.isCount()) { @@ -256,13 +249,6 @@ class JdbcCodeBlocks { builder.add(";\n$]"); - // TODO Projections - - builder.addStatement("$1T $2L = new $1T()", MapSqlParameterSource.class, rawParameterSource); - builder.addStatement("$T $L = $L.build($L)", String.class, queryVariableName, selection, rawParameterSource); - builder.addStatement("$1T $2L = escapingParameterSource($3L)", SqlParameterSource.class, - parameterSourceVariableName, rawParameterSource); - return builder.build(); } @@ -567,7 +553,10 @@ class JdbcCodeBlocks { Assert.state(aotQuery != null, "AOT Query must not be null"); - Builder builder = CodeBlock.builder(); + return doBuild(); + } + + private CodeBlock doBuild() { MethodReturn methodReturn = context.getMethodReturn(); boolean isProjecting = methodReturn.isProjecting() @@ -581,128 +570,90 @@ class JdbcCodeBlocks { String result = context.localVariable("result"); String rowMapper = context.localVariable("rowMapper"); + ExecutionDecorator decorator = getExecutionDecorator(); + if (modifying.isPresent()) { - return update(returnType); + return update(decorator, returnType); } else if (aotQuery.isCount()) { - return count(result, returnType, queryResultType); + return count(decorator, result, queryResultType, returnType); } else if (aotQuery.isExists()) { - return exists(queryResultType); + return exists(decorator, queryResultType); } else if (aotQuery.isDelete()) { - return delete(rowMapper, result, queryResultType, returnType, actualReturnType); + return delete(decorator, rowMapper, result, queryResultType, returnType, actualReturnType); } else { + return select(decorator, rowMapper, result, queryResultType, isProjecting, methodReturn); + } + } - String resultSetExtractor = null; + private ExecutionDecorator getExecutionDecorator() { - if (rowMapperClass != null) { - builder.addStatement("$T $L = new $T()", RowMapper.class, rowMapper, rowMapperClass); - } else if (StringUtils.hasText(rowMapperRef)) { - builder.addStatement("$T $L = getRowMapperFactory().getRowMapper($S)", RowMapper.class, rowMapper, - rowMapperRef); - } else if (resultSetExtractorClass == null) { + if (aotQuery instanceof DerivedAotQuery) { - Type typeToRead; + return new ExecutionDecorator() { + @Override + public String decorate(String executionCode) { + String builder = context.localVariable("builder"); - if (isProjecting) { - typeToRead = context.getReturnedType().getDomainType(); - } else { - typeToRead = methodReturn.getActualReturnClass(); + return String.format("%s.executeWith((%s, %s) -> %s)", builder, query(), paramSource(), executionCode); } - builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper, - typeToRead); - } - - if (StringUtils.hasText(resultSetExtractorRef) || resultSetExtractorClass != null) { - - resultSetExtractor = context.localVariable("resultSetExtractor"); + @Override + public String query() { + return context.localVariable("sql"); + } - if (resultSetExtractorClass != null && (rowMapperClass != null || StringUtils.hasText(rowMapperRef))) { - builder.addStatement("$T $L = new $T($L)", ResultSetExtractor.class, resultSetExtractor, - resultSetExtractorClass, rowMapper); - } else if (resultSetExtractorClass != null) { - builder.addStatement("$T $L = new $T()", ResultSetExtractor.class, resultSetExtractor, - resultSetExtractorClass); - } else if (StringUtils.hasText(resultSetExtractorRef)) { - builder.addStatement("$T $L = getRowMapperFactory().getResultSetExtractor($S)", ResultSetExtractor.class, - resultSetExtractor, resultSetExtractorRef); + @Override + public String paramSource() { + return context.localVariable("paramSource"); } + }; + } + return new ExecutionDecorator() { + @Override + public String decorate(String executionCode) { + return executionCode; } - if (StringUtils.hasText(resultSetExtractor)) { - - builder.addStatement("return ($T) getJdbcOperations().query($L, $L, $L)", queryResultType, queryVariableName, - parameterSourceVariableName, resultSetExtractor); - - return builder.build(); + @Override + public String query() { + return queryVariableName; } - boolean dynamicProjection = StringUtils.hasText(context.getDynamicProjectionParameterName()); - Object queryResultTypeRef = dynamicProjection ? context.getDynamicProjectionParameterName() : queryResultType; - - if (queryMethod.isCollectionQuery() || queryMethod.isSliceQuery() || queryMethod.isPageQuery()) { - - builder.addStatement("$1T $2L = ($1T) getJdbcOperations().query($3L, $4L, new $5T<>($6L))", List.class, - result, queryVariableName, parameterSourceVariableName, RowMapperResultSetExtractor.class, rowMapper); - - if (queryMethod.isSliceQuery() || queryMethod.isPageQuery()) { - - String pageable = context.getPageableParameterName(); - - builder.addStatement( - "$1T $2L = ($1T) convertMany($3L, %s)".formatted(dynamicProjection ? "$4L" : "$4T.class"), List.class, - context.localVariable("converted"), result, queryResultTypeRef); - - if (queryMethod.isPageQuery()) { - - builder.addStatement("return $1T.getPage($2L, $3L, $4L)", PageableExecutionUtils.class, - context.localVariable("converted"), pageable, context.localVariable("countAll")); - } else { - - builder.addStatement("boolean $1L = $2L.isPaged() && $3L.size() > $2L.getPageSize()", - context.localVariable("hasNext"), pageable, context.localVariable("converted")); + @Override + public String paramSource() { + return parameterSourceVariableName; + } + }; - builder.addStatement("return new $1T($2L ? $3L.subList(0, $4L.getPageSize()) : $3L, $4L, $2L)", - SliceImpl.class, context.localVariable("hasNext"), context.localVariable("converted"), pageable); - } + } - return builder.build(); - } + /** + * Decorates an execution. Used to wrap execution with lambda or similar. + */ + interface ExecutionDecorator { - builder.addStatement("return ($T) convertMany($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"), - methodReturn.getTypeName(), result, queryResultTypeRef); - } else if (queryMethod.isStreamQuery()) { + /** + * Decorate the given executionCode with a wrapper. + * + * @param executionCode + * @return + */ + String decorate(String executionCode); - builder.addStatement("$1T $2L = getJdbcOperations().queryForStream($3L, $4L, $5L)", Stream.class, result, - queryVariableName, parameterSourceVariableName, rowMapper); - builder.addStatement("return ($T) convertMany($L, $T.class)", methodReturn.getTypeName(), result, - queryResultTypeRef); - } else { + String query(); - builder.addStatement("$T $L = queryForObject($L, $L, $L)", Object.class, result, queryVariableName, - parameterSourceVariableName, rowMapper); + String paramSource(); - if (methodReturn.isOptional()) { - builder.addStatement( - "return ($1T) $1T.ofNullable(convertOne($2L, %s))".formatted(dynamicProjection ? "$3L" : "$3T.class"), - Optional.class, result, queryResultTypeRef); - } else { - builder.addStatement("return ($T) convertOne($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"), - methodReturn.getTypeName(), result, queryResultTypeRef); - } - } - } - - return builder.build(); } - private CodeBlock update(Class returnType) { + private CodeBlock update(ExecutionDecorator decorator, Class returnType) { String result = context.localVariable("result"); Builder builder = CodeBlock.builder(); - LordOfTheStrings.InvocationBuilder invoke = LordOfTheStrings.invoke("getJdbcOperations().update($L, $L)", - queryVariableName, parameterSourceVariableName); + LordOfTheStrings.InvocationBuilder invoke = LordOfTheStrings + .invoke(decorator.decorate("getJdbcOperations().update($L, $L)"), decorator.query(), decorator.paramSource()); if (context.getMethodReturn().isVoid()) { builder.addStatement(invoke.build()); @@ -719,7 +670,7 @@ class JdbcCodeBlocks { return builder.build(); } - private CodeBlock delete(String rowMapper, String result, TypeName queryResultType, + private CodeBlock delete(ExecutionDecorator decorator, String rowMapper, String result, TypeName queryResultType, Class returnType, Type actualReturnType) { CodeBlock.Builder builder = CodeBlock.builder(); @@ -727,8 +678,9 @@ class JdbcCodeBlocks { builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper, context.getRepositoryInformation().getDomainType()); - builder.addStatement("$1T $2L = ($1T) getJdbcOperations().query($3L, $4L, new $5T<>($6L))", List.class, result, - queryVariableName, parameterSourceVariableName, RowMapperResultSetExtractor.class, rowMapper); + builder.addStatement( + "$1T $2L = ($1T) " + decorator.decorate("getJdbcOperations().query($3L, $4L, new $5T<>($6L))"), List.class, + result, decorator.query(), decorator.paramSource(), RowMapperResultSetExtractor.class, rowMapper); builder.addStatement("$L.forEach(getOperations()::delete)", result); @@ -745,28 +697,140 @@ class JdbcCodeBlocks { return builder.build(); } - private CodeBlock count(String result, Class returnType, TypeName queryResultType) { + private CodeBlock count(ExecutionDecorator decorator, String result, TypeName queryResultType, + Class returnType) { CodeBlock.Builder builder = CodeBlock.builder(); - builder.addStatement("$1T $2L = queryForObject($3L, $4L, new $5T<>($1T.class))", Number.class, result, - queryVariableName, parameterSourceVariableName, SingleColumnRowMapper.class); + builder.addStatement("$1T $2L = " + decorator.decorate("queryForObject($3L, $4L, new $5T<>($1T.class))"), + Number.class, result, decorator.query(), decorator.paramSource(), SingleColumnRowMapper.class); builder.addStatement(LordOfTheStrings.returning(returnType) // .number(result) // + .whenPrimitiveOrBoxed(short.class, "$1L.shortValue()", result) // + .whenPrimitiveOrBoxed(byte.class, "$1L.byteValue()", result) // + .whenPrimitiveOrBoxed(double.class, "$1L.doubleValue()", result) // + .whenPrimitiveOrBoxed(float.class, "$1L.floatValue()", result) // .otherwise("($T) convertOne($L, $T.class)", context.getMethodReturn().getTypeName(), result, queryResultType) // .build()); return builder.build(); } - private CodeBlock exists(TypeName queryResultType) { + private CodeBlock exists(ExecutionDecorator decorator, TypeName queryResultType) { CodeBlock.Builder builder = CodeBlock.builder(); - builder.addStatement("return ($T) getJdbcOperations().query($L, $L, $T::next)", queryResultType, - queryVariableName, parameterSourceVariableName, ResultSet.class); + builder.addStatement("return ($T) " + decorator.decorate("getJdbcOperations().query($L, $L, $T::next)"), + queryResultType, decorator.query(), decorator.paramSource(), ResultSet.class); + + return builder.build(); + } + + private CodeBlock select(ExecutionDecorator decorator, String rowMapper, String result, TypeName queryResultType, + boolean isProjecting, MethodReturn methodReturn) { + Builder builder = CodeBlock.builder(); + + String resultSetExtractor = null; + + if (rowMapperClass != null) { + builder.addStatement("$T $L = new $T()", RowMapper.class, rowMapper, rowMapperClass); + } else if (StringUtils.hasText(rowMapperRef)) { + builder.addStatement("$T $L = getRowMapperFactory().getRowMapper($S)", RowMapper.class, rowMapper, + rowMapperRef); + } else if (resultSetExtractorClass == null) { + + Type typeToRead; + + if (isProjecting) { + typeToRead = context.getReturnedType().getDomainType(); + } else { + typeToRead = methodReturn.getActualReturnClass(); + } + + builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper, typeToRead); + } + + if (StringUtils.hasText(resultSetExtractorRef) || resultSetExtractorClass != null) { + + resultSetExtractor = context.localVariable("resultSetExtractor"); + + if (resultSetExtractorClass != null && (rowMapperClass != null || StringUtils.hasText(rowMapperRef))) { + builder.addStatement("$T $L = new $T($L)", ResultSetExtractor.class, resultSetExtractor, + resultSetExtractorClass, rowMapper); + } else if (resultSetExtractorClass != null) { + builder.addStatement("$T $L = new $T()", ResultSetExtractor.class, resultSetExtractor, + resultSetExtractorClass); + } else if (StringUtils.hasText(resultSetExtractorRef)) { + builder.addStatement("$T $L = getRowMapperFactory().getResultSetExtractor($S)", ResultSetExtractor.class, + resultSetExtractor, resultSetExtractorRef); + } + } + + if (StringUtils.hasText(resultSetExtractor)) { + + builder.addStatement("return ($T) " + decorator.decorate("getJdbcOperations().query($L, $L, $L)"), + queryResultType, decorator.query(), decorator.paramSource(), resultSetExtractor); + + return builder.build(); + } + + boolean dynamicProjection = StringUtils.hasText(context.getDynamicProjectionParameterName()); + Object queryResultTypeRef = dynamicProjection ? context.getDynamicProjectionParameterName() : queryResultType; + + if (queryMethod.isCollectionQuery() || queryMethod.isSliceQuery() || queryMethod.isPageQuery()) { + + builder.addStatement( + "$1T $2L = ($1T) " + decorator.decorate("getJdbcOperations().query($3L, $4L, new $5T<>($6L))"), List.class, + result, decorator.query(), decorator.paramSource(), RowMapperResultSetExtractor.class, rowMapper); + + if (queryMethod.isSliceQuery() || queryMethod.isPageQuery()) { + + String pageable = context.getPageableParameterName(); + + builder.addStatement( + "$1T $2L = ($1T) convertMany($3L, %s)".formatted(dynamicProjection ? "$4L" : "$4T.class"), List.class, + context.localVariable("converted"), result, queryResultTypeRef); + + if (queryMethod.isPageQuery()) { + + builder.addStatement("return $1T.getPage($2L, $3L, $4L)", PageableExecutionUtils.class, + context.localVariable("converted"), pageable, context.localVariable("countAll")); + } else { + + builder.addStatement("boolean $1L = $2L.isPaged() && $3L.size() > $2L.getPageSize()", + context.localVariable("hasNext"), pageable, context.localVariable("converted")); + + builder.addStatement("return new $1T($2L ? $3L.subList(0, $4L.getPageSize()) : $3L, $4L, $2L)", + SliceImpl.class, context.localVariable("hasNext"), context.localVariable("converted"), pageable); + } + + return builder.build(); + } + + builder.addStatement("return ($T) convertMany($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"), + methodReturn.getTypeName(), result, queryResultTypeRef); + } else if (queryMethod.isStreamQuery()) { + + builder.addStatement("$1T $2L = " + decorator.decorate("getJdbcOperations().queryForStream($3L, $4L, $5L)"), + Stream.class, result, decorator.query(), decorator.paramSource(), rowMapper); + builder.addStatement("return ($T) convertMany($L, $T.class)", methodReturn.getTypeName(), result, + queryResultTypeRef); + } else { + + builder.addStatement("$T $L = " + decorator.decorate("queryForObject($L, $L, $L)"), Object.class, result, + decorator.query(), decorator.paramSource(), rowMapper); + + if (methodReturn.isOptional()) { + builder.addStatement( + "return ($1T) $1T.ofNullable(convertOne($2L, %s))".formatted(dynamicProjection ? "$3L" : "$3T.class"), + Optional.class, result, queryResultTypeRef); + } else { + builder.addStatement("return ($T) convertOne($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"), + methodReturn.getTypeName(), result, queryResultTypeRef); + } + } return builder.build(); } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StatementFactory.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StatementFactory.java index b4efa8306..1873a4f68 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StatementFactory.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StatementFactory.java @@ -42,8 +42,10 @@ import org.springframework.data.relational.core.sql.Select; import org.springframework.data.relational.core.sql.SelectBuilder; import org.springframework.data.relational.core.sql.Table; import org.springframework.data.relational.core.sql.render.SqlRenderer; +import org.springframework.data.repository.query.ParametersSource; import org.springframework.data.util.Predicates; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; +import org.springframework.jdbc.core.namedparam.SqlParameterSource; import org.springframework.lang.Contract; /** @@ -58,12 +60,14 @@ public class StatementFactory { private final JdbcConverter converter; private final RenderContextFactory renderContextFactory; private final QueryMapper queryMapper; + private final Dialect dialect; private final SqlGeneratorSource sqlGeneratorSource; public StatementFactory(JdbcConverter converter, Dialect dialect) { this.renderContextFactory = new RenderContextFactory(dialect); this.converter = converter; this.queryMapper = new QueryMapper(converter); + this.dialect = dialect; this.sqlGeneratorSource = new SqlGeneratorSource(converter, dialect); } @@ -170,6 +174,27 @@ public class StatementFactory { return this; } + /** + * Build the SQL statement and apply the given function to the SQL string and its parameters. + * + * @param function SQL statement function accepting SQL string and parameters. + * @return the function result. + * @param type of the function result. + */ + public T executeWith(StatementFunction function) { + + MapSqlParameterSource parameterSource = new MapSqlParameterSource(); + String sql = build(parameterSource); + + return function.apply(sql, new EscapingParameterSource(parameterSource, dialect.getLikeEscaper())); + } + + /** + * Build the SQL statement and assign parameters to the given {@link ParametersSource}. + * + * @param parameterSource the parameter source to be populated. + * @return the build SQL statement. + */ public String build(MapSqlParameterSource parameterSource) { SelectBuilder.SelectLimitOffset limitOffsetBuilder = createSelectClause(entity, table); @@ -260,5 +285,25 @@ public class StatementFactory { enum Mode { COUNT, EXISTS, SELECT, SLICE } + } + + /** + * Represents a function that accepts a SQL string and a {@link ParametersSource} as arguments and produces a result. + * Ideal to run statements using {@link org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations} . + */ + @FunctionalInterface + public interface StatementFunction { + + /** + * Applies this function to the given arguments. + * + * @param sql the SQL string. + * @param paramSource parameters for the SQL string. + * @return the function result. + */ + T apply(String sql, SqlParameterSource paramSource); + + } + }