diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/JdbcQueryCreator.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/JdbcQueryCreator.java index 342b6bf0c..d361dbbef 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/JdbcQueryCreator.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/JdbcQueryCreator.java @@ -15,18 +15,21 @@ */ package org.springframework.data.jdbc.repository.query; -import java.util.ArrayList; -import java.util.Collection; +import java.util.Collections; +import java.util.List; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.jdbc.core.convert.JdbcConverter; +import org.springframework.data.mapping.PersistentPropertyPath; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.relational.core.dialect.Dialect; import org.springframework.data.relational.core.dialect.RenderContextFactory; +import org.springframework.data.relational.core.mapping.PersistentPropertyPathExtension; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.relational.core.query.Criteria; +import org.springframework.data.relational.core.sql.Expression; import org.springframework.data.relational.core.sql.Select; import org.springframework.data.relational.core.sql.SelectBuilder; import org.springframework.data.relational.core.sql.SqlIdentifier; @@ -35,6 +38,8 @@ import org.springframework.data.relational.core.sql.render.SqlRenderer; import org.springframework.data.relational.repository.query.RelationalEntityMetadata; import org.springframework.data.relational.repository.query.RelationalParameterAccessor; import org.springframework.data.relational.repository.query.RelationalQueryCreator; +import org.springframework.data.repository.query.Parameters; +import org.springframework.data.repository.query.parser.Part; import org.springframework.data.repository.query.parser.PartTree; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.util.Assert; @@ -50,8 +55,6 @@ class JdbcQueryCreator extends RelationalQueryCreator { private final PartTree tree; private final RelationalParameterAccessor accessor; private final QueryMapper queryMapper; - - private final MappingContext, RelationalPersistentProperty> mappingContext; private final RelationalEntityMetadata entityMetadata; private final RenderContextFactory renderContextFactory; @@ -65,8 +68,8 @@ class JdbcQueryCreator extends RelationalQueryCreator { * @param entityMetadata relational entity metadata, must not be {@literal null}. * @param accessor parameter metadata provider, must not be {@literal null}. */ - JdbcQueryCreator(PartTree tree, JdbcConverter converter, Dialect dialect, - RelationalEntityMetadata entityMetadata, RelationalParameterAccessor accessor) { + JdbcQueryCreator(PartTree tree, JdbcConverter converter, Dialect dialect, RelationalEntityMetadata entityMetadata, + RelationalParameterAccessor accessor) { super(tree, accessor); Assert.notNull(converter, "JdbcConverter must not be null"); @@ -76,12 +79,60 @@ class JdbcQueryCreator extends RelationalQueryCreator { this.tree = tree; this.accessor = accessor; - this.mappingContext = (MappingContext) converter.getMappingContext(); this.entityMetadata = entityMetadata; this.queryMapper = new QueryMapper(dialect, converter); this.renderContextFactory = new RenderContextFactory(dialect); } + /** + * Validate parameters for the derived query. Specifically checking that the query method defines scalar parameters + * and collection parameters where required and that invalid parameter declarations are rejected. + * + * @param tree + * @param parameters + */ + public static void validate(PartTree tree, Parameters parameters, + MappingContext, ? extends RelationalPersistentProperty> context) { + + RelationalQueryCreator.validate(tree, parameters); + + for (PartTree.OrPart parts : tree) { + for (Part part : parts) { + + PersistentPropertyPath propertyPath = context + .getPersistentPropertyPath(part.getProperty()); + PersistentPropertyPathExtension path = new PersistentPropertyPathExtension(context, propertyPath); + + for (PersistentPropertyPathExtension pathToValidate = path; path.getLength() > 0; path = path.getParentPath()) { + validateProperty(pathToValidate); + } + } + } + } + + private static void validateProperty(PersistentPropertyPathExtension path) { + + if (!path.getParentPath().isEmbedded() && path.getLength() > 1) { + throw new IllegalArgumentException( + String.format("Cannot query by nested property: %s", path.getRequiredPersistentPropertyPath().toDotPath())); + } + + if (path.isMultiValued() || path.isMap()) { + throw new IllegalArgumentException(String.format("Cannot query by multi-valued property: %s", + path.getRequiredPersistentPropertyPath().getLeafProperty().getName())); + } + + if (!path.isEmbedded() && path.isEntity()) { + throw new IllegalArgumentException( + String.format("Cannot query by nested entity: %s", path.getRequiredPersistentPropertyPath().toDotPath())); + } + + if (path.getRequiredPersistentPropertyPath().getLeafProperty().isReference()) { + throw new IllegalArgumentException( + String.format("Cannot query by reference: %s", path.getRequiredPersistentPropertyPath().toDotPath())); + } + } + /** * Creates {@link ParametrizedQuery} applying the given {@link Criteria} and {@link Sort} definition. * @@ -96,7 +147,12 @@ class JdbcQueryCreator extends RelationalQueryCreator { Table table = Table.create(entityMetadata.getTableName()); MapSqlParameterSource parameterSource = new MapSqlParameterSource(); - SelectBuilder.SelectFromAndJoin builder = Select.builder().select(table.columns(getSelectProjection())).from(table); + List columns = table.columns(getSelectProjection()); + if (columns.isEmpty()) { + columns = Collections.singletonList(table.asterisk()); + } + + SelectBuilder.SelectFromAndJoin builder = Select.builder().select(columns).from(table); if (tree.isExistsProjection()) { builder = builder.limit(1); @@ -132,27 +188,6 @@ class JdbcQueryCreator extends RelationalQueryCreator { return new SqlIdentifier[] { tableEntity.getIdColumn() }; } - Collection columnNames = unwrapColumnNames("", tableEntity); - - return columnNames.toArray(new SqlIdentifier[0]); - } - - private Collection unwrapColumnNames(String prefix, RelationalPersistentEntity persistentEntity) { - - Collection columnNames = new ArrayList<>(); - - for (RelationalPersistentProperty property : persistentEntity) { - - if (property.isEmbedded()) { - columnNames.addAll( - unwrapColumnNames(prefix + property.getEmbeddedPrefix(), mappingContext.getPersistentEntity(property))); - } - - else { - columnNames.add(property.getColumnName().transform(prefix::concat)); - } - } - - return columnNames; + return new SqlIdentifier[0]; } } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java index 775080391..cc9ed78b8 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java @@ -63,15 +63,8 @@ public class PartTreeJdbcQuery extends AbstractJdbcQuery { this.dialect = dialect; this.converter = converter; - try { - - this.tree = new PartTree(queryMethod.getName(), queryMethod.getEntityInformation().getJavaType()); - JdbcQueryCreator.validate(this.tree, this.parameters); - } catch (RuntimeException e) { - - throw new IllegalArgumentException( - String.format("Failed to create query for method %s! %s", queryMethod, e.getMessage()), e); - } + this.tree = new PartTree(queryMethod.getName(), queryMethod.getEntityInformation().getJavaType()); + JdbcQueryCreator.validate(this.tree, this.parameters, this.converter.getMappingContext()); this.execution = getQueryExecution(queryMethod, null, rowMapper); } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/JdbcQueryLookupStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/JdbcQueryLookupStrategy.java index 2df4b4225..1c95e5d37 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/JdbcQueryLookupStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/JdbcQueryLookupStrategy.java @@ -35,6 +35,7 @@ import org.springframework.data.relational.core.mapping.event.AfterLoadCallback; import org.springframework.data.relational.core.mapping.event.AfterLoadEvent; import org.springframework.data.repository.core.NamedQueries; import org.springframework.data.repository.core.RepositoryMetadata; +import org.springframework.data.repository.query.QueryCreationException; import org.springframework.data.repository.query.QueryLookupStrategy; import org.springframework.data.repository.query.RepositoryQuery; import org.springframework.jdbc.core.RowMapper; @@ -94,12 +95,16 @@ class JdbcQueryLookupStrategy implements QueryLookupStrategy { JdbcQueryMethod queryMethod = new JdbcQueryMethod(method, repositoryMetadata, projectionFactory, namedQueries, context); - if (namedQueries.hasQuery(queryMethod.getNamedQueryName()) || queryMethod.hasAnnotatedQuery()) { + try { + if (namedQueries.hasQuery(queryMethod.getNamedQueryName()) || queryMethod.hasAnnotatedQuery()) { - RowMapper mapper = queryMethod.isModifyingQuery() ? null : createMapper(queryMethod); - return new StringBasedJdbcQuery(queryMethod, operations, mapper, converter); - } else { - return new PartTreeJdbcQuery(queryMethod, dialect, converter, operations, createMapper(queryMethod)); + RowMapper mapper = queryMethod.isModifyingQuery() ? null : createMapper(queryMethod); + return new StringBasedJdbcQuery(queryMethod, operations, mapper, converter); + } else { + return new PartTreeJdbcQuery(queryMethod, dialect, converter, operations, createMapper(queryMethod)); + } + } catch (Exception e) { + throw QueryCreationException.create(queryMethod, e.getMessage()); } } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQueryUnitTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQueryUnitTests.java index 862d45d7f..bdd3849f6 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQueryUnitTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQueryUnitTests.java @@ -30,16 +30,20 @@ import java.util.Properties; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.data.annotation.Id; import org.springframework.data.jdbc.core.convert.BasicJdbcConverter; import org.springframework.data.jdbc.core.convert.JdbcConverter; import org.springframework.data.jdbc.core.convert.RelationResolver; +import org.springframework.data.jdbc.core.mapping.AggregateReference; import org.springframework.data.jdbc.core.mapping.JdbcMappingContext; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.relational.core.dialect.H2Dialect; import org.springframework.data.relational.core.mapping.Embedded; +import org.springframework.data.relational.core.mapping.MappedCollection; import org.springframework.data.relational.core.mapping.Table; import org.springframework.data.relational.repository.query.RelationalParametersParameterAccessor; +import org.springframework.data.repository.NoRepositoryBean; import org.springframework.data.repository.Repository; import org.springframework.data.repository.core.support.DefaultRepositoryMetadata; import org.springframework.data.repository.core.support.PropertiesBasedNamedQueries; @@ -51,34 +55,50 @@ import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; * * @author Roman Chigvintsev * @author Mark Paluch + * @author Jens Schauder */ @RunWith(MockitoJUnitRunner.class) public class PartTreeJdbcQueryUnitTests { private static final String TABLE = "\"users\""; - private static final String ALL_FIELDS = "\"users\".\"ID\", \"users\".\"FIRST_NAME\", \"users\".\"LAST_NAME\", \"users\".\"DATE_OF_BIRTH\", \"users\".\"AGE\", \"users\".\"ACTIVE\", \"users\".\"USER_STREET\", \"users\".\"USER_CITY\""; + private static final String ALL_FIELDS = "\"users\".*"; JdbcMappingContext mappingContext = new JdbcMappingContext(); JdbcConverter converter = new BasicJdbcConverter(mappingContext, mock(RelationResolver.class)); @Test // DATAJDBC-318 - public void selectContainsColumnsForOneToOneReference() throws Exception { + public void shouldFailForQueryByReference() throws Exception { - JdbcQueryMethod queryMethod = getQueryMethod("findAllByFirstName", String.class); - PartTreeJdbcQuery jdbcQuery = createQuery(queryMethod); - ParametrizedQuery query = jdbcQuery.createQuery(getAccessor(queryMethod, new Object[] { "John" })); + JdbcQueryMethod queryMethod = getQueryMethod("findAllByHated", Hobby.class); + assertThatIllegalArgumentException().isThrownBy(() -> createQuery(queryMethod)); + } + + @Test // DATAJDBC-318 + public void shouldFailForQueryByAggregateReference() throws Exception { - assertThat(query.getQuery()).contains("hated.\"name\" AS \"hated_name\""); + JdbcQueryMethod queryMethod = getQueryMethod("findAllByHobbyReference", Hobby.class); + assertThatIllegalArgumentException().isThrownBy(() -> createQuery(queryMethod)); } @Test // DATAJDBC-318 - public void doesNotContainsColumnsForOneToManyReference() throws Exception{ + public void shouldFailForQueryByList() throws Exception { - JdbcQueryMethod queryMethod = getQueryMethod("findAllByFirstName", String.class); - PartTreeJdbcQuery jdbcQuery = createQuery(queryMethod); - ParametrizedQuery query = jdbcQuery.createQuery(getAccessor(queryMethod, new Object[] { "John" })); + JdbcQueryMethod queryMethod = getQueryMethod("findAllByHobbies", Object.class); + assertThatIllegalArgumentException().isThrownBy(() -> createQuery(queryMethod)); + } + + @Test // DATAJDBC-318 + public void shouldFailForQueryByEmbeddedList() throws Exception { - assertThat(query.getQuery().toLowerCase()).doesNotContain("hobbies"); + JdbcQueryMethod queryMethod = getQueryMethod("findByAnotherEmbeddedList", Object.class); + assertThatIllegalArgumentException().isThrownBy(() -> createQuery(queryMethod)); + } + + @Test // DATAJDBC-318 + public void shouldFailForAggregateReference() throws Exception { + + JdbcQueryMethod queryMethod = getQueryMethod("findByAnotherEmbeddedList", Object.class); + assertThatIllegalArgumentException().isThrownBy(() -> createQuery(queryMethod)); } @Test // DATAJDBC-318 @@ -570,10 +590,17 @@ public class PartTreeJdbcQueryUnitTests { return new RelationalParametersParameterAccessor(queryMethod, values); } + @NoRepositoryBean interface UserRepository extends Repository { List findAllByFirstName(String firstName); + List findAllByHated(Hobby hobby); + + List findAllByHobbies(Object hobbies); + + List findAllByHobbyReference(Hobby hobby); + List findAllByLastNameAndFirstName(String lastName, String firstName); List findAllByLastNameOrFirstName(String lastName, String firstName); @@ -637,6 +664,8 @@ public class PartTreeJdbcQueryUnitTests { User findByAddress(Address address); User findByAddressStreet(String street); + + User findByAnotherEmbeddedList(Object list); } @Table("users") @@ -650,9 +679,12 @@ public class PartTreeJdbcQueryUnitTests { Boolean active; @Embedded(prefix = "user_", onEmpty = Embedded.OnEmpty.USE_NULL) Address address; + @Embedded.Nullable AnotherEmbedded anotherEmbedded; List hobbies; Hobby hated; + + AggregateReference hobbyReference; } @AllArgsConstructor @@ -661,6 +693,11 @@ public class PartTreeJdbcQueryUnitTests { String city; } + @AllArgsConstructor + static class AnotherEmbedded { + @MappedCollection(idColumn = "ID", keyColumn = "ORDER_KEY") List list; + } + static class Hobby { String name; }