diff --git a/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClientBuilder.java b/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClientBuilder.java index 8ae87bf7f..57c64fcd8 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClientBuilder.java +++ b/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClientBuilder.java @@ -33,9 +33,9 @@ import org.springframework.util.Assert; */ class DefaultDatabaseClientBuilder implements DatabaseClient.Builder { - private @Nullable ConnectionFactory connector; - private @Nullable R2dbcExceptionTranslator exceptionTranslator; - private ReactiveDataAccessStrategy accessStrategy = new DefaultReactiveDataAccessStrategy(); + @Nullable ConnectionFactory connector; + @Nullable R2dbcExceptionTranslator exceptionTranslator; + ReactiveDataAccessStrategy accessStrategy = new DefaultReactiveDataAccessStrategy(); DefaultDatabaseClientBuilder() {} @@ -44,7 +44,7 @@ class DefaultDatabaseClientBuilder implements DatabaseClient.Builder { Assert.notNull(other, "DefaultDatabaseClientBuilder must not be null!"); this.connector = other.connector; - this.exceptionTranslator = exceptionTranslator; + this.exceptionTranslator = other.exceptionTranslator; } @Override @@ -83,8 +83,12 @@ class DefaultDatabaseClientBuilder implements DatabaseClient.Builder { exceptionTranslator = new SqlErrorCodeR2dbcExceptionTranslator(connector); } - return new DefaultDatabaseClient(this.connector, exceptionTranslator, accessStrategy, - new DefaultDatabaseClientBuilder(this)); + return doBuild(this.connector, exceptionTranslator, this.accessStrategy, new DefaultDatabaseClientBuilder(this)); + } + + protected DatabaseClient doBuild(ConnectionFactory connector, R2dbcExceptionTranslator exceptionTranslator, + ReactiveDataAccessStrategy accessStrategy, DefaultDatabaseClientBuilder builder) { + return new DefaultDatabaseClient(connector, exceptionTranslator, accessStrategy, builder); } @Override diff --git a/src/main/java/org/springframework/data/r2dbc/function/DefaultTransactionalDatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/function/DefaultTransactionalDatabaseClient.java new file mode 100644 index 000000000..9206c244f --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/DefaultTransactionalDatabaseClient.java @@ -0,0 +1,167 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; +import reactor.util.function.Tuple2; + +import java.util.function.Function; + +import org.reactivestreams.Publisher; +import org.springframework.data.r2dbc.function.connectionfactory.ConnectionFactoryUtils; +import org.springframework.data.r2dbc.function.connectionfactory.ReactiveTransactionSynchronization; +import org.springframework.data.r2dbc.function.connectionfactory.TransactionResources; +import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator; +import org.springframework.transaction.NoTransactionException; + +/** + * Default implementation of a {@link TransactionalDatabaseClient}. + * + * @author Mark Paluch + */ +class DefaultTransactionalDatabaseClient extends DefaultDatabaseClient implements TransactionalDatabaseClient { + + DefaultTransactionalDatabaseClient(ConnectionFactory connector, R2dbcExceptionTranslator exceptionTranslator, + ReactiveDataAccessStrategy dataAccessStrategy, DefaultDatabaseClientBuilder builder) { + super(connector, exceptionTranslator, dataAccessStrategy, builder); + } + + @Override + public TransactionalDatabaseClient.Builder mutate() { + return (TransactionalDatabaseClient.Builder) super.mutate(); + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.TransactionalDatabaseClient#beginTransaction() + */ + @Override + public Mono beginTransaction() { + + Mono transactional = ConnectionFactoryUtils.currentReactiveTransactionSynchronization() // + .map(synchronization -> { + + TransactionResources transactionResources = TransactionResources.create(); + // TODO: This Tx management code creating a TransactionContext. Find a better place. + synchronization.registerTransaction(transactionResources); + return transactionResources; + }); + + return transactional.flatMap(it -> { + return ConnectionFactoryUtils.doGetConnection(obtainConnectionFactory()); + }).flatMap(it -> Mono.from(it.getT1().beginTransaction())); + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.TransactionalDatabaseClient#commitTransaction() + */ + @Override + public Mono commitTransaction() { + return cleanup(Connection::commitTransaction); + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.TransactionalDatabaseClient#rollbackTransaction() + */ + @Override + public Mono rollbackTransaction() { + return cleanup(Connection::rollbackTransaction); + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.TransactionalDatabaseClient#inTransaction(java.util.function.Function) + */ + @Override + public Flux inTransaction(Function> callback) { + + return Flux.usingWhen(beginTransaction().thenReturn(this), callback, // + DefaultTransactionalDatabaseClient::commitTransaction, // + DefaultTransactionalDatabaseClient::rollbackTransaction, // + DefaultTransactionalDatabaseClient::rollbackTransaction) // + .subscriberContext(DefaultTransactionalDatabaseClient::withTransactionSynchronization); + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.DefaultDatabaseClient#getConnection() + */ + @Override + protected Mono getConnection() { + return ConnectionFactoryUtils.getConnection(obtainConnectionFactory()).map(Tuple2::getT1); + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.DefaultDatabaseClient#closeConnection(io.r2dbc.spi.Connection) + */ + @Override + protected Publisher closeConnection(Connection connection) { + + return Mono.subscriberContext().flatMap(context -> { + + if (context.hasKey(ReactiveTransactionSynchronization.class)) { + + return ConnectionFactoryUtils.currentConnectionFactory() + .flatMap(it -> ConnectionFactoryUtils.releaseConnection(connection, it)); + } + + return Mono.from(connection.close()); + }); + } + + /** + * Execute a transactional cleanup. Also, deregister the current {@link TransactionResources synchronization} element. + */ + private static Mono cleanup(Function> callback) { + + return ConnectionFactoryUtils.currentActiveReactiveTransactionSynchronization() // + .flatMap(synchronization -> { + + TransactionResources currentSynchronization = synchronization.getCurrentTransaction(); + + ConnectionFactory connectionFactory = currentSynchronization.getResource(ConnectionFactory.class); + + if (connectionFactory == null) { + throw new NoTransactionException("No ConnectionFactory attached"); + } + + return Mono.from(connectionFactory.create()) + .flatMap(connection -> Mono.from(callback.apply(connection)) + .then(ConnectionFactoryUtils.releaseConnection(connection, connectionFactory)) + .then(ConnectionFactoryUtils.closeConnection(connection, connectionFactory))) // TODO: Is this rather + // related to + // TransactionContext + // cleanup? + .doFinally(s -> synchronization.unregisterTransaction(currentSynchronization)); + }); + } + + /** + * Potentially register a {@link ReactiveTransactionSynchronization} in the {@link Context} if no synchronization + * object is registered. + * + * @param context the subscriber context. + * @return subscriber context with a registered synchronization. + */ + static Context withTransactionSynchronization(Context context) { + + // associate synchronizer object to host transactional resources. + // TODO: Should be moved to a better place. + return context.put(ReactiveTransactionSynchronization.class, + context.getOrDefault(ReactiveTransactionSynchronization.class, new ReactiveTransactionSynchronization())); + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/function/DefaultTransactionalDatabaseClientBuilder.java b/src/main/java/org/springframework/data/r2dbc/function/DefaultTransactionalDatabaseClientBuilder.java new file mode 100644 index 000000000..7e322e059 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/DefaultTransactionalDatabaseClientBuilder.java @@ -0,0 +1,99 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function; + +import io.r2dbc.spi.ConnectionFactory; + +import java.util.function.Consumer; + +import org.springframework.data.r2dbc.function.DatabaseClient.Builder; +import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator; +import org.springframework.util.Assert; + +/** + * @author Mark Paluch + */ +class DefaultTransactionalDatabaseClientBuilder extends DefaultDatabaseClientBuilder + implements TransactionalDatabaseClient.Builder { + + DefaultTransactionalDatabaseClientBuilder() {} + + DefaultTransactionalDatabaseClientBuilder(DefaultDatabaseClientBuilder other) { + + Assert.notNull(other, "DefaultDatabaseClientBuilder must not be null!"); + + this.connector = other.connector; + this.exceptionTranslator = other.exceptionTranslator; + } + + @Override + public DatabaseClient.Builder clone() { + return new DefaultTransactionalDatabaseClientBuilder(this); + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.DefaultDatabaseClientBuilder#connectionFactory(io.r2dbc.spi.ConnectionFactory) + */ + @Override + public TransactionalDatabaseClient.Builder connectionFactory(ConnectionFactory factory) { + super.connectionFactory(factory); + return this; + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.DefaultDatabaseClientBuilder#exceptionTranslator(org.springframework.data.r2dbc.support.R2dbcExceptionTranslator) + */ + @Override + public TransactionalDatabaseClient.Builder exceptionTranslator(R2dbcExceptionTranslator exceptionTranslator) { + super.exceptionTranslator(exceptionTranslator); + return this; + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.DefaultDatabaseClientBuilder#dataAccessStrategy(org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy) + */ + @Override + public TransactionalDatabaseClient.Builder dataAccessStrategy(ReactiveDataAccessStrategy accessStrategy) { + super.dataAccessStrategy(accessStrategy); + return this; + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.DefaultDatabaseClientBuilder#apply(java.util.function.Consumer) + */ + @Override + public TransactionalDatabaseClient.Builder apply(Consumer builderConsumer) { + super.apply(builderConsumer); + return this; + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.DefaultDatabaseClientBuilder#build() + */ + @Override + public TransactionalDatabaseClient build() { + return (TransactionalDatabaseClient) super.build(); + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.DefaultDatabaseClientBuilder#doBuild(io.r2dbc.spi.ConnectionFactory, org.springframework.data.r2dbc.support.R2dbcExceptionTranslator, org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy, org.springframework.data.r2dbc.function.DefaultDatabaseClientBuilder) + */ + @Override + protected DatabaseClient doBuild(ConnectionFactory connector, R2dbcExceptionTranslator exceptionTranslator, + ReactiveDataAccessStrategy accessStrategy, DefaultDatabaseClientBuilder builder) { + return new DefaultTransactionalDatabaseClient(connector, exceptionTranslator, accessStrategy, builder); + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/function/TransactionalDatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/function/TransactionalDatabaseClient.java new file mode 100644 index 000000000..a747032ca --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/TransactionalDatabaseClient.java @@ -0,0 +1,202 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function; + +import io.r2dbc.spi.ConnectionFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.util.function.Consumer; +import java.util.function.Function; + +import org.reactivestreams.Publisher; +import org.springframework.data.r2dbc.function.connectionfactory.TransactionResources; +import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator; +import org.springframework.util.Assert; + +/** + * {@link DatabaseClient} that participates in an ongoing transaction if the subscription happens within a hosted + * transaction. Alternatively, transactions can be started and cleaned up using {@link #beginTransaction()} and + * {@link #commitTransaction()}. + *

+ * Transactional resources are bound to {@link ReactiveTransactionSynchronization} through nested + * {@link TransactionContext} enabling nested (parallel) transactions. The simplemost approach to use transactions is by + * using {@link #inTransaction(Function)} which will start a transaction and commit it on successful termination. The + * callback allows execution of multiple statements within the same transaction. + * + *

+ * Flux transactionalFlux = databaseClient.inTransaction(db -> {
+ *
+ * 	return db.execute().sql("INSERT INTO person (id, firstname, lastname) VALUES($1, $2, $3)") //
+ * 			.bind(0, 1) //
+ * 			.bind(1, "Walter") //
+ * 			.bind(2, "White") //
+ * 			.fetch().rowsUpdated();
+ * });
+ * 
+ * + * Alternatively, transactions can be controlled by using {@link #beginTransaction()} and {@link #commitTransaction()} + * methods. This approach requires {@link #enableTransactionSynchronization(Publisher) enabling of transaction + * synchronization} for the transactional operation. + * + *
+ * Mono mono = databaseClient.beginTransaction()
+ * 		.then(databaseClient.execute().sql("INSERT INTO person (id, firstname, lastname) VALUES($1, $2, $3)") //
+ * 				.bind(0, 1) //
+ * 				.bind(1, "Walter") //
+ * 				.bind(2, "White") //
+ * 				.fetch().rowsUpdated())
+ * 		.then(databaseClient.commitTransaction());
+ *
+ * Mono transactionalMono = databaseClient.enableTransactionSynchronization(mono);
+ * 
+ *

+ * This {@link DatabaseClient} can be safely used without transaction synchronization to invoke database functionality + * in auto-commit transactions. + * + * @author Mark Paluch + * @see #inTransaction(Function) + * @see #enableTransactionSynchronization(Publisher) + * @see #beginTransaction() + * @see #commitTransaction() + * @see #rollbackTransaction() + * @see org.springframework.data.r2dbc.function.connectionfactory.ReactiveTransactionSynchronization + * @see TransactionResources + * @see org.springframework.data.r2dbc.function.connectionfactory.ConnectionFactoryUtils + */ +public interface TransactionalDatabaseClient extends DatabaseClient { + + /** + * Start a transaction and bind connection resources to the subscriber context. + * + * @return + */ + Mono beginTransaction(); + + /** + * Commit a transaction and unbind connection resources from the subscriber context. + * + * @return + * @throws org.springframework.transaction.NoTransactionException if no transaction is ongoing. + */ + Mono commitTransaction(); + + /** + * Rollback a transaction and unbind connection resources from the subscriber context. + * + * @return + * @throws org.springframework.transaction.NoTransactionException if no transaction is ongoing. + */ + Mono rollbackTransaction(); + + /** + * Execute a {@link Function} accepting a {@link DatabaseClient} within a managed transaction. {@link Exception Error + * signals} cause the transaction to be rolled back. + * + * @param callback + * @return the callback result. + */ + Flux inTransaction(Function> callback); + + /** + * Enable transaction management so that connections can be bound to the subscription. + * + * @param publisher must not be {@literal null}. + * @return the Transaction-enabled {@link Mono}. + */ + default Mono enableTransactionSynchronization(Mono publisher) { + + Assert.notNull(publisher, "Publisher must not be null!"); + + return publisher.subscriberContext(DefaultTransactionalDatabaseClient::withTransactionSynchronization); + } + + /** + * Enable transaction management so that connections can be bound to the subscription. + * + * @param publisher must not be {@literal null}. + * @return the Transaction-enabled {@link Flux}. + */ + default Flux enableTransactionSynchronization(Publisher publisher) { + + Assert.notNull(publisher, "Publisher must not be null!"); + + return Flux.from(publisher).subscriberContext(DefaultTransactionalDatabaseClient::withTransactionSynchronization); + } + + /** + * Return a builder to mutate properties of this database client. + */ + TransactionalDatabaseClient.Builder mutate(); + + // Static, factory methods + + /** + * A variant of {@link #create(ConnectionFactory)} that accepts a {@link io.r2dbc.spi.ConnectionFactory}. + */ + static TransactionalDatabaseClient create(ConnectionFactory factory) { + return (TransactionalDatabaseClient) new DefaultTransactionalDatabaseClientBuilder().connectionFactory(factory) + .build(); + } + + /** + * Obtain a {@code DatabaseClient} builder. + */ + static TransactionalDatabaseClient.Builder builder() { + return new DefaultTransactionalDatabaseClientBuilder(); + } + + /** + * A mutable builder for creating a {@link TransactionalDatabaseClient}. + */ + interface Builder extends DatabaseClient.Builder { + + /** + * Configures the {@link ConnectionFactory R2DBC connector}. + * + * @param factory must not be {@literal null}. + * @return {@code this} {@link DatabaseClient.Builder}. + */ + Builder connectionFactory(ConnectionFactory factory); + + /** + * Configures a {@link R2dbcExceptionTranslator}. + * + * @param exceptionTranslator must not be {@literal null}. + * @return {@code this} {@link DatabaseClient.Builder}. + */ + Builder exceptionTranslator(R2dbcExceptionTranslator exceptionTranslator); + + /** + * Configures a {@link ReactiveDataAccessStrategy}. + * + * @param accessStrategy must not be {@literal null}. + * @return {@code this} {@link DatabaseClient.Builder}. + */ + Builder dataAccessStrategy(ReactiveDataAccessStrategy accessStrategy); + + /** + * Configures a {@link Consumer} to configure this builder. + * + * @param builderConsumer must not be {@literal null}. + * @return {@code this} {@link DatabaseClient.Builder}. + */ + Builder apply(Consumer builderConsumer); + + @Override + TransactionalDatabaseClient build(); + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/ConnectionFactoryUtils.java b/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/ConnectionFactoryUtils.java new file mode 100644 index 000000000..d2e60f962 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/ConnectionFactoryUtils.java @@ -0,0 +1,249 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function.connectionfactory; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.dao.DataAccessResourceFailureException; +import org.springframework.lang.Nullable; +import org.springframework.transaction.NoTransactionException; +import org.springframework.util.Assert; + +/** + * Helper class that provides static methods for obtaining R2DBC Connections from a + * {@link io.r2dbc.spi.ConnectionFactory}. + *

+ * Used internally by Spring's {@link org.springframework.data.r2dbc.function.DatabaseClient}, Spring's R2DBC operation + * objects. Can also be used directly in application code. + * + * @author Mark Paluch + */ +public class ConnectionFactoryUtils { + + private static final Log logger = LogFactory.getLog(ConnectionFactoryUtils.class); + + /** + * Obtain a {@link io.r2dbc.spi.Connection} from the given {@link io.r2dbc.spi.ConnectionFactory}. Translates + * exceptions into the Spring hierarchy of unchecked generic data access exceptions, simplifying calling code and + * making any exception that is thrown more meaningful. + *

+ * Is aware of a corresponding Connection bound to the current {@link reactor.util.context.Context}. Will bind a + * Connection to the {@link reactor.util.context.Context} if transaction synchronization is active. + * + * @param connectionFactory the {@link io.r2dbc.spi.ConnectionFactory} to obtain Connections from + * @return a R2DBC Connection from the given {@link io.r2dbc.spi.ConnectionFactory}. + * @throws DataAccessResourceFailureException if the attempt to get a {@link io.r2dbc.spi.Connection} failed + * @see #releaseConnection + */ + public static Mono> getConnection(ConnectionFactory connectionFactory) { + return doGetConnection(connectionFactory) + .onErrorMap(e -> new DataAccessResourceFailureException("Failed to obtain R2DBC Connection", e)); + } + + /** + * Actually obtain a R2DBC Connection from the given {@link ConnectionFactory}. Same as {@link #getConnection}, but + * preserving the original exceptions. + *

+ * Is aware of a corresponding Connection bound to the current {@link reactor.util.context.Context}. Will bind a + * Connection to the {@link reactor.util.context.Context} if transaction synchronization is active. + * + * @param connectionFactory the {@link ConnectionFactory} to obtain Connections from. + * @return a R2DBC {@link io.r2dbc.spi.Connection} from the given {@link ConnectionFactory}. + */ + public static Mono> doGetConnection(ConnectionFactory connectionFactory) { + + Assert.notNull(connectionFactory, "ConnectionFactory must not be null!"); + + return Mono.subscriberContext().flatMap(it -> { + + if (it.hasKey(ReactiveTransactionSynchronization.class)) { + + ReactiveTransactionSynchronization synchronization = it.get(ReactiveTransactionSynchronization.class); + + return obtainConnection(synchronization, connectionFactory); + } + return Mono.empty(); + }).switchIfEmpty(Mono.defer(() -> { + return Mono.from(connectionFactory.create()).map(it -> Tuples.of(it, connectionFactory)); + })); + } + + private static Mono> obtainConnection( + ReactiveTransactionSynchronization synchronization, ConnectionFactory connectionFactory) { + + if (synchronization.isSynchronizationActive()) { + + logger.debug("Registering transaction synchronization for R2DBC Connection"); + + TransactionResources txContext = synchronization.getCurrentTransaction(); + ConnectionFactory resource = txContext.getResource(ConnectionFactory.class); + + Mono> attachNewConnection = Mono + .defer(() -> Mono.from(connectionFactory.create()).map(it -> { + + logger.debug("Fetching new R2DBC Connection from ConnectionFactory"); + + SingletonConnectionFactory s = new SingletonConnectionFactory(connectionFactory.getMetadata(), it); + txContext.registerResource(ConnectionFactory.class, s); + + return Tuples.of(it, connectionFactory); + })); + + return Mono.justOrEmpty(resource).flatMap(factory -> { + + logger.debug("Fetching resumed R2DBC Connection from ConnectionFactory"); + + return Mono.from(factory.create()) + .map(connection -> Tuples. of(connection, factory)); + + }).switchIfEmpty(attachNewConnection); + } + + return Mono.empty(); + } + + /** + * Close the given {@link io.r2dbc.spi.Connection}, obtained from the given {@link ConnectionFactory}, if it is not + * managed externally (that is, not bound to the thread). + * + * @param con the {@link io.r2dbc.spi.Connection} to close if necessary. + * @param connectionFactory the {@link ConnectionFactory} that the Connection was obtained from (may be + * {@literal null}). + * @see #getConnection + */ + public static Mono releaseConnection(@Nullable io.r2dbc.spi.Connection con, + @Nullable ConnectionFactory connectionFactory) { + + return doReleaseConnection(con, connectionFactory) + .onErrorMap(e -> new DataAccessResourceFailureException("Failed to close R2DBC Connection", e)); + } + + /** + * Actually close the given {@link io.r2dbc.spi.Connection}, obtained from the given {@link ConnectionFactory}. Same + * as {@link #releaseConnection}, but preserving the original exception. + * + * @param con the {@link io.r2dbc.spi.Connection} to close if necessary. + * @param connectionFactory the {@link ConnectionFactory} that the Connection was obtained from (may be + * {@literal null}). + * @see #doGetConnection + */ + public static Mono doReleaseConnection(@Nullable io.r2dbc.spi.Connection con, + @Nullable ConnectionFactory connectionFactory) { + + if (connectionFactory instanceof SingletonConnectionFactory) { + + SingletonConnectionFactory factory = (SingletonConnectionFactory) connectionFactory; + + logger.debug("Releasing R2DBC Connection"); + + return factory.close(con); + } + + logger.debug("Closing R2DBC Connection"); + + return Mono.from(con.close()); + } + + /** + * Close the {@link io.r2dbc.spi.Connection}. Translates exceptions into the Spring hierarchy of unchecked generic + * data access exceptions, simplifying calling code and making any exception that is thrown more meaningful. + * + * @param connectionFactory the {@link io.r2dbc.spi.ConnectionFactory} to obtain Connections from + * @return a R2DBC Connection from the given {@link io.r2dbc.spi.ConnectionFactory}. + * @throws DataAccessResourceFailureException if the attempt to get a {@link io.r2dbc.spi.Connection} failed + */ + public static Mono closeConnection(Connection connection, ConnectionFactory connectionFactory) { + return doCloseConnection(connection, connectionFactory) + .onErrorMap(e -> new DataAccessResourceFailureException("Failed to obtain R2DBC Connection", e)); + } + + /** + * Close the {@link io.r2dbc.spi.Connection}, unless a {@link SmartConnectionFactory} doesn't want us to. + * + * @param connection the {@link io.r2dbc.spi.Connection} to close if necessary. + * @param connectionFactory the {@link ConnectionFactory} that the Connection was obtained from. + * @see Connection#close() + * @see SmartConnectionFactory#shouldClose(Connection) + */ + public static Mono doCloseConnection(Connection connection, @Nullable ConnectionFactory connectionFactory) { + + if (!(connectionFactory instanceof SingletonConnectionFactory) + || ((SingletonConnectionFactory) connectionFactory).shouldClose(connection)) { + + SingletonConnectionFactory factory = (SingletonConnectionFactory) connectionFactory; + return factory.close(connection).then(Mono.from(connection.close())); + } + + return Mono.empty(); + } + + /** + * Obtain the currently {@link ReactiveTransactionSynchronization} from the current subscriber + * {@link reactor.util.context.Context}. + * + * @see Mono#subscriberContext() + * @see ReactiveTransactionSynchronization + * @throws NoTransactionException if no active {@link ReactiveTransactionSynchronization} is associated with the + * current subscription. + */ + public static Mono currentReactiveTransactionSynchronization() { + + return Mono.subscriberContext().filter(it -> it.hasKey(ReactiveTransactionSynchronization.class)) // + .switchIfEmpty(Mono.error(new NoTransactionException( + "Transaction management is not enabled. Make sure to register ReactiveTransactionSynchronization in the subscriber Context!"))) // + .map(it -> it.get(ReactiveTransactionSynchronization.class)); + } + + /** + * Obtain the currently active {@link ReactiveTransactionSynchronization} from the current subscriber + * {@link reactor.util.context.Context}. + * + * @see Mono#subscriberContext() + * @see ReactiveTransactionSynchronization + * @throws NoTransactionException if no active {@link ReactiveTransactionSynchronization} is associated with the + * current subscription. + */ + public static Mono currentActiveReactiveTransactionSynchronization() { + + return currentReactiveTransactionSynchronization() + .filter(ReactiveTransactionSynchronization::isSynchronizationActive) // + .switchIfEmpty(Mono.error(new NoTransactionException("ReactiveTransactionSynchronization not active!"))); + } + + /** + * Obtain the {@link io.r2dbc.spi.ConnectionFactory} from the current subscriber {@link reactor.util.context.Context}. + * + * @see Mono#subscriberContext() + * @see ReactiveTransactionSynchronization + * @see TransactionResources + */ + public static Mono currentConnectionFactory() { + + return currentActiveReactiveTransactionSynchronization() // + .map(synchronization -> { + + TransactionResources currentSynchronization = synchronization.getCurrentTransaction(); + return currentSynchronization.getResource(ConnectionFactory.class); + }).switchIfEmpty(Mono.error(new DataAccessResourceFailureException( + "Cannot extract ConnectionFactory from current TransactionContext!"))); + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/DefaultTransactionResources.java b/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/DefaultTransactionResources.java new file mode 100644 index 000000000..853448dc1 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/DefaultTransactionResources.java @@ -0,0 +1,51 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function.connectionfactory; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.springframework.util.Assert; + +/** + * Default implementation of {@link TransactionResources}. + * + * @author Mark Paluch + */ +class DefaultTransactionResources implements TransactionResources { + + private Map, Object> items = new ConcurrentHashMap<>(); + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.connectionfactory.TransactionResources#registerResource(java.lang.Class, java.lang.Object) + */ + @Override + public void registerResource(Class key, T value) { + + Assert.state(!items.containsKey(key), () -> String.format("Resource for %s is already bound", key)); + + items.put(key, value); + } + + /* (non-Javadoc) + * @see org.springframework.data.r2dbc.function.connectionfactory.TransactionResources#getResource(java.lang.Class) + */ + @SuppressWarnings("unchecked") + @Override + public T getResource(Class key) { + return (T) items.get(key); + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/ReactiveTransactionSynchronization.java b/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/ReactiveTransactionSynchronization.java new file mode 100644 index 000000000..af2ab9344 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/ReactiveTransactionSynchronization.java @@ -0,0 +1,87 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function.connectionfactory; + +import java.util.Stack; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Central delegate that manages transactional resources. To be used by resource management code but not by typical + * application code. + *

+ * Supports a list of transactional resources if synchronization is active. + *

+ * Resource management code should check for subscriber {@link reactor.util.context.Context}-bound resources, e.g. R2DBC + * Connections using {@link TransactionResources#getResource(Class)}. Such code is normally not supposed to bind + * resources, as this is the responsibility of transaction managers. A further option is to lazily bind on first use if + * transaction synchronization is active, for performing transactions that span an arbitrary number of resources. + *

+ * Transaction synchronization must be activated and deactivated by a transaction manager by registering + * {@link ReactiveTransactionSynchronization} in the {@link reactor.util.context.Context subscriber context}. + * + * @author Mark Paluch + */ +public class ReactiveTransactionSynchronization { + + private Stack resources = new Stack<>(); + + /** + * Return if transaction synchronization is active for the current {@link reactor.util.context.Context}. Can be called + * before register to avoid unnecessary instance creation. + */ + public boolean isSynchronizationActive() { + return !resources.isEmpty(); + } + + /** + * Create a new transaction span and register a {@link TransactionResources} instance. + * + * @param transactionResources must not be {@literal null}. + */ + public void registerTransaction(TransactionResources transactionResources) { + + Assert.notNull(transactionResources, "TransactionContext must not be null!"); + + resources.push(transactionResources); + } + + /** + * Unregister a transaction span and by removing {@link TransactionResources} instance. + * + * @param transactionResources must not be {@literal null}. + */ + public void unregisterTransaction(TransactionResources transactionResources) { + + Assert.notNull(transactionResources, "TransactionContext must not be null!"); + + resources.remove(transactionResources); + } + + /** + * @return obtain the current {@link TransactionResources} or {@literal null} if none is present. + */ + @Nullable + public TransactionResources getCurrentTransaction() { + + if (!resources.isEmpty()) { + return resources.peek(); + } + + return null; + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/SingletonConnectionFactory.java b/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/SingletonConnectionFactory.java new file mode 100644 index 000000000..5610bf791 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/SingletonConnectionFactory.java @@ -0,0 +1,85 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function.connectionfactory; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactoryMetadata; +import reactor.core.publisher.Mono; + +import java.util.concurrent.atomic.AtomicInteger; + +import org.reactivestreams.Publisher; + +/** + * Connection holder, wrapping a R2DBC Connection. + * {@link org.springframework.data.r2dbc.function.TransactionalDatabaseClient} binds instances of this class to the + * {@link TransactionResources} for a specific subscription. + * + * @author Mark Paluch + */ +class SingletonConnectionFactory implements SmartConnectionFactory { + + private final ConnectionFactoryMetadata metadata; + private final Connection connection; + private final Mono connectionMono; + private final AtomicInteger refCount = new AtomicInteger(); + + SingletonConnectionFactory(ConnectionFactoryMetadata metadata, Connection connection) { + + this.metadata = metadata; + this.connection = connection; + this.connectionMono = Mono.just(connection); + } + + /* (non-Javadoc) + * @see io.r2dbc.spi.ConnectionFactory#create() + */ + @Override + public Publisher create() { + + if (refCount.get() == -1) { + throw new IllegalStateException("Connection is closed!"); + } + + return connectionMono.doOnSubscribe(s -> refCount.incrementAndGet()); + } + + /* (non-Javadoc) + * @see io.r2dbc.spi.ConnectionFactory#getMetadata() + */ + @Override + public ConnectionFactoryMetadata getMetadata() { + return metadata; + } + + private boolean connectionEquals(Connection connection) { + return this.connection == connection; + } + + @Override + public boolean shouldClose(Connection connection) { + return refCount.get() == 1; + } + + Mono close(Connection connection) { + + if (connectionEquals(connection)) { + return Mono. empty().doOnSubscribe(s -> refCount.decrementAndGet()); + } + + throw new IllegalArgumentException("Connection is not associated with this connection factory"); + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/SmartConnectionFactory.java b/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/SmartConnectionFactory.java new file mode 100644 index 000000000..221cdc54f --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/SmartConnectionFactory.java @@ -0,0 +1,44 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function.connectionfactory; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; + +/** + * Extension of the {@code io.r2dbc.spi.ConnectionFactory} interface, to be implemented by special connection factories + * that return R2DBC Connections in an unwrapped fashion. + *

+ * Classes using this interface can query whether or not the {@link Connection} should be closed after an operation. + * Spring's {@link ConnectionFactoryUtils} automatically perform such a check. + * + * @author Mark Paluch + * @see ConnectionFactoryUtils#closeConnection + */ +public interface SmartConnectionFactory extends ConnectionFactory { + + /** + * Should we close this {@link io.r2dbc.spi.Connection}, obtained from this {@code io.r2dbc.spi.ConnectionFactory}? + *

+ * Code that uses Connections from a SmartConnectionFactory should always perform a check via this method before + * invoking {@code close()}. + * + * @param connection the {@link io.r2dbc.spi.Connection} to check. + * @return whether the given {@link Connection} should be closed. + * @see io.r2dbc.spi.Connection#close() + */ + boolean shouldClose(Connection connection); +} diff --git a/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/TransactionResources.java b/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/TransactionResources.java new file mode 100644 index 000000000..119c6ca87 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/connectionfactory/TransactionResources.java @@ -0,0 +1,58 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function.connectionfactory; + +import reactor.core.publisher.Mono; + +/** + * Transaction context for an ongoing transaction synchronization allowing to register transactional resources. + *

+ * Supports one resource per key without overwriting, that is, a resource needs to be removed before a new one can be + * set for the same key. + *

+ * Primarily used by {@link ConnectionFactoryUtils} but can be also used by application code to register resources that + * should be bound to a transaction. + * + * @author Mark Paluch + */ +public interface TransactionResources { + + /** + * Creates a new empty {@link TransactionResources}. + * + * @return the empty {@link TransactionResources}. + */ + static TransactionResources create() { + return new DefaultTransactionResources(); + } + + /** + * Retrieve a resource from this context identified by {@code key}. + * + * @param key the resource key. + * @return the resource emitted through {@link Mono} or {@link Mono#empty()} if the resource was not found. + */ + T getResource(Class key); + + /** + * Register a resource in this context. + * + * @param key the resource key. + * @param value can be a subclass of the {@code key} type. + * @throws IllegalStateException if a resource is already bound under {@code key}. + */ + void registerResource(Class key, T value); +} diff --git a/src/test/java/org/springframework/data/r2dbc/function/TransactionalDatabaseClientIntegrationTests.java b/src/test/java/org/springframework/data/r2dbc/function/TransactionalDatabaseClientIntegrationTests.java new file mode 100644 index 000000000..ade4922cc --- /dev/null +++ b/src/test/java/org/springframework/data/r2dbc/function/TransactionalDatabaseClientIntegrationTests.java @@ -0,0 +1,181 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function; + +import static org.assertj.core.api.Assertions.*; + +import io.r2dbc.spi.ConnectionFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.ArrayBlockingQueue; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.data.jdbc.testing.R2dbcIntegrationTestSupport; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.transaction.NoTransactionException; + +/** + * Integration tests for {@link TransactionalDatabaseClient}. + * + * @author Mark Paluch + */ +public class TransactionalDatabaseClientIntegrationTests extends R2dbcIntegrationTestSupport { + + private ConnectionFactory connectionFactory; + + private JdbcTemplate jdbc; + + @Before + public void before() { + + Hooks.onOperatorDebug(); + + connectionFactory = createConnectionFactory(); + + String tableToCreate = "CREATE TABLE IF NOT EXISTS legoset (\n" + + " id integer CONSTRAINT id PRIMARY KEY,\n" + " name varchar(255) NOT NULL,\n" + + " manual integer NULL\n" + ");"; + + jdbc = createJdbcTemplate(createDataSource()); + jdbc.execute(tableToCreate); + jdbc.execute("DELETE FROM legoset"); + } + + @Test + public void executeInsertInManagedTransaction() { + + TransactionalDatabaseClient databaseClient = TransactionalDatabaseClient.create(connectionFactory); + + Flux integerFlux = databaseClient.inTransaction(db -> { + + return db.execute().sql("INSERT INTO legoset (id, name, manual) VALUES($1, $2, $3)") // + .bind(0, 42055) // + .bind(1, "SCHAUFELRADBAGGER") // + .bindNull("$3") // + .fetch().rowsUpdated(); + }); + + integerFlux.as(StepVerifier::create) // + .expectNext(1) // + .verifyComplete(); + + assertThat(jdbc.queryForMap("SELECT id, name, manual FROM legoset")).containsEntry("id", 42055); + } + + @Test + public void executeInsertInAutoCommitTransaction() { + + TransactionalDatabaseClient databaseClient = TransactionalDatabaseClient.create(connectionFactory); + + Mono integerFlux = databaseClient.execute() + .sql("INSERT INTO legoset (id, name, manual) VALUES($1, $2, $3)") // + .bind(0, 42055) // + .bind(1, "SCHAUFELRADBAGGER") // + .bindNull("$3") // + .fetch().rowsUpdated(); + + integerFlux.as(StepVerifier::create) // + .expectNext(1) // + .verifyComplete(); + + assertThat(jdbc.queryForMap("SELECT id, name, manual FROM legoset")).containsEntry("id", 42055); + } + + @Test + public void shouldManageUserTransaction() { + + Queue transactionIds = new ArrayBlockingQueue<>(5); + TransactionalDatabaseClient databaseClient = TransactionalDatabaseClient.create(connectionFactory); + + Flux txId = databaseClient.execute().sql("SELECT txid_current();").exchange() + .flatMapMany(it -> it.extract((r, md) -> r.get(0, Long.class)).all()); + + Mono then = databaseClient.enableTransactionSynchronization(databaseClient.beginTransaction() // + .thenMany(txId.concatWith(txId).doOnNext(transactionIds::add)) // + .then(databaseClient.rollbackTransaction())); + + then.as(StepVerifier::create) // + .verifyComplete(); + + List listOfTxIds = new ArrayList<>(transactionIds); + assertThat(listOfTxIds).hasSize(2); + assertThat(listOfTxIds).containsExactly(listOfTxIds.get(1), listOfTxIds.get(0)); + } + + @Test + public void userTransactionManagementShouldFailWithoutSynchronizer() { + + TransactionalDatabaseClient databaseClient = TransactionalDatabaseClient.create(connectionFactory); + + Mono then = databaseClient.beginTransaction().then(databaseClient.rollbackTransaction()); + + then.as(StepVerifier::create) // + .consumeErrorWith(exception -> { + + assertThat(exception).isInstanceOf(NoTransactionException.class) + .hasMessageContaining("Transaction management is not enabled"); + }).verify(); + } + + @Test + public void shouldRollbackTransaction() { + + TransactionalDatabaseClient databaseClient = TransactionalDatabaseClient.create(connectionFactory); + + Flux integerFlux = databaseClient.inTransaction(db -> { + + return db.execute().sql("INSERT INTO legoset (id, name, manual) VALUES($1, $2, $3)") // + .bind(0, 42055) // + .bind(1, "SCHAUFELRADBAGGER") // + .bindNull("$3") // + .fetch().rowsUpdated().then(Mono.error(new IllegalStateException("failed"))); + }); + + integerFlux.as(StepVerifier::create) // + .expectError(IllegalStateException.class) // + .verify(); + + assertThat(jdbc.queryForMap("SELECT count(*) FROM legoset")).containsEntry("count", 0L); + } + + @Test + public void emitTransactionIds() { + + TransactionalDatabaseClient databaseClient = TransactionalDatabaseClient.create(connectionFactory); + + Flux transactionIds = databaseClient.inTransaction(db -> { + + Flux txId = db.execute().sql("SELECT txid_current();").exchange() + .flatMapMany(it -> it.extract((r, md) -> r.get(0, Long.class)).all()); + return txId.concatWith(txId); + }); + + transactionIds.collectList().as(StepVerifier::create) // + .consumeNextWith(actual -> { + + assertThat(actual).hasSize(2); + assertThat(actual).containsExactly(actual.get(1), actual.get(0)); + }) // + .verifyComplete(); + } +} diff --git a/src/test/java/org/springframework/data/r2dbc/function/connectionfactory/ConnectionFactoryUtilsUnitTests.java b/src/test/java/org/springframework/data/r2dbc/function/connectionfactory/ConnectionFactoryUtilsUnitTests.java new file mode 100644 index 000000000..756340587 --- /dev/null +++ b/src/test/java/org/springframework/data/r2dbc/function/connectionfactory/ConnectionFactoryUtilsUnitTests.java @@ -0,0 +1,93 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function.connectionfactory; + +import static org.mockito.Mockito.*; + +import io.r2dbc.spi.ConnectionFactory; +import reactor.test.StepVerifier; + +import org.junit.Test; +import org.springframework.transaction.NoTransactionException; + +/** + * Unit tests for {@link ConnectionFactoryUtils}. + * + * @author Mark Paluch + */ +public class ConnectionFactoryUtilsUnitTests { + + @Test + public void currentReactiveTransactionSynchronizationShouldReportSynchronization() { + + ConnectionFactoryUtils.currentReactiveTransactionSynchronization() // + .subscriberContext( + it -> it.put(ReactiveTransactionSynchronization.class, new ReactiveTransactionSynchronization())) + .as(StepVerifier::create) // + .expectNextCount(1) // + .verifyComplete(); + } + + @Test + public void currentReactiveTransactionSynchronizationShouldFailWithoutTxMgmt() { + + ConnectionFactoryUtils.currentReactiveTransactionSynchronization() // + .as(StepVerifier::create) // + .expectError(NoTransactionException.class) // + .verify(); + } + + @Test + public void currentActiveReactiveTransactionSynchronizationShouldReportSynchronization() { + + ConnectionFactoryUtils.currentActiveReactiveTransactionSynchronization() // + .subscriberContext(it -> { + ReactiveTransactionSynchronization sync = new ReactiveTransactionSynchronization(); + sync.registerTransaction(TransactionResources.create()); + return it.put(ReactiveTransactionSynchronization.class, sync); + }).as(StepVerifier::create) // + .expectNextCount(1) // + .verifyComplete(); + } + + @Test + public void currentActiveReactiveTransactionSynchronization() { + + ConnectionFactoryUtils.currentActiveReactiveTransactionSynchronization() // + .subscriberContext( + it -> it.put(ReactiveTransactionSynchronization.class, new ReactiveTransactionSynchronization())) + .as(StepVerifier::create) // + .expectError(NoTransactionException.class) // + .verify(); + } + + @Test + public void currentConnectionFactoryShouldReportConnectionFactory() { + + ConnectionFactory factoryMock = mock(ConnectionFactory.class); + + ConnectionFactoryUtils.currentConnectionFactory() // + .subscriberContext(it -> { + ReactiveTransactionSynchronization sync = new ReactiveTransactionSynchronization(); + TransactionResources resources = TransactionResources.create(); + resources.registerResource(ConnectionFactory.class, factoryMock); + sync.registerTransaction(resources); + return it.put(ReactiveTransactionSynchronization.class, sync); + }).as(StepVerifier::create) // + .expectNext(factoryMock) // + .verifyComplete(); + } +} diff --git a/src/test/java/org/springframework/data/r2dbc/repository/R2dbcRepositoryIntegrationTests.java b/src/test/java/org/springframework/data/r2dbc/repository/R2dbcRepositoryIntegrationTests.java index 5a6679cdc..cc22129bb 100644 --- a/src/test/java/org/springframework/data/r2dbc/repository/R2dbcRepositoryIntegrationTests.java +++ b/src/test/java/org/springframework/data/r2dbc/repository/R2dbcRepositoryIntegrationTests.java @@ -27,6 +27,8 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import java.util.Arrays; +import java.util.Collections; +import java.util.Map; import org.junit.Before; import org.junit.Test; @@ -36,6 +38,7 @@ import org.springframework.data.jdbc.repository.query.Query; import org.springframework.data.jdbc.testing.R2dbcIntegrationTestSupport; import org.springframework.data.r2dbc.function.DatabaseClient; import org.springframework.data.r2dbc.function.DefaultReactiveDataAccessStrategy; +import org.springframework.data.r2dbc.function.TransactionalDatabaseClient; import org.springframework.data.r2dbc.repository.support.R2dbcRepositoryFactory; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.Table; @@ -129,6 +132,34 @@ public class R2dbcRepositoryIntegrationTests extends R2dbcIntegrationTestSupport }).verifyComplete(); } + @Test + public void shouldInsertItemsTransactional() { + + TransactionalDatabaseClient client = TransactionalDatabaseClient.builder().connectionFactory(connectionFactory) + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(mappingContext, new EntityInstantiators())).build(); + + LegoSetRepository transactionalRepository = new R2dbcRepositoryFactory(client, mappingContext) + .getRepository(LegoSetRepository.class); + + LegoSet legoSet1 = new LegoSet(null, "SCHAUFELRADBAGGER", 12); + LegoSet legoSet2 = new LegoSet(null, "FORSCHUNGSSCHIFF", 13); + + Flux> transactional = client.inTransaction(db -> { + + return transactionalRepository.save(legoSet1) // + .map(it -> jdbc.queryForMap("SELECT count(*) FROM repo_legoset")); + }); + + Mono> nonTransactional = transactionalRepository.save(legoSet2) // + .map(it -> jdbc.queryForMap("SELECT count(*) FROM repo_legoset")); + + transactional.as(StepVerifier::create).expectNext(Collections.singletonMap("count", 0L)).verifyComplete(); + nonTransactional.as(StepVerifier::create).expectNext(Collections.singletonMap("count", 2L)).verifyComplete(); + + Map count = jdbc.queryForMap("SELECT count(*) FROM repo_legoset"); + assertThat(count).containsEntry("count", 2L); + } + interface LegoSetRepository extends ReactiveCrudRepository { @Query("SELECT * FROM repo_legoset WHERE name like $1")