Browse Source

Refine AOT execution for derived queries.

We now provide an extension to directly run queries constructed from a derived query method to avoid leaking internals into AOT-generated code.

Closes #2140
pull/2148/head
Mark Paluch 3 months ago
parent
commit
d7629313e9
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 7
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/AotRepositoryFragmentSupport.java
  2. 288
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java
  3. 45
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StatementFactory.java

7
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/AotRepositoryFragmentSupport.java

@ -38,7 +38,6 @@ import org.springframework.data.expression.ValueExpression; @@ -38,7 +38,6 @@ import org.springframework.data.expression.ValueExpression;
import org.springframework.data.jdbc.core.JdbcAggregateOperations;
import org.springframework.data.jdbc.core.convert.JdbcColumnTypes;
import org.springframework.data.jdbc.core.mapping.JdbcValue;
import org.springframework.data.jdbc.repository.query.EscapingParameterSource;
import org.springframework.data.jdbc.repository.query.JdbcParameters;
import org.springframework.data.jdbc.repository.query.JdbcValueBindUtil;
import org.springframework.data.jdbc.repository.query.RowMapperFactory;
@ -60,6 +59,8 @@ import org.springframework.util.ConcurrentLruCache; @@ -60,6 +59,8 @@ import org.springframework.util.ConcurrentLruCache;
/**
* Support class for JDBC AOT repository fragments.
* <p>
* This class is indented to be used by generated AOT fragments and not to be used directly.
*
* @author Mark Paluch
* @since 4.0
@ -141,10 +142,6 @@ public class AotRepositoryFragmentSupport { @@ -141,10 +142,6 @@ public class AotRepositoryFragmentSupport {
return DataAccessUtils.uniqueResult(results);
}
protected SqlParameterSource escapingParameterSource(SqlParameterSource parameterSource) {
return new EscapingParameterSource(parameterSource, getDialect().getLikeEscaper());
}
protected @Nullable Object escape(@Nullable Object value) {
if (value == null) {

288
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java

@ -57,7 +57,6 @@ import org.springframework.jdbc.core.RowMapper; @@ -57,7 +57,6 @@ import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.RowMapperResultSetExtractor;
import org.springframework.jdbc.core.SingleColumnRowMapper;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
@ -171,7 +170,7 @@ class JdbcCodeBlocks { @@ -171,7 +170,7 @@ class JdbcCodeBlocks {
}, b -> b.add(";\n$]")));
}
builder.add(buildQuery(false, entityQuery, criteria, this.parameterSourceVariableName, this.queryVariableName));
builder.add(buildQuery(false, entityQuery, criteria));
if (countQuery != null) {
@ -179,15 +178,11 @@ class JdbcCodeBlocks { @@ -179,15 +178,11 @@ class JdbcCodeBlocks {
countAll.beginControlFlow("$T $L = () ->", LongSupplier.class, context.localVariable("countAll"));
String parameterSourceVariableName = context
.localVariable("count" + StringUtils.capitalize(this.parameterSourceVariableName));
String queryVariableName = context.localVariable("count" + StringUtils.capitalize(this.queryVariableName));
countAll.add(buildQuery(true, countQuery, criteria));
countAll.add(buildQuery(true, countQuery, criteria, parameterSourceVariableName, queryVariableName));
countAll.addStatement("$1T $2L = queryForObject($3L, $4L, new $5T<>($1T.class))", Number.class,
context.localVariable("count"), queryVariableName, parameterSourceVariableName,
SingleColumnRowMapper.class);
countAll.addStatement("$1T $2L = $3L.executeWith(($4L, $5L) -> queryForObject($4L, $5L, new $6T<>($1T.class)))",
Number.class, context.localVariable("count"), context.localVariable("countBuilder"),
context.localVariable("sql"), context.localVariable("paramSource"), SingleColumnRowMapper.class);
countAll.addStatement("return $1L != null ? $1L.longValue() : 0L", context.localVariable("count"));
@ -202,12 +197,10 @@ class JdbcCodeBlocks { @@ -202,12 +197,10 @@ class JdbcCodeBlocks {
return builder.build();
}
private CodeBlock buildQuery(boolean count, DerivedAotQuery aotQuery, CriteriaDefinition criteria,
String parameterSourceVariableName, String queryVariableName) {
private CodeBlock buildQuery(boolean count, DerivedAotQuery aotQuery, CriteriaDefinition criteria) {
Builder builder = CodeBlock.builder();
String selection = context.localVariable(count ? "countBuilder" : "builder");
String rawParameterSource = context.localVariable(count ? "countRawParameterSource" : "rawParameterSource");
String method;
if (aotQuery.isCount()) {
@ -256,13 +249,6 @@ class JdbcCodeBlocks { @@ -256,13 +249,6 @@ class JdbcCodeBlocks {
builder.add(";\n$]");
// TODO Projections
builder.addStatement("$1T $2L = new $1T()", MapSqlParameterSource.class, rawParameterSource);
builder.addStatement("$T $L = $L.build($L)", String.class, queryVariableName, selection, rawParameterSource);
builder.addStatement("$1T $2L = escapingParameterSource($3L)", SqlParameterSource.class,
parameterSourceVariableName, rawParameterSource);
return builder.build();
}
@ -567,7 +553,10 @@ class JdbcCodeBlocks { @@ -567,7 +553,10 @@ class JdbcCodeBlocks {
Assert.state(aotQuery != null, "AOT Query must not be null");
Builder builder = CodeBlock.builder();
return doBuild();
}
private CodeBlock doBuild() {
MethodReturn methodReturn = context.getMethodReturn();
boolean isProjecting = methodReturn.isProjecting()
@ -581,15 +570,167 @@ class JdbcCodeBlocks { @@ -581,15 +570,167 @@ class JdbcCodeBlocks {
String result = context.localVariable("result");
String rowMapper = context.localVariable("rowMapper");
ExecutionDecorator decorator = getExecutionDecorator();
if (modifying.isPresent()) {
return update(returnType);
return update(decorator, returnType);
} else if (aotQuery.isCount()) {
return count(result, returnType, queryResultType);
return count(decorator, result, queryResultType, returnType);
} else if (aotQuery.isExists()) {
return exists(queryResultType);
return exists(decorator, queryResultType);
} else if (aotQuery.isDelete()) {
return delete(rowMapper, result, queryResultType, returnType, actualReturnType);
return delete(decorator, rowMapper, result, queryResultType, returnType, actualReturnType);
} else {
return select(decorator, rowMapper, result, queryResultType, isProjecting, methodReturn);
}
}
private ExecutionDecorator getExecutionDecorator() {
if (aotQuery instanceof DerivedAotQuery) {
return new ExecutionDecorator() {
@Override
public String decorate(String executionCode) {
String builder = context.localVariable("builder");
return String.format("%s.executeWith((%s, %s) -> %s)", builder, query(), paramSource(), executionCode);
}
@Override
public String query() {
return context.localVariable("sql");
}
@Override
public String paramSource() {
return context.localVariable("paramSource");
}
};
}
return new ExecutionDecorator() {
@Override
public String decorate(String executionCode) {
return executionCode;
}
@Override
public String query() {
return queryVariableName;
}
@Override
public String paramSource() {
return parameterSourceVariableName;
}
};
}
/**
* Decorates an execution. Used to wrap execution with lambda or similar.
*/
interface ExecutionDecorator {
/**
* Decorate the given executionCode with a wrapper.
*
* @param executionCode
* @return
*/
String decorate(String executionCode);
String query();
String paramSource();
}
private CodeBlock update(ExecutionDecorator decorator, Class<?> returnType) {
String result = context.localVariable("result");
Builder builder = CodeBlock.builder();
LordOfTheStrings.InvocationBuilder invoke = LordOfTheStrings
.invoke(decorator.decorate("getJdbcOperations().update($L, $L)"), decorator.query(), decorator.paramSource());
if (context.getMethodReturn().isVoid()) {
builder.addStatement(invoke.build());
} else {
builder.addStatement(invoke.assignTo("int $L", result));
}
builder.addStatement(LordOfTheStrings.returning(returnType) //
.whenBoolean("$L != 0", result) //
.whenBoxedLong("(long) $L", result) //
.otherwise("$L", result)//
.build());
return builder.build();
}
private CodeBlock delete(ExecutionDecorator decorator, String rowMapper, String result, TypeName queryResultType,
Class<?> returnType, Type actualReturnType) {
CodeBlock.Builder builder = CodeBlock.builder();
builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper,
context.getRepositoryInformation().getDomainType());
builder.addStatement(
"$1T $2L = ($1T) " + decorator.decorate("getJdbcOperations().query($3L, $4L, new $5T<>($6L))"), List.class,
result, decorator.query(), decorator.paramSource(), RowMapperResultSetExtractor.class, rowMapper);
builder.addStatement("$L.forEach(getOperations()::delete)", result);
builder.addStatement(LordOfTheStrings.returning(returnType) //
.when(Collection.class.isAssignableFrom(context.getMethodReturn().toClass()),
"($T) convertMany($L, $T.class)", context.getMethodReturn().getTypeName(), result, queryResultType) //
.when(context.getRepositoryInformation().getDomainType(),
"($1T) ($2L.isEmpty() ? null : $2L.iterator().next())", actualReturnType, result) //
.whenBoolean("!$L.isEmpty()", result) //
.whenBoxedLong("(long) $L.size()", result) //
.otherwise("$L.size()", result) //
.build());
return builder.build();
}
private CodeBlock count(ExecutionDecorator decorator, String result, TypeName queryResultType,
Class<?> returnType) {
CodeBlock.Builder builder = CodeBlock.builder();
builder.addStatement("$1T $2L = " + decorator.decorate("queryForObject($3L, $4L, new $5T<>($1T.class))"),
Number.class, result, decorator.query(), decorator.paramSource(), SingleColumnRowMapper.class);
builder.addStatement(LordOfTheStrings.returning(returnType) //
.number(result) //
.whenPrimitiveOrBoxed(short.class, "$1L.shortValue()", result) //
.whenPrimitiveOrBoxed(byte.class, "$1L.byteValue()", result) //
.whenPrimitiveOrBoxed(double.class, "$1L.doubleValue()", result) //
.whenPrimitiveOrBoxed(float.class, "$1L.floatValue()", result) //
.otherwise("($T) convertOne($L, $T.class)", context.getMethodReturn().getTypeName(), result, queryResultType) //
.build());
return builder.build();
}
private CodeBlock exists(ExecutionDecorator decorator, TypeName queryResultType) {
CodeBlock.Builder builder = CodeBlock.builder();
builder.addStatement("return ($T) " + decorator.decorate("getJdbcOperations().query($L, $L, $T::next)"),
queryResultType, decorator.query(), decorator.paramSource(), ResultSet.class);
return builder.build();
}
private CodeBlock select(ExecutionDecorator decorator, String rowMapper, String result, TypeName queryResultType,
boolean isProjecting, MethodReturn methodReturn) {
Builder builder = CodeBlock.builder();
String resultSetExtractor = null;
@ -608,8 +749,7 @@ class JdbcCodeBlocks { @@ -608,8 +749,7 @@ class JdbcCodeBlocks {
typeToRead = methodReturn.getActualReturnClass();
}
builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper,
typeToRead);
builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper, typeToRead);
}
if (StringUtils.hasText(resultSetExtractorRef) || resultSetExtractorClass != null) {
@ -630,8 +770,8 @@ class JdbcCodeBlocks { @@ -630,8 +770,8 @@ class JdbcCodeBlocks {
if (StringUtils.hasText(resultSetExtractor)) {
builder.addStatement("return ($T) getJdbcOperations().query($L, $L, $L)", queryResultType, queryVariableName,
parameterSourceVariableName, resultSetExtractor);
builder.addStatement("return ($T) " + decorator.decorate("getJdbcOperations().query($L, $L, $L)"),
queryResultType, decorator.query(), decorator.paramSource(), resultSetExtractor);
return builder.build();
}
@ -641,8 +781,9 @@ class JdbcCodeBlocks { @@ -641,8 +781,9 @@ class JdbcCodeBlocks {
if (queryMethod.isCollectionQuery() || queryMethod.isSliceQuery() || queryMethod.isPageQuery()) {
builder.addStatement("$1T $2L = ($1T) getJdbcOperations().query($3L, $4L, new $5T<>($6L))", List.class,
result, queryVariableName, parameterSourceVariableName, RowMapperResultSetExtractor.class, rowMapper);
builder.addStatement(
"$1T $2L = ($1T) " + decorator.decorate("getJdbcOperations().query($3L, $4L, new $5T<>($6L))"), List.class,
result, decorator.query(), decorator.paramSource(), RowMapperResultSetExtractor.class, rowMapper);
if (queryMethod.isSliceQuery() || queryMethod.isPageQuery()) {
@ -672,14 +813,14 @@ class JdbcCodeBlocks { @@ -672,14 +813,14 @@ class JdbcCodeBlocks {
methodReturn.getTypeName(), result, queryResultTypeRef);
} else if (queryMethod.isStreamQuery()) {
builder.addStatement("$1T $2L = getJdbcOperations().queryForStream($3L, $4L, $5L)", Stream.class, result,
queryVariableName, parameterSourceVariableName, rowMapper);
builder.addStatement("$1T $2L = " + decorator.decorate("getJdbcOperations().queryForStream($3L, $4L, $5L)"),
Stream.class, result, decorator.query(), decorator.paramSource(), rowMapper);
builder.addStatement("return ($T) convertMany($L, $T.class)", methodReturn.getTypeName(), result,
queryResultTypeRef);
} else {
builder.addStatement("$T $L = queryForObject($L, $L, $L)", Object.class, result, queryVariableName,
parameterSourceVariableName, rowMapper);
builder.addStatement("$T $L = " + decorator.decorate("queryForObject($L, $L, $L)"), Object.class, result,
decorator.query(), decorator.paramSource(), rowMapper);
if (methodReturn.isOptional()) {
builder.addStatement(
@ -690,83 +831,6 @@ class JdbcCodeBlocks { @@ -690,83 +831,6 @@ class JdbcCodeBlocks {
methodReturn.getTypeName(), result, queryResultTypeRef);
}
}
}
return builder.build();
}
private CodeBlock update(Class<?> returnType) {
String result = context.localVariable("result");
Builder builder = CodeBlock.builder();
LordOfTheStrings.InvocationBuilder invoke = LordOfTheStrings.invoke("getJdbcOperations().update($L, $L)",
queryVariableName, parameterSourceVariableName);
if (context.getMethodReturn().isVoid()) {
builder.addStatement(invoke.build());
} else {
builder.addStatement(invoke.assignTo("int $L", result));
}
builder.addStatement(LordOfTheStrings.returning(returnType) //
.whenBoolean("$L != 0", result) //
.whenBoxedLong("(long) $L", result) //
.otherwise("$L", result)//
.build());
return builder.build();
}
private CodeBlock delete(String rowMapper, String result, TypeName queryResultType,
Class<?> returnType, Type actualReturnType) {
CodeBlock.Builder builder = CodeBlock.builder();
builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper,
context.getRepositoryInformation().getDomainType());
builder.addStatement("$1T $2L = ($1T) getJdbcOperations().query($3L, $4L, new $5T<>($6L))", List.class, result,
queryVariableName, parameterSourceVariableName, RowMapperResultSetExtractor.class, rowMapper);
builder.addStatement("$L.forEach(getOperations()::delete)", result);
builder.addStatement(LordOfTheStrings.returning(returnType) //
.when(Collection.class.isAssignableFrom(context.getMethodReturn().toClass()),
"($T) convertMany($L, $T.class)", context.getMethodReturn().getTypeName(), result, queryResultType) //
.when(context.getRepositoryInformation().getDomainType(),
"($1T) ($2L.isEmpty() ? null : $2L.iterator().next())", actualReturnType, result) //
.whenBoolean("!$L.isEmpty()", result) //
.whenBoxedLong("(long) $L.size()", result) //
.otherwise("$L.size()", result) //
.build());
return builder.build();
}
private CodeBlock count(String result, Class<?> returnType, TypeName queryResultType) {
CodeBlock.Builder builder = CodeBlock.builder();
builder.addStatement("$1T $2L = queryForObject($3L, $4L, new $5T<>($1T.class))", Number.class, result,
queryVariableName, parameterSourceVariableName, SingleColumnRowMapper.class);
builder.addStatement(LordOfTheStrings.returning(returnType) //
.number(result) //
.otherwise("($T) convertOne($L, $T.class)", context.getMethodReturn().getTypeName(), result, queryResultType) //
.build());
return builder.build();
}
private CodeBlock exists(TypeName queryResultType) {
CodeBlock.Builder builder = CodeBlock.builder();
builder.addStatement("return ($T) getJdbcOperations().query($L, $L, $T::next)", queryResultType,
queryVariableName, parameterSourceVariableName, ResultSet.class);
return builder.build();
}

45
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StatementFactory.java

@ -42,8 +42,10 @@ import org.springframework.data.relational.core.sql.Select; @@ -42,8 +42,10 @@ import org.springframework.data.relational.core.sql.Select;
import org.springframework.data.relational.core.sql.SelectBuilder;
import org.springframework.data.relational.core.sql.Table;
import org.springframework.data.relational.core.sql.render.SqlRenderer;
import org.springframework.data.repository.query.ParametersSource;
import org.springframework.data.util.Predicates;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.lang.Contract;
/**
@ -58,12 +60,14 @@ public class StatementFactory { @@ -58,12 +60,14 @@ public class StatementFactory {
private final JdbcConverter converter;
private final RenderContextFactory renderContextFactory;
private final QueryMapper queryMapper;
private final Dialect dialect;
private final SqlGeneratorSource sqlGeneratorSource;
public StatementFactory(JdbcConverter converter, Dialect dialect) {
this.renderContextFactory = new RenderContextFactory(dialect);
this.converter = converter;
this.queryMapper = new QueryMapper(converter);
this.dialect = dialect;
this.sqlGeneratorSource = new SqlGeneratorSource(converter, dialect);
}
@ -170,6 +174,27 @@ public class StatementFactory { @@ -170,6 +174,27 @@ public class StatementFactory {
return this;
}
/**
* Build the SQL statement and apply the given function to the SQL string and its parameters.
*
* @param function SQL statement function accepting SQL string and parameters.
* @return the function result.
* @param <T> type of the function result.
*/
public <T extends @Nullable Object> T executeWith(StatementFunction<T> function) {
MapSqlParameterSource parameterSource = new MapSqlParameterSource();
String sql = build(parameterSource);
return function.apply(sql, new EscapingParameterSource(parameterSource, dialect.getLikeEscaper()));
}
/**
* Build the SQL statement and assign parameters to the given {@link ParametersSource}.
*
* @param parameterSource the parameter source to be populated.
* @return the build SQL statement.
*/
public String build(MapSqlParameterSource parameterSource) {
SelectBuilder.SelectLimitOffset limitOffsetBuilder = createSelectClause(entity, table);
@ -260,5 +285,25 @@ public class StatementFactory { @@ -260,5 +285,25 @@ public class StatementFactory {
enum Mode {
COUNT, EXISTS, SELECT, SLICE
}
}
/**
* Represents a function that accepts a SQL string and a {@link ParametersSource} as arguments and produces a result.
* Ideal to run statements using {@link org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations} .
*/
@FunctionalInterface
public interface StatementFunction<T extends @Nullable Object> {
/**
* Applies this function to the given arguments.
*
* @param sql the SQL string.
* @param paramSource parameters for the SQL string.
* @return the function result.
*/
T apply(String sql, SqlParameterSource paramSource);
}
}

Loading…
Cancel
Save