diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ConnectionFunction.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ConnectionFunction.java index 4d1b2652947..12a9f8fc80a 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ConnectionFunction.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ConnectionFunction.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,34 +20,18 @@ import java.util.function.Function; import io.r2dbc.spi.Connection; + /** * Union type combining {@link Function} and {@link SqlProvider} to expose the SQL that is - * related to the underlying action. + * related to the underlying action. The SqlProvider can support lazy / generate once semantics, + * in which case {@link #getSql()} can be {@code null} until the {@code #apply(Connection)} + * method is invoked. * * @author Mark Paluch + * @author Simon Baslé * @since 5.3 * @param the type of the result of the function. */ -class ConnectionFunction implements Function, SqlProvider { - - private final String sql; - - private final Function function; - - - ConnectionFunction(String sql, Function function) { - this.sql = sql; - this.function = function; - } - - - @Override - public R apply(Connection t) { - return this.function.apply(t); - } - - @Override - public String getSql() { - return this.sql; - } +interface ConnectionFunction extends Function, SqlProvider { } + diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java index 13f1e00aca3..8ef54b37159 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -79,7 +79,10 @@ public interface DatabaseClient extends ConnectionAccessor { * the execution. The SQL string can contain either native parameter * bind markers or named parameters (e.g. {@literal :foo, :bar}) when * {@link NamedParameterExpander} is enabled. - *

Accepts {@link PreparedOperation} as SQL and binding {@link Supplier} + *

Accepts {@link PreparedOperation} as SQL and binding {@link Supplier}. + *

{code DatabaseClient} implementations should defer the resolution of + * the SQL string as much as possible, ideally up to the point where a + * {@code Subscription} happens. This is the case for the default implementation. * @param sqlSupplier a supplier for the SQL statement * @return a new {@link GenericExecuteSpec} * @see NamedParameterExpander diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java index 91b21d74f22..d7933fae1b6 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,6 +60,7 @@ import org.springframework.util.StringUtils; * @author Mark Paluch * @author Mingyuan Wu * @author Bogdan Ilchyshyn + * @author Simon Baslé * @since 5.3 */ class DefaultDatabaseClient implements DatabaseClient { @@ -322,9 +323,8 @@ class DefaultDatabaseClient implements DatabaseClient { return fetch().rowsUpdated().then(); } - private FetchSpec execute(Supplier sqlSupplier, BiFunction mappingFunction) { - String sql = getRequiredSql(sqlSupplier); - Function statementFunction = connection -> { + private ResultFunction getResultFunction(Supplier sqlSupplier) { + BiFunction statementFunction = (connection, sql) -> { if (logger.isDebugEnabled()) { logger.debug("Executing SQL statement [" + sql + "]"); } @@ -370,16 +370,16 @@ class DefaultDatabaseClient implements DatabaseClient { return statement; }; - Function> resultFunction = connection -> { - Statement statement = statementFunction.apply(connection); - return Flux.from(this.filterFunction.filter(statement, DefaultDatabaseClient.this.executeFunction)) - .cast(Result.class).checkpoint("SQL \"" + sql + "\" [DatabaseClient]"); - }; + return new ResultFunction(sqlSupplier, statementFunction, this.filterFunction, DefaultDatabaseClient.this.executeFunction); + } + + private FetchSpec execute(Supplier sqlSupplier, BiFunction mappingFunction) { + ResultFunction resultHandler = getResultFunction(sqlSupplier); return new DefaultFetchSpec<>( - DefaultDatabaseClient.this, sql, - new ConnectionFunction<>(sql, resultFunction), - new ConnectionFunction<>(sql, connection -> sumRowsUpdated(resultFunction, connection)), + DefaultDatabaseClient.this, + resultHandler, + connection -> sumRowsUpdated(resultHandler, connection), mappingFunction); } diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultFetchSpec.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultFetchSpec.java index 85d8bd311a1..b92f96c2462 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultFetchSpec.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultFetchSpec.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ import java.util.function.BiFunction; import java.util.function.Function; import io.r2dbc.spi.Connection; -import io.r2dbc.spi.Result; import io.r2dbc.spi.Row; import io.r2dbc.spi.RowMetadata; import reactor.core.publisher.Flux; @@ -32,6 +31,7 @@ import org.springframework.dao.IncorrectResultSizeDataAccessException; * Default {@link FetchSpec} implementation. * * @author Mark Paluch + * @author Simon Baslé * @since 5.3 * @param the row result type */ @@ -39,24 +39,21 @@ class DefaultFetchSpec implements FetchSpec { private final ConnectionAccessor connectionAccessor; - private final String sql; - - private final Function> resultFunction; + private final ResultFunction resultFunction; private final Function> updatedRowsFunction; private final BiFunction mappingFunction; - DefaultFetchSpec(ConnectionAccessor connectionAccessor, String sql, - Function> resultFunction, + DefaultFetchSpec(ConnectionAccessor connectionAccessor, + ResultFunction resultFunction, Function> updatedRowsFunction, BiFunction mappingFunction) { - this.sql = sql; this.connectionAccessor = connectionAccessor; this.resultFunction = resultFunction; - this.updatedRowsFunction = updatedRowsFunction; + this.updatedRowsFunction = new DelegateConnectionFunction<>(resultFunction, updatedRowsFunction); this.mappingFunction = mappingFunction; } @@ -70,7 +67,7 @@ class DefaultFetchSpec implements FetchSpec { } if (list.size() > 1) { return Mono.error(new IncorrectResultSizeDataAccessException( - String.format("Query [%s] returned non unique result.", this.sql), + String.format("Query [%s] returned non unique result.", this.resultFunction.getSql()), 1)); } return Mono.just(list.get(0)); @@ -84,7 +81,7 @@ class DefaultFetchSpec implements FetchSpec { @Override public Flux all() { - return this.connectionAccessor.inConnectionMany(new ConnectionFunction<>(this.sql, + return this.connectionAccessor.inConnectionMany(new DelegateConnectionFunction<>(this.resultFunction, connection -> this.resultFunction.apply(connection) .flatMap(result -> result.map(this.mappingFunction)))); } diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DelegateConnectionFunction.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DelegateConnectionFunction.java new file mode 100644 index 00000000000..a25bddd7f66 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DelegateConnectionFunction.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.function.Function; + +import io.r2dbc.spi.Connection; + +import org.springframework.lang.Nullable; + +/** + * A {@link ConnectionFunction} that delegates to a {@code SqlProvider} and a plain + * {@code Function}. + * + * @author Simon Baslé + * @since 5.3.26 + * @param the type of the result of the function. + */ +final class DelegateConnectionFunction implements ConnectionFunction { + + private final SqlProvider sql; + + private final Function function; + + + DelegateConnectionFunction(SqlProvider sql, Function function) { + this.sql = sql; + this.function = function; + } + + + @Override + public R apply(Connection t) { + return this.function.apply(t); + } + + @Nullable + @Override + public String getSql() { + return this.sql.getSql(); + } +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ResultFunction.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ResultFunction.java new file mode 100644 index 00000000000..1204eac5df0 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ResultFunction.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Statement; +import reactor.core.publisher.Flux; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * A {@link ConnectionFunction} that produces a {@code Flux} of {@link Result} and that + * defers generation of the SQL until the function has been applied. + * Beforehand, the {@code getSql()} method simply returns {@code null}. The sql String is + * also memoized during application, so that subsequent calls to {@link #getSql()} return + * the same {@code String} without further calls to the {@code Supplier}. + * + * @author Mark Paluch + * @author Simon Baslé + * @since 5.3.26 + */ +final class ResultFunction implements ConnectionFunction> { + + final Supplier sqlSupplier; + final BiFunction statementFunction; + final StatementFilterFunction filterFunction; + final ExecuteFunction executeFunction; + + @Nullable + String resolvedSql = null; + + ResultFunction(Supplier sqlSupplier, BiFunction statementFunction, StatementFilterFunction filterFunction, ExecuteFunction executeFunction) { + this.sqlSupplier = sqlSupplier; + this.statementFunction = statementFunction; + this.filterFunction = filterFunction; + this.executeFunction = executeFunction; + } + + @Override + public Flux apply(Connection connection) { + String sql = this.sqlSupplier.get(); + Assert.state(StringUtils.hasText(sql), "SQL returned by supplier must not be empty"); + this.resolvedSql = sql; + Statement statement = this.statementFunction.apply(connection, sql); + return Flux.from(this.filterFunction.filter(statement, this.executeFunction)) + .cast(Result.class).checkpoint("SQL \"" + sql + "\" [DatabaseClient]"); + } + + @Nullable + @Override + public String getSql() { + return this.resolvedSql; + } +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java index 6f03e82bd42..302191b65a3 100644 --- a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java @@ -17,6 +17,8 @@ package org.springframework.r2dbc.core; import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import io.r2dbc.spi.Connection; import io.r2dbc.spi.ConnectionFactory; @@ -64,6 +66,7 @@ import static org.mockito.BDDMockito.when; * @author Mark Paluch * @author Ferdinand Jacobs * @author Jens Schauder + * @author Simon Baslé */ @ExtendWith(MockitoExtension.class) @MockitoSettings(strictness = Strictness.LENIENT) @@ -397,6 +400,47 @@ class DefaultDatabaseClientUnitTests { inOrder.verifyNoMoreInteractions(); } + @Test + void sqlSupplierInvocationIsDeferredUntilSubscription() { + // We'll have either 2 or 3 rows, depending on the subscription and the generated SQL + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata( + MockColumnMetadata.builder().name("id").javaType(Integer.class).build()).build(); + final MockRow row1 = MockRow.builder().identified("id", Integer.class, 1).build(); + final MockRow row2 = MockRow.builder().identified("id", Integer.class, 2).build(); + final MockRow row3 = MockRow.builder().identified("id", Integer.class, 3).build(); + // Set up 2 mock statements + mockStatementFor("SELECT id FROM test WHERE id < '3'", MockResult.builder() + .rowMetadata(metadata) + .row(row1, row2).build()); + mockStatementFor("SELECT id FROM test WHERE id < '4'", MockResult.builder() + .rowMetadata(metadata) + .row(row1, row2, row3).build()); + // Create the client + DatabaseClient databaseClient = this.databaseClientBuilder.build(); + + AtomicInteger invoked = new AtomicInteger(); + // Assemble a publisher, but don't subscribe yet + Mono> operation = databaseClient + .sql(() -> { + int idMax = 2 + invoked.incrementAndGet(); + return String.format("SELECT id FROM test WHERE id < '%s'", idMax); + }) + .map(r -> r.get("id", Integer.class)) + .all() + .collectList(); + + assertThat(invoked).as("invoked (before subscription)").hasValue(0); + + List rows = operation.block(); + assertThat(invoked).as("invoked (after 1st subscription)").hasValue(1); + assertThat(rows).containsExactly(1, 2); + + rows = operation.block(); + assertThat(invoked).as("invoked (after 2nd subscription)").hasValue(2); + assertThat(rows).containsExactly(1, 2, 3); + } + + private Statement mockStatement() { return mockStatementFor(null, null); }