diff --git a/src/main/java/org/springframework/data/jdbc/core/function/DefaultDatabaseClient.java b/src/main/java/org/springframework/data/jdbc/core/function/DefaultDatabaseClient.java index de640cac2..d47867635 100644 --- a/src/main/java/org/springframework/data/jdbc/core/function/DefaultDatabaseClient.java +++ b/src/main/java/org/springframework/data/jdbc/core/function/DefaultDatabaseClient.java @@ -102,25 +102,75 @@ class DefaultDatabaseClient implements DatabaseClient { return new DefaultInsertIntoSpec(); } - public Flux execute(Function> action) throws DataAccessException { + /** + * Execute a callback {@link Function} within a {@link Connection} scope. The function is responsible for creating a + * {@link Flux}. The connection is released after the {@link Flux} terminates (or the subscription is cancelled). + * Connection resources must not be passed outside of the {@link Function} closure, otherwise resources may get + * defunct. + * + * @param action must not be {@literal null}. + * @return the resulting {@link Flux}. + * @throws DataAccessException + */ + public Flux executeMany(Function> action) throws DataAccessException { Assert.notNull(action, "Callback object must not be null"); - Mono connectionMono = Mono.from(obtainConnectionFactory().create()); + Mono connectionMono = getConnection(); // Create close-suppressing Connection proxy, also preparing returned Statements. - return connectionMono.flatMapMany(connection -> { + return Flux.usingWhen(connectionMono, it -> { - Connection connectionToUse = createConnectionProxy(connection); + Connection connectionToUse = createConnectionProxy(it); - // TODO: Release connection - return doInConnection(action, connectionToUse); + return doInConnectionMany(connectionToUse, action); + }, this::closeConnection, this::closeConnection, this::closeConnection) // + .onErrorMap(SQLException.class, ex -> translateException("executeMany", getSql(action), ex)); + } - }).onErrorMap(SQLException.class, ex -> { + /** + * Execute a callback {@link Function} within a {@link Connection} scope. The function is responsible for creating a + * {@link Mono}. The connection is released after the {@link Mono} terminates (or the subscription is cancelled). + * Connection resources must not be passed outside of the {@link Function} closure, otherwise resources may get + * defunct. + * + * @param action must not be {@literal null}. + * @return the resulting {@link Mono}. + * @throws DataAccessException + */ + public Mono execute(Function> action) throws DataAccessException { - String sql = getSql(action); - return translateException("ConnectionCallback", sql, ex); - }); + Assert.notNull(action, "Callback object must not be null"); + + Mono connectionMono = getConnection(); + // Create close-suppressing Connection proxy, also preparing returned Statements. + + return Mono.usingWhen(connectionMono, it -> { + + Connection connectionToUse = createConnectionProxy(it); + + return doInConnection(connectionToUse, action); + }, this::closeConnection, this::closeConnection, this::closeConnection) // + .onErrorMap(SQLException.class, ex -> translateException("execute", getSql(action), ex)); + } + + /** + * Obtain a {@link Connection}. + * + * @return + */ + protected Mono getConnection() { + return Mono.from(obtainConnectionFactory().create()); + } + + /** + * Release the {@link Connection}. + * + * @param connection + * @return + */ + protected Publisher closeConnection(Connection connection) { + return connection.close(); } /** @@ -153,11 +203,12 @@ class DefaultDatabaseClient implements DatabaseClient { * @return a DataAccessException wrapping the {@code SQLException} (never {@code null}) */ protected DataAccessException translateException(String task, @Nullable String sql, SQLException ex) { + DataAccessException dae = exceptionTranslator.translate(task, sql, ex); return (dae != null ? dae : new UncategorizedSQLException(task, sql, ex)); } - static void doBind(Statement statement, Map> byName, + private static void doBind(Statement statement, Map> byName, Map> byIndex) { byIndex.forEach((i, o) -> { @@ -225,16 +276,26 @@ class DefaultDatabaseClient implements DatabaseClient { return sql; } - Mono> exchange(String sql, BiFunction mappingFunction) { + protected SqlResult exchange(String sql, BiFunction mappingFunction) { - return execute(it -> { + Function executeFunction = it -> { + + if (logger.isDebugEnabled()) { + logger.debug("Executing SQL statement [" + sql + "]"); + } Statement statement = it.createStatement(sql); doBind(statement, byName, byIndex); - return Flux - .just((SqlResult) new DefaultSqlResult<>(sql, Flux.from(statement.add().execute()), mappingFunction)); - }).next(); + return statement; + }; + + Function> resultFunction = it -> Flux.from(executeFunction.apply(it).execute()); + + return new DefaultSqlResultFunctions<>(sql, // + resultFunction, // + it -> resultFunction.apply(it).flatMap(Result::getRowsUpdated).next(), // + mappingFunction); } public GenericExecuteSpecSupport bind(int index, Object value) { @@ -310,15 +371,12 @@ class DefaultDatabaseClient implements DatabaseClient { @Override public FetchSpec> fetch() { - - String sql = getSql(); - return new DefaultFetchSpec<>(sql, exchange(sql, ColumnMapRowMapper.INSTANCE).flatMapMany(SqlResult::all), - exchange(sql, ColumnMapRowMapper.INSTANCE).flatMap(FetchSpec::rowsUpdated)); + return exchange(getSql(), ColumnMapRowMapper.INSTANCE); } @Override public Mono>> exchange() { - return exchange(getSql(), ColumnMapRowMapper.INSTANCE); + return Mono.just(exchange(getSql(), ColumnMapRowMapper.INSTANCE)); } @Override @@ -381,14 +439,12 @@ class DefaultDatabaseClient implements DatabaseClient { @Override public FetchSpec fetch() { - String sql = getSql(); - return new DefaultFetchSpec<>(sql, exchange(sql, mappingFunction).flatMapMany(SqlResult::all), - exchange(sql, mappingFunction).flatMap(FetchSpec::rowsUpdated)); + return exchange(getSql(), mappingFunction); } @Override public Mono> exchange() { - return exchange(getSql(), mappingFunction); + return Mono.just(exchange(getSql(), mappingFunction)); } @Override @@ -472,11 +528,15 @@ class DefaultDatabaseClient implements DatabaseClient { @Override public Mono then() { - return exchange().flatMapMany(FetchSpec::all).then(); + return exchange((row, md) -> row).all().then(); } @Override public Mono>> exchange() { + return Mono.just(exchange(ColumnMapRowMapper.INSTANCE)); + } + + private SqlResult exchange(BiFunction mappingFunction) { if (byName.isEmpty()) { throw new IllegalStateException("Insert fields is empty!"); @@ -490,26 +550,43 @@ class DefaultDatabaseClient implements DatabaseClient { builder.append("INSERT INTO ").append(table).append(" (").append(fieldNames).append(") ").append(" VALUES(") .append(placeholders).append(")"); - return execute(it -> { + String sql = builder.toString(); + Function insertFunction = it -> { - String sql = builder.toString(); + if (logger.isDebugEnabled()) { + logger.debug("Executing SQL statement [" + sql + "]"); + } Statement statement = it.createStatement(sql); + doBind(statement); + return statement; + }; - AtomicInteger index = new AtomicInteger(); - for (Optional o : byName.values()) { + Function> resultFunction = it -> Flux + .from(insertFunction.apply(it).executeReturningGeneratedKeys()); - if (o.isPresent()) { - o.ifPresent(v -> statement.bind(index.getAndIncrement(), v)); - } else { - statement.bindNull("$" + (index.getAndIncrement() + 1), 0); // TODO: What is type? - } - } + return new DefaultSqlResultFunctions<>(sql, // + resultFunction, // + it -> resultFunction.apply(it).flatMap(Result::getRowsUpdated).next(), // + mappingFunction); + } + + /** + * PostgreSQL-specific bind. + * + * @param statement + */ + private void doBind(Statement statement) { - SqlResult> result = new DefaultSqlResult<>(sql, - Flux.from(statement.executeReturningGeneratedKeys()), ColumnMapRowMapper.INSTANCE); - return Flux.just(result); + AtomicInteger index = new AtomicInteger(); - }).next(); + for (Optional o : byName.values()) { + + if (o.isPresent()) { + o.ifPresent(v -> statement.bind(index.getAndIncrement(), v)); + } else { + statement.bindNull("$" + (index.getAndIncrement() + 1), 0); // TODO: What is type? + } + } } } @@ -523,7 +600,7 @@ class DefaultDatabaseClient implements DatabaseClient { private final String table; private final Publisher objectToInsert; - public DefaultTypedInsertSpec(Class typeToInsert) { + DefaultTypedInsertSpec(Class typeToInsert) { this.typeToInsert = typeToInsert; this.table = dataAccessStrategy.getTableName(typeToInsert); @@ -556,69 +633,84 @@ class DefaultDatabaseClient implements DatabaseClient { @Override public Mono then() { - return exchange().flatMapMany(FetchSpec::all).then(); + return Mono.from(objectToInsert).map(toInsert -> exchange(toInsert, (row, md) -> row).all()).then(); } @Override public Mono>> exchange() { + return Mono.from(objectToInsert).map(toInsert -> exchange(toInsert, ColumnMapRowMapper.INSTANCE)); + } - return Mono.from(objectToInsert).flatMap(toInsert -> { + private SqlResult exchange(Object toInsert, BiFunction mappingFunction) { - StringBuilder builder = new StringBuilder(); + StringBuilder builder = new StringBuilder(); - List> insertValues = dataAccessStrategy.getInsert(toInsert); - String fieldNames = insertValues.stream().map(Pair::getFirst).collect(Collectors.joining(",")); - String placeholders = IntStream.range(0, insertValues.size()).mapToObj(i -> "$" + (i + 1)) - .collect(Collectors.joining(",")); + List> insertValues = dataAccessStrategy.getInsert(toInsert); + String fieldNames = insertValues.stream().map(Pair::getFirst).collect(Collectors.joining(",")); + String placeholders = IntStream.range(0, insertValues.size()).mapToObj(i -> "$" + (i + 1)) + .collect(Collectors.joining(",")); - builder.append("INSERT INTO ").append(table).append(" (").append(fieldNames).append(") ").append(" VALUES(") - .append(placeholders).append(")"); + builder.append("INSERT INTO ").append(table).append(" (").append(fieldNames).append(") ").append(" VALUES(") + .append(placeholders).append(")"); - return execute(it -> { + String sql = builder.toString(); - String sql = builder.toString(); - Statement statement = it.createStatement(sql); + Function insertFunction = it -> { - AtomicInteger index = new AtomicInteger(); + if (logger.isDebugEnabled()) { + logger.debug("Executing SQL statement [" + sql + "]"); + } - for (Pair pair : insertValues) { + Statement statement = it.createStatement(sql); - if (pair.getSecond() != null) { // TODO: Better type to transport null values. - statement.bind(index.getAndIncrement(), pair.getSecond()); - } else { - statement.bindNull("$" + (index.getAndIncrement() + 1), 0); // TODO: What is type? - } + AtomicInteger index = new AtomicInteger(); + + for (Pair pair : insertValues) { + + if (pair.getSecond() != null) { // TODO: Better type to transport null values. + statement.bind(index.getAndIncrement(), pair.getSecond()); + } else { + statement.bindNull("$" + (index.getAndIncrement() + 1), 0); // TODO: What is type? } + } + + return statement; + }; - SqlResult> result = new DefaultSqlResult<>(sql, - Flux.from(statement.executeReturningGeneratedKeys()), ColumnMapRowMapper.INSTANCE); - return Flux.just(result); + Function> resultFunction = it -> Flux + .from(insertFunction.apply(it).executeReturningGeneratedKeys()); - }).next(); - }); + return new DefaultSqlResultFunctions<>(sql, // + resultFunction, // + it -> resultFunction.apply(it).flatMap(Result::getRowsUpdated).next(), // + mappingFunction); } } /** - * Default {@link org.springframework.data.jdbc.core.function.DatabaseClient.SqlResult} implementation. + * Default {@link org.springframework.data.jdbc.core.function.SqlResult} implementation. */ - static class DefaultSqlResult implements SqlResult { + class DefaultSqlResultFunctions implements SqlResult { private final String sql; - private final Flux result; + private final Function> resultFunction; + private final Function> updatedRowsFunction; private final FetchSpec fetchSpec; - DefaultSqlResult(String sql, Flux result, BiFunction mappingFunction) { + DefaultSqlResultFunctions(String sql, Function> resultFunction, + Function> updatedRowsFunction, BiFunction mappingFunction) { this.sql = sql; - this.result = result; - this.fetchSpec = new DefaultFetchSpec<>(sql, result.flatMap(it -> it.map(mappingFunction)), - result.flatMap(Result::getRowsUpdated).next()); + this.resultFunction = resultFunction; + this.updatedRowsFunction = updatedRowsFunction; + + this.fetchSpec = new DefaultFetchFunctions<>(sql, + it -> resultFunction.apply(it).flatMap(result -> result.map(mappingFunction)), updatedRowsFunction); } @Override public SqlResult extract(BiFunction mappingFunction) { - return new DefaultSqlResult<>(sql, result, mappingFunction); + return new DefaultSqlResultFunctions<>(sql, resultFunction, updatedRowsFunction, mappingFunction); } @Override @@ -643,11 +735,11 @@ class DefaultDatabaseClient implements DatabaseClient { } @RequiredArgsConstructor - static class DefaultFetchSpec implements FetchSpec { + class DefaultFetchFunctions implements FetchSpec { private final String sql; - private final Flux result; - private final Mono updatedRows; + private final Function> resultFunction; + private final Function> updatedRowsFunction; @Override public Mono one() { @@ -675,19 +767,19 @@ class DefaultDatabaseClient implements DatabaseClient { @Override public Flux all() { - return result; + return executeMany(resultFunction); } @Override public Mono rowsUpdated() { - return updatedRows; + return execute(updatedRowsFunction); } } - private static Flux doInConnection(Function> action, Connection it) { + private static Flux doInConnectionMany(Connection connection, Function> action) { try { - return action.apply(it); + return action.apply(connection); } catch (RuntimeException e) { String sql = getSql(action); @@ -695,6 +787,17 @@ class DefaultDatabaseClient implements DatabaseClient { } } + private static Mono doInConnection(Connection connection, Function> action) { + + try { + return action.apply(connection); + } catch (RuntimeException e) { + + String sql = getSql(action); + return Mono.error(new DefaultDatabaseClient.UncategorizedSQLException("ConnectionCallback", sql, e) {}); + } + } + /** * Determine SQL from potential provider object. * @@ -764,7 +867,7 @@ class DefaultDatabaseClient implements DatabaseClient { } } - private static class UncategorizedSQLException extends UncategorizedDataAccessException { + private static class UncategorizedSQLException extends UncategorizedDataAccessException implements SqlProvider { /** SQL that led to the problem */ @Nullable private final String sql; diff --git a/src/test/java/org/springframework/data/jdbc/core/function/DatabaseClientIntegrationTests.java b/src/test/java/org/springframework/data/jdbc/core/function/DatabaseClientIntegrationTests.java index 53b4bdc65..ab8a49d70 100644 --- a/src/test/java/org/springframework/data/jdbc/core/function/DatabaseClientIntegrationTests.java +++ b/src/test/java/org/springframework/data/jdbc/core/function/DatabaseClientIntegrationTests.java @@ -121,12 +121,29 @@ public class DatabaseClientIntegrationTests { .value("name", "SCHAUFELRADBAGGER") // .nullValue("manual") // .exchange() // - .flatMapMany(it -> it.extract((r, m) -> r.get("id", Integer.class)).all()).as(StepVerifier::create) // + .flatMapMany(it -> it.extract((r, m) -> r.get("id", Integer.class)).all()) // + .as(StepVerifier::create) // .expectNext(42055).verifyComplete(); assertThat(jdbc.queryForMap("SELECT id, name, manual FROM legoset")).containsEntry("id", 42055); } + @Test + public void insertWithoutResult() { + + DatabaseClient databaseClient = DatabaseClient.create(connectionFactory); + + databaseClient.insert().into("legoset")// + .value("id", 42055) // + .value("name", "SCHAUFELRADBAGGER") // + .nullValue("manual") // + .then() // + .as(StepVerifier::create) // + .verifyComplete(); + + assertThat(jdbc.queryForMap("SELECT id, name, manual FROM legoset")).containsEntry("id", 42055); + } + @Test public void insertTypedObject() {