From 03a667ca6d509d975f23021ecdeceeb65ed375dc Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 7 Oct 2020 14:50:54 +0200 Subject: [PATCH] #465 - Convert bind values for String-based queries to their native type. We now invoke the converter for query arguments of String-based queries (using the Query annotation or named queries) to convert the value into a type that can be used by the driver. Previously all values were passed-thru which caused for e.g. enums to pass-thru these to the driver and the driver encode failed. --- .../ExpressionEvaluatingParameterBinder.java | 84 +++++++++++++------ .../query/StringBasedR2dbcQuery.java | 11 ++- .../support/R2dbcRepositoryFactory.java | 4 +- .../query/StringBasedR2dbcQueryUnitTests.java | 27 +++++- 4 files changed, 96 insertions(+), 30 deletions(-) diff --git a/src/main/java/org/springframework/data/r2dbc/repository/query/ExpressionEvaluatingParameterBinder.java b/src/main/java/org/springframework/data/r2dbc/repository/query/ExpressionEvaluatingParameterBinder.java index 66eb6a80e..a218379f7 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/query/ExpressionEvaluatingParameterBinder.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/query/ExpressionEvaluatingParameterBinder.java @@ -22,6 +22,7 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Pattern; +import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy; import org.springframework.data.relational.repository.query.RelationalParameterAccessor; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; @@ -38,15 +39,19 @@ class ExpressionEvaluatingParameterBinder { private final ExpressionQuery expressionQuery; + private final ReactiveDataAccessStrategy dataAccessStrategy; + private final Map namedParameters = new ConcurrentHashMap<>(); /** * Creates new {@link ExpressionEvaluatingParameterBinder} * * @param expressionQuery must not be {@literal null}. + * @param dataAccessStrategy must not be {@literal null}. */ - ExpressionEvaluatingParameterBinder(ExpressionQuery expressionQuery) { + ExpressionEvaluatingParameterBinder(ExpressionQuery expressionQuery, ReactiveDataAccessStrategy dataAccessStrategy) { this.expressionQuery = expressionQuery; + this.dataAccessStrategy = dataAccessStrategy; } /** @@ -76,50 +81,44 @@ class ExpressionEvaluatingParameterBinder { for (ParameterBinding binding : expressionQuery.getBindings()) { - org.springframework.r2dbc.core.Parameter valueForBinding = evaluator.evaluate(binding.getExpression()); + org.springframework.r2dbc.core.Parameter valueForBinding = getBindValue( + evaluator.evaluate(binding.getExpression())); - if (valueForBinding.isEmpty()) { - bindSpecToUse = bindSpecToUse.bindNull(binding.getParameterName(), valueForBinding.getType()); - } else { - bindSpecToUse = bindSpecToUse.bind(binding.getParameterName(), valueForBinding.getValue()); - } + bindSpecToUse = bind(bindSpecToUse, binding.getParameterName(), valueForBinding); } return bindSpecToUse; } private DatabaseClient.GenericExecuteSpec bindParameters(DatabaseClient.GenericExecuteSpec bindSpec, - boolean bindableNull, Object[] values, Parameters bindableParameters) { + boolean hasBindableNullValue, Object[] values, Parameters bindableParameters) { DatabaseClient.GenericExecuteSpec bindSpecToUse = bindSpec; int bindingIndex = 0; for (Parameter bindableParameter : bindableParameters) { - Object value = values[bindableParameter.getIndex()]; Optional name = bindableParameter.getName(); - if ((name.isPresent() && isNamedParameterUsed(name)) || !expressionQuery.getBindings().isEmpty()) { + if (name.isPresent() && (isNamedParameterReferencedFromQuery(name)) || !expressionQuery.getBindings().isEmpty()) { - if (isNamedParameterUsed(name)) { + if (!isNamedParameterReferencedFromQuery(name)) { + continue; + } + + org.springframework.r2dbc.core.Parameter parameter = getBindValue(values, bindableParameter); - if (value == null) { - if (bindableNull) { - bindSpecToUse = bindSpecToUse.bindNull(name.get(), bindableParameter.getType()); - } - } else { - bindSpecToUse = bindSpecToUse.bind(name.get(), value); - } + if (!parameter.isEmpty() || hasBindableNullValue) { + bindSpecToUse = bind(bindSpecToUse, name.get(), parameter); } // skip unused named parameters if there is SpEL } else { - if (value == null) { - if (bindableNull) { - bindSpecToUse = bindSpecToUse.bindNull(bindingIndex++, bindableParameter.getType()); - } - } else { - bindSpecToUse = bindSpecToUse.bind(bindingIndex++, value); + + org.springframework.r2dbc.core.Parameter parameter = getBindValue(values, bindableParameter); + + if (!parameter.isEmpty() || hasBindableNullValue) { + bindSpecToUse = bind(bindSpecToUse, bindingIndex++, parameter); } } } @@ -127,7 +126,42 @@ class ExpressionEvaluatingParameterBinder { return bindSpecToUse; } - private boolean isNamedParameterUsed(Optional name) { + private org.springframework.r2dbc.core.Parameter getBindValue(Object[] values, Parameter bindableParameter) { + + org.springframework.r2dbc.core.Parameter parameter = org.springframework.r2dbc.core.Parameter + .fromOrEmpty(values[bindableParameter.getIndex()], bindableParameter.getType()); + + return dataAccessStrategy.getBindValue(parameter); + } + + private static DatabaseClient.GenericExecuteSpec bind(DatabaseClient.GenericExecuteSpec spec, String name, + org.springframework.r2dbc.core.Parameter parameter) { + + Object value = parameter.getValue(); + if (value == null) { + return spec.bindNull(name, parameter.getType()); + } else { + return spec.bind(name, value); + } + } + + private static DatabaseClient.GenericExecuteSpec bind(DatabaseClient.GenericExecuteSpec spec, int index, + org.springframework.r2dbc.core.Parameter parameter) { + + Object value = parameter.getValue(); + if (value == null) { + return spec.bindNull(index, parameter.getType()); + } else { + + return spec.bind(index, value); + } + } + + private org.springframework.r2dbc.core.Parameter getBindValue(org.springframework.r2dbc.core.Parameter bindValue) { + return dataAccessStrategy.getBindValue(bindValue); + } + + private boolean isNamedParameterReferencedFromQuery(Optional name) { if (!name.isPresent()) { return false; diff --git a/src/main/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQuery.java b/src/main/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQuery.java index 2cd55e4b3..81ae7ed9a 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQuery.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQuery.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; import org.springframework.data.r2dbc.convert.R2dbcConverter; +import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy; import org.springframework.data.r2dbc.repository.Query; import org.springframework.data.relational.repository.query.RelationalParameterAccessor; import org.springframework.data.repository.query.QueryMethodEvaluationContextProvider; @@ -54,12 +55,15 @@ public class StringBasedR2dbcQuery extends AbstractR2dbcQuery { * @param queryMethod must not be {@literal null}. * @param databaseClient must not be {@literal null}. * @param converter must not be {@literal null}. + * @param dataAccessStrategy must not be {@literal null}. * @param expressionParser must not be {@literal null}. * @param evaluationContextProvider must not be {@literal null}. */ public StringBasedR2dbcQuery(R2dbcQueryMethod queryMethod, DatabaseClient databaseClient, R2dbcConverter converter, + ReactiveDataAccessStrategy dataAccessStrategy, ExpressionParser expressionParser, ReactiveQueryMethodEvaluationContextProvider evaluationContextProvider) { - this(queryMethod.getRequiredAnnotatedQuery(), queryMethod, databaseClient, converter, expressionParser, + this(queryMethod.getRequiredAnnotatedQuery(), queryMethod, databaseClient, converter, dataAccessStrategy, + expressionParser, evaluationContextProvider); } @@ -70,11 +74,12 @@ public class StringBasedR2dbcQuery extends AbstractR2dbcQuery { * @param method must not be {@literal null}. * @param databaseClient must not be {@literal null}. * @param converter must not be {@literal null}. + * @param dataAccessStrategy must not be {@literal null}. * @param expressionParser must not be {@literal null}. * @param evaluationContextProvider must not be {@literal null}. */ public StringBasedR2dbcQuery(String query, R2dbcQueryMethod method, DatabaseClient databaseClient, - R2dbcConverter converter, ExpressionParser expressionParser, + R2dbcConverter converter, ReactiveDataAccessStrategy dataAccessStrategy, ExpressionParser expressionParser, ReactiveQueryMethodEvaluationContextProvider evaluationContextProvider) { super(method, databaseClient, converter); @@ -84,7 +89,7 @@ public class StringBasedR2dbcQuery extends AbstractR2dbcQuery { Assert.hasText(query, "Query must not be empty"); this.expressionQuery = ExpressionQuery.create(query); - this.binder = new ExpressionEvaluatingParameterBinder(expressionQuery); + this.binder = new ExpressionEvaluatingParameterBinder(expressionQuery, dataAccessStrategy); this.expressionDependencies = createExpressionDependencies(); } diff --git a/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactory.java b/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactory.java index f3bbb00bc..70fb65713 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactory.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactory.java @@ -189,9 +189,11 @@ public class R2dbcRepositoryFactory extends ReactiveRepositoryFactorySupport { if (namedQueries.hasQuery(namedQueryName)) { String namedQuery = namedQueries.getQuery(namedQueryName); return new StringBasedR2dbcQuery(namedQuery, queryMethod, this.databaseClient, this.converter, + this.dataAccessStrategy, parser, this.evaluationContextProvider); } else if (queryMethod.hasAnnotatedQuery()) { - return new StringBasedR2dbcQuery(queryMethod, this.databaseClient, this.converter, parser, + return new StringBasedR2dbcQuery(queryMethod, this.databaseClient, this.converter, this.dataAccessStrategy, + this.parser, this.evaluationContextProvider); } else { return new PartTreeR2dbcQuery(queryMethod, this.databaseClient, this.converter, this.dataAccessStrategy); diff --git a/src/test/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQueryUnitTests.java b/src/test/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQueryUnitTests.java index 87c4ce536..a8f43c2f8 100644 --- a/src/test/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQueryUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQueryUnitTests.java @@ -33,6 +33,9 @@ import org.springframework.data.domain.Sort; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.r2dbc.convert.MappingR2dbcConverter; +import org.springframework.data.r2dbc.core.DefaultReactiveDataAccessStrategy; +import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy; +import org.springframework.data.r2dbc.dialect.PostgresDialect; import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; import org.springframework.data.r2dbc.repository.Query; import org.springframework.data.relational.core.mapping.RelationalMappingContext; @@ -62,6 +65,7 @@ public class StringBasedR2dbcQueryUnitTests { private RelationalMappingContext mappingContext; private MappingR2dbcConverter converter; + private ReactiveDataAccessStrategy accessStrategy; private ProjectionFactory factory; private RepositoryMetadata metadata; @@ -70,6 +74,7 @@ public class StringBasedR2dbcQueryUnitTests { this.mappingContext = new R2dbcMappingContext(); this.converter = new MappingR2dbcConverter(this.mappingContext); + this.accessStrategy = new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE, converter); this.metadata = AbstractRepositoryMetadata.getMetadata(SampleRepository.class); this.factory = new SpelAwareProxyProjectionFactory(); @@ -240,13 +245,26 @@ public class StringBasedR2dbcQueryUnitTests { verifyNoMoreInteractions(bindSpec); } + @Test // gh-465 + void translatesEnumToDatabaseValue() { + + StringBasedR2dbcQuery query = getQueryMethod("queryWithEnum", MyEnum.class); + R2dbcParameterAccessor accessor = new R2dbcParameterAccessor(query.getQueryMethod(), MyEnum.INSTANCE); + + BindableQuery stringQuery = query.createQuery(accessor).block(); + assertThat(stringQuery.bind(bindSpec)).isNotNull(); + + verify(bindSpec).bind(0, "INSTANCE"); + verifyNoMoreInteractions(bindSpec); + } + private StringBasedR2dbcQuery getQueryMethod(String name, Class... args) { Method method = ReflectionUtils.findMethod(SampleRepository.class, name, args); R2dbcQueryMethod queryMethod = new R2dbcQueryMethod(method, metadata, factory, converter.getMappingContext()); - return new StringBasedR2dbcQuery(queryMethod, databaseClient, converter, PARSER, + return new StringBasedR2dbcQuery(queryMethod, databaseClient, converter, accessStrategy, PARSER, ReactiveQueryMethodEvaluationContextProvider.DEFAULT); } @@ -285,6 +303,9 @@ public class StringBasedR2dbcQueryUnitTests { @Query("SELECT * FROM person WHERE lastname = :name") Person queryWithUnusedParameter(String name, Sort unused); + + @Query("SELECT * FROM person WHERE lastname = :name") + Person queryWithEnum(MyEnum myEnum); } static class Person { @@ -299,4 +320,8 @@ public class StringBasedR2dbcQueryUnitTests { return name; } } + + enum MyEnum { + INSTANCE; + } }