From b46e34263eed96462cf3c0db93cf7ab2132610a5 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Thu, 18 Aug 2022 12:06:03 +0200 Subject: [PATCH] Fix COUNT/EXISTS projections for entities without an identifier. We now issue a COUNT(1) respective SELECT 1 for COUNT queries and EXISTS queries for entities that do not specify an identifier. Previously these query projections could fail because of empty select lists. Closes #1310 --- .../data/r2dbc/core/R2dbcEntityTemplate.java | 15 +++--- .../data/r2dbc/query/QueryMapper.java | 3 +- .../repository/query/R2dbcQueryCreator.java | 21 ++++---- .../core/R2dbcEntityTemplateUnitTests.java | 30 +++++++++++ .../query/PartTreeR2dbcQueryUnitTests.java | 52 +++++++++++++++++-- 5 files changed, 100 insertions(+), 21 deletions(-) diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java index d327042ca..7379f4eda 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java @@ -298,7 +298,7 @@ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAw Expression countExpression = entity.hasIdProperty() ? table.column(entity.getRequiredIdProperty().getColumnName()) - : Expressions.asterisk(); + : Expressions.just("1"); return spec.withProjection(Functions.count(countExpression)); }); @@ -333,13 +333,14 @@ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAw RelationalPersistentEntity entity = getRequiredEntity(entityClass); StatementMapper statementMapper = dataAccessStrategy.getStatementMapper().forType(entityClass); - SqlIdentifier columnName = entity.hasIdProperty() ? entity.getRequiredIdProperty().getColumnName() - : SqlIdentifier.unquoted("*"); + StatementMapper.SelectSpec selectSpec = statementMapper.createSelect(tableName).limit(1); + if (entity.hasIdProperty()) { + selectSpec = selectSpec // + .withProjection(entity.getRequiredIdProperty().getColumnName()); - StatementMapper.SelectSpec selectSpec = statementMapper // - .createSelect(tableName) // - .withProjection(columnName) // - .limit(1); + } else { + selectSpec = selectSpec.withProjection(Expressions.just("1")); + } Optional criteria = query.getCriteria(); if (criteria.isPresent()) { diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java index 188760ed7..0bf28c9d0 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java @@ -153,7 +153,8 @@ public class QueryMapper { */ public Expression getMappedObject(Expression expression, @Nullable RelationalPersistentEntity entity) { - if (entity == null || expression instanceof AsteriskFromTable) { + if (entity == null || expression instanceof AsteriskFromTable + || expression instanceof Expressions.SimpleExpression) { return expression; } diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/repository/query/R2dbcQueryCreator.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/repository/query/R2dbcQueryCreator.java index 89d2dfa66..4cdcf7a73 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/repository/query/R2dbcQueryCreator.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/repository/query/R2dbcQueryCreator.java @@ -25,11 +25,16 @@ import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy; import org.springframework.data.r2dbc.core.StatementMapper; -import org.springframework.data.relational.repository.Lock; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.relational.core.query.Criteria; -import org.springframework.data.relational.core.sql.*; +import org.springframework.data.relational.core.sql.Column; +import org.springframework.data.relational.core.sql.Expression; +import org.springframework.data.relational.core.sql.Expressions; +import org.springframework.data.relational.core.sql.Functions; +import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.data.relational.core.sql.Table; +import org.springframework.data.relational.repository.Lock; import org.springframework.data.relational.repository.query.RelationalEntityMetadata; import org.springframework.data.relational.repository.query.RelationalParameterAccessor; import org.springframework.data.relational.repository.query.RelationalQueryCreator; @@ -164,18 +169,14 @@ class R2dbcQueryCreator extends RelationalQueryCreator> { expressions.add(column); } - } else if (tree.isExistsProjection()) { - - expressions = dataAccessStrategy.getIdentifierColumns(entityToRead).stream() // - .map(table::column) // - .collect(Collectors.toList()); - } else if (tree.isCountProjection()) { + } else if (tree.isExistsProjection() || tree.isCountProjection()) { Expression countExpression = entityMetadata.getTableEntity().hasIdProperty() ? table.column(entityMetadata.getTableEntity().getRequiredIdProperty().getColumnName()) - : Expressions.asterisk(); + : Expressions.just("1"); - expressions = Collections.singletonList(Functions.count(countExpression)); + expressions = Collections + .singletonList(tree.isCountProjection() ? Functions.count(countExpression) : countExpression); } else { expressions = dataAccessStrategy.getAllColumns(entityToRead).stream() // .map(table::column) // diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java index 10be58109..b9c0a12a0 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java @@ -122,6 +122,30 @@ public class R2dbcEntityTemplateUnitTests { .verifyComplete(); } + @Test // gh-1310 + void shouldProjectExistsResultWithoutId() { + + MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Object.class, null).build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT 1"), result); + + entityTemplate.select(WithoutId.class).exists() // + .as(StepVerifier::create) // + .expectNext(true).verifyComplete(); + } + + @Test // gh-1310 + void shouldProjectCountResultWithoutId() { + + MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Long.class, 1L).build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT COUNT(1)"), result); + + entityTemplate.select(WithoutId.class).count() // + .as(StepVerifier::create) // + .expectNext(1L).verifyComplete(); + } + @Test // gh-469 void shouldExistsByCriteria() { @@ -477,6 +501,12 @@ public class R2dbcEntityTemplateUnitTests { Parameter.from("before-save")); } + @Value + static class WithoutId { + + String name; + } + @Value @With static class Person { diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java index 6a73c1b95..f25c42914 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java @@ -38,6 +38,7 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; + import org.springframework.beans.factory.annotation.Value; import org.springframework.data.annotation.Id; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; @@ -49,10 +50,10 @@ import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy; import org.springframework.data.r2dbc.dialect.DialectResolver; import org.springframework.data.r2dbc.dialect.R2dbcDialect; import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; -import org.springframework.data.relational.repository.Lock; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.Table; import org.springframework.data.relational.core.sql.LockMode; +import org.springframework.data.relational.repository.Lock; import org.springframework.data.relational.repository.query.RelationalParametersParameterAccessor; import org.springframework.data.repository.Repository; import org.springframework.data.repository.core.support.DefaultRepositoryMetadata; @@ -748,6 +749,32 @@ class PartTreeR2dbcQueryUnitTests { verify(bindTarget, times(1)).bind(0, "John"); } + @Test // GH-1310 + void createsQueryWithoutIdForCountProjection() throws Exception { + + R2dbcQueryMethod queryMethod = getQueryMethod(WithoutIdRepository.class, "countByFirstName", String.class); + PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy); + PreparedOperation query = createQuery(queryMethod, r2dbcQuery, "John"); + + PreparedOperationAssert.assertThat(query) // + .selects("COUNT(1)") // + .from(TABLE) // + .where(TABLE + ".first_name = $1"); + } + + @Test // GH-1310 + void createsQueryWithoutIdForExistsProjection() throws Exception { + + R2dbcQueryMethod queryMethod = getQueryMethod(WithoutIdRepository.class, "existsByFirstName", String.class); + PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy); + PreparedOperation query = createQuery(queryMethod, r2dbcQuery, "John"); + + PreparedOperationAssert.assertThat(query) // + .selects("1") // + .from(TABLE) // + .where(TABLE + ".first_name = $1 LIMIT 1"); + } + private PreparedOperation createQuery(R2dbcQueryMethod queryMethod, PartTreeR2dbcQuery r2dbcQuery, Object... parameters) { return createQuery(r2dbcQuery, getAccessor(queryMethod, parameters)); @@ -759,8 +786,13 @@ class PartTreeR2dbcQueryUnitTests { } private R2dbcQueryMethod getQueryMethod(String methodName, Class... parameterTypes) throws Exception { - Method method = UserRepository.class.getMethod(methodName, parameterTypes); - return new R2dbcQueryMethod(method, new DefaultRepositoryMetadata(UserRepository.class), + return getQueryMethod(UserRepository.class, methodName, parameterTypes); + } + + private R2dbcQueryMethod getQueryMethod(Class repository, String methodName, Class... parameterTypes) + throws Exception { + Method method = repository.getMethod(methodName, parameterTypes); + return new R2dbcQueryMethod(method, new DefaultRepositoryMetadata(repository), new SpelAwareProxyProjectionFactory(), mappingContext); } @@ -946,6 +978,13 @@ class PartTreeR2dbcQueryUnitTests { } + interface WithoutIdRepository extends Repository { + + Mono existsByFirstName(String firstName); + + Mono countByFirstName(String firstName); + } + @Table("users") @Data private static class User { @@ -958,6 +997,13 @@ class PartTreeR2dbcQueryUnitTests { private Boolean active; } + @Table("users") + @Data + private static class WithoutId { + + private String firstName; + } + interface UserProjection { String getFirstName();