Browse Source

Allow non-domain sort orders to be used with `R2dbcQueryCreator`.

Closes #1548
pull/1570/head
Mark Paluch 3 years ago
parent
commit
e2514a3454
No known key found for this signature in database
GPG Key ID: 4406B84C1661DCD1
  1. 13
      spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/repository/query/R2dbcQueryCreator.java
  2. 21
      spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java

13
spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/repository/query/R2dbcQueryCreator.java

@ -136,7 +136,7 @@ class R2dbcQueryCreator extends RelationalQueryCreator<PreparedOperation<?>> {
} }
if (sort.isSorted()) { if (sort.isSorted()) {
selectSpec = selectSpec.withSort(getSort(sort)); selectSpec = selectSpec.withSort(sort);
} }
if (tree.isDistinct()) { if (tree.isDistinct()) {
@ -186,15 +186,4 @@ class R2dbcQueryCreator extends RelationalQueryCreator<PreparedOperation<?>> {
return expressions.toArray(new Expression[0]); return expressions.toArray(new Expression[0]);
} }
private Sort getSort(Sort sort) {
RelationalPersistentEntity<?> tableEntity = entityMetadata.getTableEntity();
List<Sort.Order> orders = sort.get().map(order -> {
RelationalPersistentProperty property = tableEntity.getRequiredPersistentProperty(order.getProperty());
return order.isAscending() ? Sort.Order.asc(property.getName()) : Sort.Order.desc(property.getName());
}).collect(Collectors.toList());
return Sort.by(orders);
}
} }

21
spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java

@ -38,9 +38,10 @@ import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness; import org.mockito.quality.Strictness;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.annotation.Id; import org.springframework.data.annotation.Id;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Sort.Direction;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.r2dbc.convert.R2dbcConverter; import org.springframework.data.r2dbc.convert.R2dbcConverter;
import org.springframework.data.r2dbc.core.DefaultReactiveDataAccessStrategy; import org.springframework.data.r2dbc.core.DefaultReactiveDataAccessStrategy;
@ -53,6 +54,7 @@ import org.springframework.data.r2dbc.mapping.R2dbcMappingContext;
import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.data.relational.core.mapping.Table; import org.springframework.data.relational.core.mapping.Table;
import org.springframework.data.relational.core.sql.LockMode; import org.springframework.data.relational.core.sql.LockMode;
import org.springframework.data.relational.domain.SqlSort;
import org.springframework.data.relational.repository.Lock; import org.springframework.data.relational.repository.Lock;
import org.springframework.data.relational.repository.query.RelationalParametersParameterAccessor; import org.springframework.data.relational.repository.query.RelationalParametersParameterAccessor;
import org.springframework.data.repository.Repository; import org.springframework.data.repository.Repository;
@ -599,6 +601,21 @@ class PartTreeR2dbcQueryUnitTests {
.isThrownBy(() -> createQuery(r2dbcQuery, getAccessor(queryMethod, new Object[0]))); .isThrownBy(() -> createQuery(r2dbcQuery, getAccessor(queryMethod, new Object[0])));
} }
@Test // GH-1548
void allowsSortingByNonDomainProperties() throws Exception {
R2dbcQueryMethod queryMethod = getQueryMethod("findAllByFirstName", String.class, Sort.class);
PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy);
PreparedOperation<?> preparedOperation = createQuery(queryMethod, r2dbcQuery, "foo", Sort.by("foobar"));
PreparedOperationAssert.assertThat(preparedOperation) //
.orderBy("users.foobar ASC");
preparedOperation = createQuery(queryMethod, r2dbcQuery, "foo", SqlSort.unsafe(Direction.ASC, "sum(foobar)"));
PreparedOperationAssert.assertThat(preparedOperation) //
.orderBy("sum(foobar) ASC");
}
@Test // GH-282 @Test // GH-282
void throwsExceptionWhenInvalidNumberOfParameterIsGiven() throws Exception { void throwsExceptionWhenInvalidNumberOfParameterIsGiven() throws Exception {
@ -960,6 +977,8 @@ class PartTreeR2dbcQueryUnitTests {
Flux<User> findAllByIdIsEmpty(); Flux<User> findAllByIdIsEmpty();
Flux<User> findAllByFirstName(String firstName, Sort sort);
Flux<User> findTop3ByFirstName(String firstName); Flux<User> findTop3ByFirstName(String firstName);
Mono<User> findFirstByFirstName(String firstName); Mono<User> findFirstByFirstName(String firstName);

Loading…
Cancel
Save