From c241cd982ee73fff32bc13bac6fda9931d83faaa Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 26 Feb 2019 13:50:20 +0100 Subject: [PATCH] #56 - Use Statement Builder API for SELECT statements. Original pull request: #66. --- .../r2dbc/function/DefaultDatabaseClient.java | 17 +-- .../DefaultReactiveDataAccessStrategy.java | 106 ++++++------------ .../function/ReactiveDataAccessStrategy.java | 37 +----- .../support/SimpleR2dbcRepository.java | 70 +++++++++--- .../r2dbc/support/StatementRenderUtil.java | 65 +++++++++++ ...ltReactiveDataAccessStrategyUnitTests.java | 42 ------- 6 files changed, 154 insertions(+), 183 deletions(-) create mode 100644 src/main/java/org/springframework/data/r2dbc/support/StatementRenderUtil.java diff --git a/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClient.java index b8587e06a..6d0ed6f9e 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClient.java +++ b/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClient.java @@ -703,17 +703,9 @@ class DefaultDatabaseClient implements DatabaseClient, ConnectionAccessor { private FetchSpec exchange(BiFunction mappingFunction) { - Set columns; + String select = dataAccessStrategy.select(table, new LinkedHashSet<>(this.projectedFields), sort, page); - if (this.projectedFields.isEmpty()) { - columns = Collections.singleton("*"); - } else { - columns = new LinkedHashSet<>(this.projectedFields); - } - - QueryOperation select = dataAccessStrategy.select(table, columns, sort, page); - - return execute(select.toQuery(), mappingFunction); + return execute(select, mappingFunction); } @Override @@ -797,11 +789,10 @@ class DefaultDatabaseClient implements DatabaseClient, ConnectionAccessor { } else { columns = this.projectedFields; } - Sort sortToUse = sort.isSorted() ? dataAccessStrategy.getMappedSort(typeToRead, sort) : Sort.unsorted(); - QueryOperation select = dataAccessStrategy.select(table, new LinkedHashSet<>(columns), sortToUse, page); + String select = dataAccessStrategy.select(table, new LinkedHashSet<>(columns), sort, page); - return execute(select.get(), mappingFunction); + return execute(select, mappingFunction); } @Override diff --git a/src/main/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategy.java b/src/main/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategy.java index 54740652a..2b5f94dae 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategy.java +++ b/src/main/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategy.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.OptionalLong; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Function; @@ -43,16 +44,20 @@ import org.springframework.data.r2dbc.dialect.BindMarker; import org.springframework.data.r2dbc.dialect.BindMarkers; import org.springframework.data.r2dbc.dialect.BindMarkersFactory; import org.springframework.data.r2dbc.dialect.Dialect; -import org.springframework.data.r2dbc.dialect.LimitClause; -import org.springframework.data.r2dbc.dialect.LimitClause.Position; import org.springframework.data.r2dbc.function.convert.EntityRowMapper; import org.springframework.data.r2dbc.function.convert.R2dbcCustomConversions; import org.springframework.data.r2dbc.function.convert.SettableValue; +import org.springframework.data.r2dbc.support.StatementRenderUtil; import org.springframework.data.relational.core.conversion.BasicRelationalConverter; import org.springframework.data.relational.core.conversion.RelationalConverter; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; +import org.springframework.data.relational.core.sql.Expression; +import org.springframework.data.relational.core.sql.OrderByField; +import org.springframework.data.relational.core.sql.SelectBuilder.SelectFromAndOrderBy; +import org.springframework.data.relational.core.sql.StatementBuilder; +import org.springframework.data.relational.core.sql.Table; import org.springframework.data.util.TypeInformation; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -312,94 +317,47 @@ public class DefaultReactiveDataAccessStrategy implements ReactiveDataAccessStra * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#select(java.lang.String, java.util.Set, org.springframework.data.domain.Sort, org.springframework.data.domain.Pageable) */ @Override - public QueryOperation select(String table, Set columns, Sort sort, Pageable page) { + public String select(String table, Set columns, Sort sort, Pageable page) { - StringBuilder selectBuilder = new StringBuilder(); + Table tableToUse = Table.create(table); - selectBuilder.append("SELECT").append(' ') // - .append(StringUtils.collectionToDelimitedString(columns, ", ")).append(' ') // - .append("FROM").append(' ').append(table); + Collection selectList; - if (sort.isSorted()) { - selectBuilder.append(' ').append("ORDER BY").append(' ').append(getSortClause(sort)); + if (columns.isEmpty()) { + selectList = Collections.singletonList(tableToUse.asterisk()); + } else { + selectList = tableToUse.columns(columns); } - if (page.isPaged()) { - - LimitClause limitClause = dialect.limit(); - - if (limitClause.getClausePosition() == Position.END) { - - selectBuilder.append(' ').append(limitClause.getClause(page.getPageSize(), page.getOffset())); - } - } - - return selectBuilder::toString; - } - - private StringBuilder getSortClause(Sort sort) { - - StringBuilder sortClause = new StringBuilder(); + SelectFromAndOrderBy selectBuilder = StatementBuilder.select(selectList).from(table) + .orderBy(createOrderByFields(tableToUse, sort)); + OptionalLong limit = OptionalLong.empty(); + OptionalLong offset = OptionalLong.empty(); - for (Order order : sort) { - - if (sortClause.length() != 0) { - sortClause.append(',').append(' '); - } - - sortClause.append(order.getProperty()).append(' ').append(order.getDirection().isAscending() ? "ASC" : "DESC"); + if (page.isPaged()) { + limit = OptionalLong.of(page.getPageSize()); + offset = OptionalLong.of(page.getOffset()); } - return sortClause; - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#selectById(java.lang.String, java.util.Set, java.lang.String) - */ - @Override - public BindIdOperation selectById(String table, Set columns, String idColumn) { - - return new DefaultBindIdOperation(dialect.getBindMarkersFactory().create(), marker -> { - String columnClause = StringUtils.collectionToDelimitedString(columns, ", "); - - return String.format("SELECT %s FROM %s WHERE %s = %s", columnClause, table, idColumn, marker.getPlaceholder()); - }); + return StatementRenderUtil.render(selectBuilder.build(), limit, offset, this.dialect); } - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#selectById(java.lang.String, java.util.Set, java.lang.String, int) - */ - @Override - public BindIdOperation selectById(String table, Set columns, String idColumn, int limit) { - - LimitClause limitClause = dialect.limit(); + private Collection createOrderByFields(Table table, Sort sortToUse) { - return new DefaultBindIdOperation(dialect.getBindMarkersFactory().create(), marker -> { + List fields = new ArrayList<>(); - String columnClause = StringUtils.collectionToDelimitedString(columns, ", "); + for (Order order : sortToUse) { - if (limitClause.getClausePosition() == Position.END) { + OrderByField orderByField = OrderByField.from(table.column(order.getProperty())); - return String.format("SELECT %s FROM %s WHERE %s = %s ORDER BY %s %s", columnClause, table, idColumn, - marker.getPlaceholder(), idColumn, limitClause.getClause(limit)); + if (order.getDirection() != null) { + fields.add(order.isAscending() ? orderByField.asc() : orderByField.desc()); + } else { + fields.add(orderByField); } + } - throw new UnsupportedOperationException( - String.format("Limit clause position %s not supported!", limitClause.getClausePosition())); - }); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#selectByIdIn(java.lang.String, java.util.Set, java.lang.String) - */ - @Override - public BindIdOperation selectByIdIn(String table, Set columns, String idColumn) { - - String query = String.format("SELECT %s FROM %s", StringUtils.collectionToDelimitedString(columns, ", "), table); - return new DefaultBindIdIn(dialect.getBindMarkersFactory().create(), query, idColumn); + return fields; } /* diff --git a/src/main/java/org/springframework/data/r2dbc/function/ReactiveDataAccessStrategy.java b/src/main/java/org/springframework/data/r2dbc/function/ReactiveDataAccessStrategy.java index f715e6e43..2d3b1b5b8 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/ReactiveDataAccessStrategy.java +++ b/src/main/java/org/springframework/data/r2dbc/function/ReactiveDataAccessStrategy.java @@ -108,42 +108,7 @@ public interface ReactiveDataAccessStrategy { * @param page * @return */ - QueryOperation select(String table, Set columns, Sort sort, Pageable page); - - /** - * Create a {@code SELECT … WHERE id = ?} operation for the given {@code table} using {@code columns} to project and - * {@code idColumn}. - * - * @param table the table to insert data to. - * @param columns columns to return. - * @param idColumn name of the primary key. - * @return - */ - BindIdOperation selectById(String table, Set columns, String idColumn); - - /** - * Create a {@code SELECT … WHERE id = ?} operation for the given {@code table} using {@code columns} to project and - * {@code idColumn} applying a limit (TOP, LIMIT, …). - * - * @param table the table to insert data to. - * @param columns columns to return. - * @param idColumn name of the primary key. - * @param limit number of rows to return. - * @return - */ - BindIdOperation selectById(String table, Set columns, String idColumn, int limit); - - /** - * Create a {@code SELECT … WHERE id IN (?)} operation for the given {@code table} using {@code columns} to project - * and {@code idColumn}. The actual {@link BindableOperation#toQuery() query} string depends on - * {@link BindIdOperation#bindIds(Statement, Iterable) bound parameters}. - * - * @param table the table to insert data to. - * @param columns columns to return. - * @param idColumn name of the primary key. - * @return - */ - BindIdOperation selectByIdIn(String table, Set columns, String idColumn); + String select(String table, Set columns, Sort sort, Pageable page); /** * Create a {@code UPDATE … SET … WHERE id = ?} operation for the given {@code table} updating {@code columns} and diff --git a/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java b/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java index 24558f89f..4a6ee07b6 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java @@ -21,13 +21,17 @@ import lombok.RequiredArgsConstructor; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.util.Collections; +import java.util.ArrayList; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.BiConsumer; import org.reactivestreams.Publisher; + +import org.springframework.data.r2dbc.dialect.BindMarker; +import org.springframework.data.r2dbc.dialect.BindMarkers; import org.springframework.data.r2dbc.function.BindIdOperation; import org.springframework.data.r2dbc.function.BindableOperation; import org.springframework.data.r2dbc.function.DatabaseClient; @@ -35,6 +39,14 @@ import org.springframework.data.r2dbc.function.DatabaseClient.GenericExecuteSpec import org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy; import org.springframework.data.r2dbc.function.convert.MappingR2dbcConverter; import org.springframework.data.r2dbc.function.convert.SettableValue; +import org.springframework.data.relational.core.sql.Conditions; +import org.springframework.data.relational.core.sql.Expression; +import org.springframework.data.relational.core.sql.Functions; +import org.springframework.data.relational.core.sql.SQL; +import org.springframework.data.relational.core.sql.Select; +import org.springframework.data.relational.core.sql.StatementBuilder; +import org.springframework.data.relational.core.sql.Table; +import org.springframework.data.relational.core.sql.render.SqlRenderer; import org.springframework.data.relational.repository.query.RelationalEntityInformation; import org.springframework.data.repository.reactive.ReactiveCrudRepository; import org.springframework.util.Assert; @@ -118,13 +130,17 @@ public class SimpleR2dbcRepository implements ReactiveCrudRepository columns = new LinkedHashSet<>(accessStrategy.getAllColumns(entity.getJavaType())); String idColumnName = getIdColumnName(); - BindIdOperation select = accessStrategy.selectById(entity.getTableName(), columns, idColumnName); - GenericExecuteSpec sql = databaseClient.execute().sql(select); - BindSpecAdapter wrapper = BindSpecAdapter.create(sql); - select.bindId(wrapper, id); + BindMarkers bindMarkers = accessStrategy.getBindMarkersFactory().create(); + BindMarker bindMarker = bindMarkers.next("id"); - return wrapper.getBoundOperation().as(entity.getJavaType()) // + Table table = Table.create(entity.getTableName()); + Select select = StatementBuilder.select(table.columns(columns)).from(table) + .where(Conditions.isEqual(table.column(idColumnName), SQL.bindMarker(bindMarker.getPlaceholder()))).build(); + + return databaseClient.execute().sql(SqlRenderer.render(select)) // + .bind(0, id) // + .as(entity.getJavaType()) // .fetch() // .one(); } @@ -146,14 +162,16 @@ public class SimpleR2dbcRepository implements ReactiveCrudRepository wrapper = BindSpecAdapter.create(sql); - select.bindId(wrapper, id); + BindMarkers bindMarkers = accessStrategy.getBindMarkersFactory().create(); + BindMarker bindMarker = bindMarkers.next("id"); - return wrapper.getBoundOperation().as(entity.getJavaType()) // + Table table = Table.create(entity.getTableName()); + Select select = StatementBuilder.select(table.column(idColumnName)).from(table) + .where(Conditions.isEqual(table.column(idColumnName), SQL.bindMarker(bindMarker.getPlaceholder()))).build(); + + return databaseClient.execute().sql(SqlRenderer.render(select)) // + .bind(0, id) // .map((r, md) -> r) // .first() // .hasElement(); @@ -202,12 +220,26 @@ public class SimpleR2dbcRepository implements ReactiveCrudRepository columns = new LinkedHashSet<>(accessStrategy.getAllColumns(entity.getJavaType())); String idColumnName = getIdColumnName(); - BindIdOperation select = accessStrategy.selectByIdIn(entity.getTableName(), columns, idColumnName); - BindSpecAdapter wrapper = BindSpecAdapter.create(databaseClient.execute().sql(select)); - select.bindIds(wrapper, ids); + BindMarkers bindMarkers = accessStrategy.getBindMarkersFactory().create(); + + List markers = new ArrayList<>(); - return wrapper.getBoundOperation().as(entity.getJavaType()).fetch().all(); + for (int i = 0; i < ids.size(); i++) { + markers.add(SQL.bindMarker(bindMarkers.next("id").getPlaceholder())); + } + + Table table = Table.create(entity.getTableName()); + Select select = StatementBuilder.select(table.columns(columns)).from(table) + .where(Conditions.in(table.column(idColumnName), markers)).build(); + + GenericExecuteSpec executeSpec = databaseClient.execute().sql(SqlRenderer.render(select)); + + for (int i = 0; i < ids.size(); i++) { + executeSpec = executeSpec.bind(i, ids.get(i)); + } + + return executeSpec.as(entity.getJavaType()).fetch().all(); }); } @@ -217,8 +249,10 @@ public class SimpleR2dbcRepository implements ReactiveCrudRepository count() { - return databaseClient.execute() - .sql(String.format("SELECT COUNT(%s) FROM %s", getIdColumnName(), entity.getTableName())) // + Table table = Table.create(entity.getTableName()); + Select select = StatementBuilder.select(Functions.count(table.column(getIdColumnName()))).from(table).build(); + + return databaseClient.execute().sql(SqlRenderer.render(select)) // .map((r, md) -> r.get(0, Long.class)) // .first() // .defaultIfEmpty(0L); diff --git a/src/main/java/org/springframework/data/r2dbc/support/StatementRenderUtil.java b/src/main/java/org/springframework/data/r2dbc/support/StatementRenderUtil.java new file mode 100644 index 000000000..184cf6784 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/support/StatementRenderUtil.java @@ -0,0 +1,65 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.support; + +import java.util.OptionalLong; + +import org.springframework.data.r2dbc.dialect.Dialect; +import org.springframework.data.r2dbc.dialect.LimitClause; +import org.springframework.data.relational.core.sql.Select; +import org.springframework.data.relational.core.sql.render.SqlRenderer; + +/** + * Utility class to assist with SQL rendering. Mainly for internal use within the framework. + * + * @author Mark Paluch + */ +public abstract class StatementRenderUtil { + + /** + * Render {@link Select} to SQL considering {@link Dialect} specifics. + * + * @param select must not be {@literal null}. + * @param limit must not be {@literal null}. + * @param offset must not be {@literal null}. + * @param dialect must not be {@literal null}. + * @return the rendered SQL statement. + */ + public static String render(Select select, OptionalLong limit, OptionalLong offset, Dialect dialect) { + + String sql = SqlRenderer.render(select); + + // TODO: Replace with proper {@link Dialect} rendering for limit/offset. + if (limit.isPresent()) { + + LimitClause limitClause = dialect.limit(); + + String clause; + if (offset.isPresent()) { + clause = limitClause.getClause(limit.getAsLong(), offset.getAsLong()); + } else { + clause = limitClause.getClause(limit.getAsLong()); + } + + return sql + " " + clause; + } + + return sql; + } + + private StatementRenderUtil() {} + +} diff --git a/src/test/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategyUnitTests.java b/src/test/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategyUnitTests.java index 01e1b1622..2f1e51ec6 100644 --- a/src/test/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategyUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategyUnitTests.java @@ -42,48 +42,6 @@ public class DefaultReactiveDataAccessStrategyUnitTests { assertThat(operation.toQuery()).isEqualTo("UPDATE table SET firstname = $2, lastname = $3 WHERE id = $1"); } - @Test // gh-20 - public void shouldRenderSelectByIdQuery() { - - BindableOperation operation = strategy.selectById("table", new HashSet<>(Arrays.asList("firstname", "lastname")), - "id"); - - assertThat(operation.toQuery()).isEqualTo("SELECT firstname, lastname FROM table WHERE id = $1"); - } - - @Test // gh-20 - public void shouldRenderSelectByIdQueryWithLimit() { - - BindableOperation operation = strategy.selectById("table", new HashSet<>(Arrays.asList("firstname", "lastname")), - "id", 10); - - assertThat(operation.toQuery()) - .isEqualTo("SELECT firstname, lastname FROM table WHERE id = $1 ORDER BY id LIMIT 10"); - } - - @Test // gh-20 - public void shouldFailRenderingSelectByIdInQueryWithoutBindings() { - - BindableOperation operation = strategy.selectByIdIn("table", new HashSet<>(Arrays.asList("firstname", "lastname")), - "id"); - - assertThatThrownBy(operation::toQuery).isInstanceOf(UnsupportedOperationException.class); - } - - @Test // gh-20 - public void shouldRenderSelectByIdInQuery() { - - Statement statement = mock(Statement.class); - BindIdOperation operation = strategy.selectByIdIn("table", new HashSet<>(Arrays.asList("firstname", "lastname")), - "id"); - - operation.bindId(statement, Collections.singleton("foo")); - assertThat(operation.toQuery()).isEqualTo("SELECT firstname, lastname FROM table WHERE id IN ($1)"); - - operation.bindId(statement, "bar"); - assertThat(operation.toQuery()).isEqualTo("SELECT firstname, lastname FROM table WHERE id IN ($1, $2)"); - } - @Test // gh-20 public void shouldRenderDeleteByIdQuery() {