diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/AotRepositoryFragmentSupport.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/AotRepositoryFragmentSupport.java index 52b14d41d..d18014220 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/AotRepositoryFragmentSupport.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/AotRepositoryFragmentSupport.java @@ -38,7 +38,6 @@ import org.springframework.data.expression.ValueExpression; import org.springframework.data.jdbc.core.JdbcAggregateOperations; import org.springframework.data.jdbc.core.convert.JdbcColumnTypes; import org.springframework.data.jdbc.core.mapping.JdbcValue; -import org.springframework.data.jdbc.repository.query.EscapingParameterSource; import org.springframework.data.jdbc.repository.query.JdbcParameters; import org.springframework.data.jdbc.repository.query.JdbcValueBindUtil; import org.springframework.data.jdbc.repository.query.RowMapperFactory; @@ -60,6 +59,8 @@ import org.springframework.util.ConcurrentLruCache; /** * Support class for JDBC AOT repository fragments. + *
+ * This class is indented to be used by generated AOT fragments and not to be used directly.
*
* @author Mark Paluch
* @since 4.0
@@ -141,10 +142,6 @@ public class AotRepositoryFragmentSupport {
return DataAccessUtils.uniqueResult(results);
}
- protected SqlParameterSource escapingParameterSource(SqlParameterSource parameterSource) {
- return new EscapingParameterSource(parameterSource, getDialect().getLikeEscaper());
- }
-
protected @Nullable Object escape(@Nullable Object value) {
if (value == null) {
diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java
index 6919f7264..6e43d3de1 100644
--- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java
+++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java
@@ -57,7 +57,6 @@ import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.RowMapperResultSetExtractor;
import org.springframework.jdbc.core.SingleColumnRowMapper;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
-import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
@@ -171,7 +170,7 @@ class JdbcCodeBlocks {
}, b -> b.add(";\n$]")));
}
- builder.add(buildQuery(false, entityQuery, criteria, this.parameterSourceVariableName, this.queryVariableName));
+ builder.add(buildQuery(false, entityQuery, criteria));
if (countQuery != null) {
@@ -179,15 +178,11 @@ class JdbcCodeBlocks {
countAll.beginControlFlow("$T $L = () ->", LongSupplier.class, context.localVariable("countAll"));
- String parameterSourceVariableName = context
- .localVariable("count" + StringUtils.capitalize(this.parameterSourceVariableName));
- String queryVariableName = context.localVariable("count" + StringUtils.capitalize(this.queryVariableName));
+ countAll.add(buildQuery(true, countQuery, criteria));
- countAll.add(buildQuery(true, countQuery, criteria, parameterSourceVariableName, queryVariableName));
-
- countAll.addStatement("$1T $2L = queryForObject($3L, $4L, new $5T<>($1T.class))", Number.class,
- context.localVariable("count"), queryVariableName, parameterSourceVariableName,
- SingleColumnRowMapper.class);
+ countAll.addStatement("$1T $2L = $3L.executeWith(($4L, $5L) -> queryForObject($4L, $5L, new $6T<>($1T.class)))",
+ Number.class, context.localVariable("count"), context.localVariable("countBuilder"),
+ context.localVariable("sql"), context.localVariable("paramSource"), SingleColumnRowMapper.class);
countAll.addStatement("return $1L != null ? $1L.longValue() : 0L", context.localVariable("count"));
@@ -202,12 +197,10 @@ class JdbcCodeBlocks {
return builder.build();
}
- private CodeBlock buildQuery(boolean count, DerivedAotQuery aotQuery, CriteriaDefinition criteria,
- String parameterSourceVariableName, String queryVariableName) {
+ private CodeBlock buildQuery(boolean count, DerivedAotQuery aotQuery, CriteriaDefinition criteria) {
Builder builder = CodeBlock.builder();
String selection = context.localVariable(count ? "countBuilder" : "builder");
- String rawParameterSource = context.localVariable(count ? "countRawParameterSource" : "rawParameterSource");
String method;
if (aotQuery.isCount()) {
@@ -256,13 +249,6 @@ class JdbcCodeBlocks {
builder.add(";\n$]");
- // TODO Projections
-
- builder.addStatement("$1T $2L = new $1T()", MapSqlParameterSource.class, rawParameterSource);
- builder.addStatement("$T $L = $L.build($L)", String.class, queryVariableName, selection, rawParameterSource);
- builder.addStatement("$1T $2L = escapingParameterSource($3L)", SqlParameterSource.class,
- parameterSourceVariableName, rawParameterSource);
-
return builder.build();
}
@@ -567,7 +553,10 @@ class JdbcCodeBlocks {
Assert.state(aotQuery != null, "AOT Query must not be null");
- Builder builder = CodeBlock.builder();
+ return doBuild();
+ }
+
+ private CodeBlock doBuild() {
MethodReturn methodReturn = context.getMethodReturn();
boolean isProjecting = methodReturn.isProjecting()
@@ -581,128 +570,90 @@ class JdbcCodeBlocks {
String result = context.localVariable("result");
String rowMapper = context.localVariable("rowMapper");
+ ExecutionDecorator decorator = getExecutionDecorator();
+
if (modifying.isPresent()) {
- return update(returnType);
+ return update(decorator, returnType);
} else if (aotQuery.isCount()) {
- return count(result, returnType, queryResultType);
+ return count(decorator, result, queryResultType, returnType);
} else if (aotQuery.isExists()) {
- return exists(queryResultType);
+ return exists(decorator, queryResultType);
} else if (aotQuery.isDelete()) {
- return delete(rowMapper, result, queryResultType, returnType, actualReturnType);
+ return delete(decorator, rowMapper, result, queryResultType, returnType, actualReturnType);
} else {
+ return select(decorator, rowMapper, result, queryResultType, isProjecting, methodReturn);
+ }
+ }
- String resultSetExtractor = null;
+ private ExecutionDecorator getExecutionDecorator() {
- if (rowMapperClass != null) {
- builder.addStatement("$T $L = new $T()", RowMapper.class, rowMapper, rowMapperClass);
- } else if (StringUtils.hasText(rowMapperRef)) {
- builder.addStatement("$T $L = getRowMapperFactory().getRowMapper($S)", RowMapper.class, rowMapper,
- rowMapperRef);
- } else if (resultSetExtractorClass == null) {
+ if (aotQuery instanceof DerivedAotQuery) {
- Type typeToRead;
+ return new ExecutionDecorator() {
+ @Override
+ public String decorate(String executionCode) {
+ String builder = context.localVariable("builder");
- if (isProjecting) {
- typeToRead = context.getReturnedType().getDomainType();
- } else {
- typeToRead = methodReturn.getActualReturnClass();
+ return String.format("%s.executeWith((%s, %s) -> %s)", builder, query(), paramSource(), executionCode);
}
- builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper,
- typeToRead);
- }
-
- if (StringUtils.hasText(resultSetExtractorRef) || resultSetExtractorClass != null) {
-
- resultSetExtractor = context.localVariable("resultSetExtractor");
+ @Override
+ public String query() {
+ return context.localVariable("sql");
+ }
- if (resultSetExtractorClass != null && (rowMapperClass != null || StringUtils.hasText(rowMapperRef))) {
- builder.addStatement("$T $L = new $T($L)", ResultSetExtractor.class, resultSetExtractor,
- resultSetExtractorClass, rowMapper);
- } else if (resultSetExtractorClass != null) {
- builder.addStatement("$T $L = new $T()", ResultSetExtractor.class, resultSetExtractor,
- resultSetExtractorClass);
- } else if (StringUtils.hasText(resultSetExtractorRef)) {
- builder.addStatement("$T $L = getRowMapperFactory().getResultSetExtractor($S)", ResultSetExtractor.class,
- resultSetExtractor, resultSetExtractorRef);
+ @Override
+ public String paramSource() {
+ return context.localVariable("paramSource");
}
+ };
+ }
+ return new ExecutionDecorator() {
+ @Override
+ public String decorate(String executionCode) {
+ return executionCode;
}
- if (StringUtils.hasText(resultSetExtractor)) {
-
- builder.addStatement("return ($T) getJdbcOperations().query($L, $L, $L)", queryResultType, queryVariableName,
- parameterSourceVariableName, resultSetExtractor);
-
- return builder.build();
+ @Override
+ public String query() {
+ return queryVariableName;
}
- boolean dynamicProjection = StringUtils.hasText(context.getDynamicProjectionParameterName());
- Object queryResultTypeRef = dynamicProjection ? context.getDynamicProjectionParameterName() : queryResultType;
-
- if (queryMethod.isCollectionQuery() || queryMethod.isSliceQuery() || queryMethod.isPageQuery()) {
-
- builder.addStatement("$1T $2L = ($1T) getJdbcOperations().query($3L, $4L, new $5T<>($6L))", List.class,
- result, queryVariableName, parameterSourceVariableName, RowMapperResultSetExtractor.class, rowMapper);
-
- if (queryMethod.isSliceQuery() || queryMethod.isPageQuery()) {
-
- String pageable = context.getPageableParameterName();
-
- builder.addStatement(
- "$1T $2L = ($1T) convertMany($3L, %s)".formatted(dynamicProjection ? "$4L" : "$4T.class"), List.class,
- context.localVariable("converted"), result, queryResultTypeRef);
-
- if (queryMethod.isPageQuery()) {
-
- builder.addStatement("return $1T.getPage($2L, $3L, $4L)", PageableExecutionUtils.class,
- context.localVariable("converted"), pageable, context.localVariable("countAll"));
- } else {
-
- builder.addStatement("boolean $1L = $2L.isPaged() && $3L.size() > $2L.getPageSize()",
- context.localVariable("hasNext"), pageable, context.localVariable("converted"));
+ @Override
+ public String paramSource() {
+ return parameterSourceVariableName;
+ }
+ };
- builder.addStatement("return new $1T($2L ? $3L.subList(0, $4L.getPageSize()) : $3L, $4L, $2L)",
- SliceImpl.class, context.localVariable("hasNext"), context.localVariable("converted"), pageable);
- }
+ }
- return builder.build();
- }
+ /**
+ * Decorates an execution. Used to wrap execution with lambda or similar.
+ */
+ interface ExecutionDecorator {
- builder.addStatement("return ($T) convertMany($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"),
- methodReturn.getTypeName(), result, queryResultTypeRef);
- } else if (queryMethod.isStreamQuery()) {
+ /**
+ * Decorate the given executionCode with a wrapper.
+ *
+ * @param executionCode
+ * @return
+ */
+ String decorate(String executionCode);
- builder.addStatement("$1T $2L = getJdbcOperations().queryForStream($3L, $4L, $5L)", Stream.class, result,
- queryVariableName, parameterSourceVariableName, rowMapper);
- builder.addStatement("return ($T) convertMany($L, $T.class)", methodReturn.getTypeName(), result,
- queryResultTypeRef);
- } else {
+ String query();
- builder.addStatement("$T $L = queryForObject($L, $L, $L)", Object.class, result, queryVariableName,
- parameterSourceVariableName, rowMapper);
+ String paramSource();
- if (methodReturn.isOptional()) {
- builder.addStatement(
- "return ($1T) $1T.ofNullable(convertOne($2L, %s))".formatted(dynamicProjection ? "$3L" : "$3T.class"),
- Optional.class, result, queryResultTypeRef);
- } else {
- builder.addStatement("return ($T) convertOne($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"),
- methodReturn.getTypeName(), result, queryResultTypeRef);
- }
- }
- }
-
- return builder.build();
}
- private CodeBlock update(Class> returnType) {
+ private CodeBlock update(ExecutionDecorator decorator, Class> returnType) {
String result = context.localVariable("result");
Builder builder = CodeBlock.builder();
- LordOfTheStrings.InvocationBuilder invoke = LordOfTheStrings.invoke("getJdbcOperations().update($L, $L)",
- queryVariableName, parameterSourceVariableName);
+ LordOfTheStrings.InvocationBuilder invoke = LordOfTheStrings
+ .invoke(decorator.decorate("getJdbcOperations().update($L, $L)"), decorator.query(), decorator.paramSource());
if (context.getMethodReturn().isVoid()) {
builder.addStatement(invoke.build());
@@ -719,7 +670,7 @@ class JdbcCodeBlocks {
return builder.build();
}
- private CodeBlock delete(String rowMapper, String result, TypeName queryResultType,
+ private CodeBlock delete(ExecutionDecorator decorator, String rowMapper, String result, TypeName queryResultType,
Class> returnType, Type actualReturnType) {
CodeBlock.Builder builder = CodeBlock.builder();
@@ -727,8 +678,9 @@ class JdbcCodeBlocks {
builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper,
context.getRepositoryInformation().getDomainType());
- builder.addStatement("$1T $2L = ($1T) getJdbcOperations().query($3L, $4L, new $5T<>($6L))", List.class, result,
- queryVariableName, parameterSourceVariableName, RowMapperResultSetExtractor.class, rowMapper);
+ builder.addStatement(
+ "$1T $2L = ($1T) " + decorator.decorate("getJdbcOperations().query($3L, $4L, new $5T<>($6L))"), List.class,
+ result, decorator.query(), decorator.paramSource(), RowMapperResultSetExtractor.class, rowMapper);
builder.addStatement("$L.forEach(getOperations()::delete)", result);
@@ -745,28 +697,140 @@ class JdbcCodeBlocks {
return builder.build();
}
- private CodeBlock count(String result, Class> returnType, TypeName queryResultType) {
+ private CodeBlock count(ExecutionDecorator decorator, String result, TypeName queryResultType,
+ Class> returnType) {
CodeBlock.Builder builder = CodeBlock.builder();
- builder.addStatement("$1T $2L = queryForObject($3L, $4L, new $5T<>($1T.class))", Number.class, result,
- queryVariableName, parameterSourceVariableName, SingleColumnRowMapper.class);
+ builder.addStatement("$1T $2L = " + decorator.decorate("queryForObject($3L, $4L, new $5T<>($1T.class))"),
+ Number.class, result, decorator.query(), decorator.paramSource(), SingleColumnRowMapper.class);
builder.addStatement(LordOfTheStrings.returning(returnType) //
.number(result) //
+ .whenPrimitiveOrBoxed(short.class, "$1L.shortValue()", result) //
+ .whenPrimitiveOrBoxed(byte.class, "$1L.byteValue()", result) //
+ .whenPrimitiveOrBoxed(double.class, "$1L.doubleValue()", result) //
+ .whenPrimitiveOrBoxed(float.class, "$1L.floatValue()", result) //
.otherwise("($T) convertOne($L, $T.class)", context.getMethodReturn().getTypeName(), result, queryResultType) //
.build());
return builder.build();
}
- private CodeBlock exists(TypeName queryResultType) {
+ private CodeBlock exists(ExecutionDecorator decorator, TypeName queryResultType) {
CodeBlock.Builder builder = CodeBlock.builder();
- builder.addStatement("return ($T) getJdbcOperations().query($L, $L, $T::next)", queryResultType,
- queryVariableName, parameterSourceVariableName, ResultSet.class);
+ builder.addStatement("return ($T) " + decorator.decorate("getJdbcOperations().query($L, $L, $T::next)"),
+ queryResultType, decorator.query(), decorator.paramSource(), ResultSet.class);
+
+ return builder.build();
+ }
+
+ private CodeBlock select(ExecutionDecorator decorator, String rowMapper, String result, TypeName queryResultType,
+ boolean isProjecting, MethodReturn methodReturn) {
+ Builder builder = CodeBlock.builder();
+
+ String resultSetExtractor = null;
+
+ if (rowMapperClass != null) {
+ builder.addStatement("$T $L = new $T()", RowMapper.class, rowMapper, rowMapperClass);
+ } else if (StringUtils.hasText(rowMapperRef)) {
+ builder.addStatement("$T $L = getRowMapperFactory().getRowMapper($S)", RowMapper.class, rowMapper,
+ rowMapperRef);
+ } else if (resultSetExtractorClass == null) {
+
+ Type typeToRead;
+
+ if (isProjecting) {
+ typeToRead = context.getReturnedType().getDomainType();
+ } else {
+ typeToRead = methodReturn.getActualReturnClass();
+ }
+
+ builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper, typeToRead);
+ }
+
+ if (StringUtils.hasText(resultSetExtractorRef) || resultSetExtractorClass != null) {
+
+ resultSetExtractor = context.localVariable("resultSetExtractor");
+
+ if (resultSetExtractorClass != null && (rowMapperClass != null || StringUtils.hasText(rowMapperRef))) {
+ builder.addStatement("$T $L = new $T($L)", ResultSetExtractor.class, resultSetExtractor,
+ resultSetExtractorClass, rowMapper);
+ } else if (resultSetExtractorClass != null) {
+ builder.addStatement("$T $L = new $T()", ResultSetExtractor.class, resultSetExtractor,
+ resultSetExtractorClass);
+ } else if (StringUtils.hasText(resultSetExtractorRef)) {
+ builder.addStatement("$T $L = getRowMapperFactory().getResultSetExtractor($S)", ResultSetExtractor.class,
+ resultSetExtractor, resultSetExtractorRef);
+ }
+ }
+
+ if (StringUtils.hasText(resultSetExtractor)) {
+
+ builder.addStatement("return ($T) " + decorator.decorate("getJdbcOperations().query($L, $L, $L)"),
+ queryResultType, decorator.query(), decorator.paramSource(), resultSetExtractor);
+
+ return builder.build();
+ }
+
+ boolean dynamicProjection = StringUtils.hasText(context.getDynamicProjectionParameterName());
+ Object queryResultTypeRef = dynamicProjection ? context.getDynamicProjectionParameterName() : queryResultType;
+
+ if (queryMethod.isCollectionQuery() || queryMethod.isSliceQuery() || queryMethod.isPageQuery()) {
+
+ builder.addStatement(
+ "$1T $2L = ($1T) " + decorator.decorate("getJdbcOperations().query($3L, $4L, new $5T<>($6L))"), List.class,
+ result, decorator.query(), decorator.paramSource(), RowMapperResultSetExtractor.class, rowMapper);
+
+ if (queryMethod.isSliceQuery() || queryMethod.isPageQuery()) {
+
+ String pageable = context.getPageableParameterName();
+
+ builder.addStatement(
+ "$1T $2L = ($1T) convertMany($3L, %s)".formatted(dynamicProjection ? "$4L" : "$4T.class"), List.class,
+ context.localVariable("converted"), result, queryResultTypeRef);
+
+ if (queryMethod.isPageQuery()) {
+
+ builder.addStatement("return $1T.getPage($2L, $3L, $4L)", PageableExecutionUtils.class,
+ context.localVariable("converted"), pageable, context.localVariable("countAll"));
+ } else {
+
+ builder.addStatement("boolean $1L = $2L.isPaged() && $3L.size() > $2L.getPageSize()",
+ context.localVariable("hasNext"), pageable, context.localVariable("converted"));
+
+ builder.addStatement("return new $1T($2L ? $3L.subList(0, $4L.getPageSize()) : $3L, $4L, $2L)",
+ SliceImpl.class, context.localVariable("hasNext"), context.localVariable("converted"), pageable);
+ }
+
+ return builder.build();
+ }
+
+ builder.addStatement("return ($T) convertMany($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"),
+ methodReturn.getTypeName(), result, queryResultTypeRef);
+ } else if (queryMethod.isStreamQuery()) {
+
+ builder.addStatement("$1T $2L = " + decorator.decorate("getJdbcOperations().queryForStream($3L, $4L, $5L)"),
+ Stream.class, result, decorator.query(), decorator.paramSource(), rowMapper);
+ builder.addStatement("return ($T) convertMany($L, $T.class)", methodReturn.getTypeName(), result,
+ queryResultTypeRef);
+ } else {
+
+ builder.addStatement("$T $L = " + decorator.decorate("queryForObject($L, $L, $L)"), Object.class, result,
+ decorator.query(), decorator.paramSource(), rowMapper);
+
+ if (methodReturn.isOptional()) {
+ builder.addStatement(
+ "return ($1T) $1T.ofNullable(convertOne($2L, %s))".formatted(dynamicProjection ? "$3L" : "$3T.class"),
+ Optional.class, result, queryResultTypeRef);
+ } else {
+ builder.addStatement("return ($T) convertOne($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"),
+ methodReturn.getTypeName(), result, queryResultTypeRef);
+ }
+ }
return builder.build();
}
diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StatementFactory.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StatementFactory.java
index b4efa8306..1873a4f68 100644
--- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StatementFactory.java
+++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StatementFactory.java
@@ -42,8 +42,10 @@ import org.springframework.data.relational.core.sql.Select;
import org.springframework.data.relational.core.sql.SelectBuilder;
import org.springframework.data.relational.core.sql.Table;
import org.springframework.data.relational.core.sql.render.SqlRenderer;
+import org.springframework.data.repository.query.ParametersSource;
import org.springframework.data.util.Predicates;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
+import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.lang.Contract;
/**
@@ -58,12 +60,14 @@ public class StatementFactory {
private final JdbcConverter converter;
private final RenderContextFactory renderContextFactory;
private final QueryMapper queryMapper;
+ private final Dialect dialect;
private final SqlGeneratorSource sqlGeneratorSource;
public StatementFactory(JdbcConverter converter, Dialect dialect) {
this.renderContextFactory = new RenderContextFactory(dialect);
this.converter = converter;
this.queryMapper = new QueryMapper(converter);
+ this.dialect = dialect;
this.sqlGeneratorSource = new SqlGeneratorSource(converter, dialect);
}
@@ -170,6 +174,27 @@ public class StatementFactory {
return this;
}
+ /**
+ * Build the SQL statement and apply the given function to the SQL string and its parameters.
+ *
+ * @param function SQL statement function accepting SQL string and parameters.
+ * @return the function result.
+ * @param