From eb1200415dee2b8e84bec3b9bedc55101d7653b7 Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Wed, 24 Mar 2021 20:20:59 +0000 Subject: [PATCH] Derive a ConnectionFactoryBuilder from an existing ConnectionFactory Closes gh-25788 --- .../r2dbc/R2dbcAutoConfigurationTests.java | 82 +++++---- .../SimpleBindMarkerFactoryProvider.java | 7 +- .../SimpleConnectionFactoryProvider.java | 15 +- spring-boot-project/spring-boot/build.gradle | 1 + .../boot/r2dbc/ConnectionFactoryBuilder.java | 152 ++++++++++++++++- .../OptionsCapableConnectionFactory.java | 70 ++++++++ .../r2dbc/ConnectionFactoryBuilderTests.java | 156 ++++++++++++++++++ 7 files changed, 438 insertions(+), 45 deletions(-) create mode 100644 spring-boot-project/spring-boot/src/main/java/org/springframework/boot/r2dbc/OptionsCapableConnectionFactory.java diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/r2dbc/R2dbcAutoConfigurationTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/r2dbc/R2dbcAutoConfigurationTests.java index 499c02792e0..f959a8e3aa0 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/r2dbc/R2dbcAutoConfigurationTests.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/r2dbc/R2dbcAutoConfigurationTests.java @@ -29,6 +29,10 @@ import io.r2dbc.pool.ConnectionPool; import io.r2dbc.pool.PoolMetrics; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.Option; +import io.r2dbc.spi.Wrapped; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.assertj.core.api.InstanceOfAssertFactory; +import org.assertj.core.api.ObjectAssert; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.BeanCreationException; @@ -36,6 +40,7 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.autoconfigure.r2dbc.SimpleConnectionFactoryProvider.SimpleTestConnectionFactory; import org.springframework.boot.r2dbc.EmbeddedDatabaseConnection; +import org.springframework.boot.r2dbc.OptionsCapableConnectionFactory; import org.springframework.boot.test.context.FilteredClassLoader; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -58,8 +63,14 @@ class R2dbcAutoConfigurationTests { @Test void configureWithUrlCreateConnectionPoolByDefault() { this.contextRunner.withPropertyValues("spring.r2dbc.url:r2dbc:h2:mem:///" + randomDatabaseName()) - .run((context) -> assertThat(context).hasSingleBean(ConnectionFactory.class) - .hasSingleBean(ConnectionPool.class)); + .run((context) -> { + assertThat(context).hasSingleBean(ConnectionFactory.class).hasSingleBean(ConnectionPool.class); + assertThat(context.getBean(ConnectionPool.class)).extracting(ConnectionPool::unwrap) + .satisfies((connectionFactory) -> assertThat(connectionFactory) + .asInstanceOf(type(OptionsCapableConnectionFactory.class)) + .extracting(Wrapped::unwrap) + .isExactlyInstanceOf(H2ConnectionFactory.class)); + }); } @Test @@ -113,7 +124,10 @@ class R2dbcAutoConfigurationTests { this.contextRunner.withPropertyValues("spring.r2dbc.pool.enabled=false", "spring.r2dbc.url:r2dbc:h2:mem:///" + randomDatabaseName() + "?options=DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE").run((context) -> { assertThat(context).hasSingleBean(ConnectionFactory.class).doesNotHaveBean(ConnectionPool.class); - assertThat(context.getBean(ConnectionFactory.class)).isExactlyInstanceOf(H2ConnectionFactory.class); + assertThat(context.getBean(ConnectionFactory.class)) + .asInstanceOf(type(OptionsCapableConnectionFactory.class)) + .extracting(Wrapped::unwrap) + .isExactlyInstanceOf(H2ConnectionFactory.class); }); } @@ -122,8 +136,10 @@ class R2dbcAutoConfigurationTests { this.contextRunner.with(hideConnectionPool()).withPropertyValues("spring.r2dbc.url:r2dbc:h2:mem:///" + randomDatabaseName() + "?options=DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE").run((context) -> { assertThat(context).hasSingleBean(ConnectionFactory.class); - ConnectionFactory bean = context.getBean(ConnectionFactory.class); - assertThat(bean).isExactlyInstanceOf(H2ConnectionFactory.class); + assertThat(context.getBean(ConnectionFactory.class)) + .asInstanceOf(type(OptionsCapableConnectionFactory.class)) + .extracting(Wrapped::unwrap) + .isExactlyInstanceOf(H2ConnectionFactory.class); }); } @@ -142,11 +158,10 @@ class R2dbcAutoConfigurationTests { .withPropertyValues("spring.r2dbc.pool.enabled=false", "spring.r2dbc.url:r2dbc:simple://host/database") .withUserConfiguration(CustomizerConfiguration.class).run((context) -> { assertThat(context).hasSingleBean(ConnectionFactory.class).doesNotHaveBean(ConnectionPool.class); - ConnectionFactory bean = context.getBean(ConnectionFactory.class); - assertThat(bean).isExactlyInstanceOf(SimpleTestConnectionFactory.class); - SimpleTestConnectionFactory connectionFactory = (SimpleTestConnectionFactory) bean; - assertThat(connectionFactory.getOptions().getRequiredValue(Option.valueOf("customized"))) - .isTrue(); + ConnectionFactory connectionFactory = context.getBean(ConnectionFactory.class); + assertThat(connectionFactory).asInstanceOf(type(OptionsCapableConnectionFactory.class)) + .extracting(OptionsCapableConnectionFactory::getOptions).satisfies((options) -> assertThat( + options.getRequiredValue(Option.valueOf("customized"))).isTrue()); }); } @@ -155,11 +170,11 @@ class R2dbcAutoConfigurationTests { this.contextRunner.withPropertyValues("spring.r2dbc.url:r2dbc:simple://host/database") .withUserConfiguration(CustomizerConfiguration.class).run((context) -> { assertThat(context).hasSingleBean(ConnectionFactory.class).hasSingleBean(ConnectionPool.class); - ConnectionFactory bean = context.getBean(ConnectionFactory.class); - SimpleTestConnectionFactory connectionFactory = (SimpleTestConnectionFactory) ((ConnectionPool) bean) - .unwrap(); - assertThat(connectionFactory.getOptions().getRequiredValue(Option.valueOf("customized"))) - .isTrue(); + ConnectionFactory pool = context.getBean(ConnectionFactory.class); + ConnectionFactory connectionFactory = ((ConnectionPool) pool).unwrap(); + assertThat(connectionFactory).asInstanceOf(type(OptionsCapableConnectionFactory.class)) + .extracting(OptionsCapableConnectionFactory::getOptions).satisfies((options) -> assertThat( + options.getRequiredValue(Option.valueOf("customized"))).isTrue()); }); } @@ -174,8 +189,10 @@ class R2dbcAutoConfigurationTests { this.contextRunner.withPropertyValues("spring.r2dbc.pool.enabled=false", "spring.r2dbc.url:r2dbc:simple://foo") .withClassLoader(new FilteredClassLoader("org.springframework.jdbc")).run((context) -> { assertThat(context).hasSingleBean(ConnectionFactory.class); - ConnectionFactory connectionFactory = context.getBean(ConnectionFactory.class); - assertThat(connectionFactory).isInstanceOf(SimpleTestConnectionFactory.class); + assertThat(context.getBean(ConnectionFactory.class)) + .asInstanceOf(type(OptionsCapableConnectionFactory.class)) + .extracting(Wrapped::unwrap) + .isExactlyInstanceOf(SimpleTestConnectionFactory.class); }); } @@ -183,9 +200,12 @@ class R2dbcAutoConfigurationTests { void configureWithoutPoolShouldApplyAdditionalProperties() { this.contextRunner.withPropertyValues("spring.r2dbc.pool.enabled=false", "spring.r2dbc.url:r2dbc:simple://foo", "spring.r2dbc.properties.test=value", "spring.r2dbc.properties.another=2").run((context) -> { - SimpleTestConnectionFactory connectionFactory = context.getBean(SimpleTestConnectionFactory.class); - assertThat(getRequiredOptionsValue(connectionFactory, "test")).isEqualTo("value"); - assertThat(getRequiredOptionsValue(connectionFactory, "another")).isEqualTo("2"); + ConnectionFactory connectionFactory = context.getBean(ConnectionFactory.class); + assertThat(connectionFactory).asInstanceOf(type(OptionsCapableConnectionFactory.class)) + .extracting(OptionsCapableConnectionFactory::getOptions).satisfies((options) -> { + assertThat(options.getRequiredValue(Option.valueOf("test"))).isEqualTo("value"); + assertThat(options.getRequiredValue(Option.valueOf("another"))).isEqualTo("2"); + }); }); } @@ -194,17 +214,15 @@ class R2dbcAutoConfigurationTests { this.contextRunner.withPropertyValues("spring.r2dbc.url:r2dbc:simple://foo", "spring.r2dbc.properties.test=value", "spring.r2dbc.properties.another=2").run((context) -> { assertThat(context).hasSingleBean(ConnectionFactory.class).hasSingleBean(ConnectionPool.class); - SimpleTestConnectionFactory connectionFactory = (SimpleTestConnectionFactory) context - .getBean(ConnectionPool.class).unwrap(); - assertThat(getRequiredOptionsValue(connectionFactory, "test")).isEqualTo("value"); - assertThat(getRequiredOptionsValue(connectionFactory, "another")).isEqualTo("2"); + ConnectionFactory connectionFactory = context.getBean(ConnectionPool.class).unwrap(); + assertThat(connectionFactory).asInstanceOf(type(OptionsCapableConnectionFactory.class)) + .extracting(OptionsCapableConnectionFactory::getOptions).satisfies((options) -> { + assertThat(options.getRequiredValue(Option.valueOf("test"))).isEqualTo("value"); + assertThat(options.getRequiredValue(Option.valueOf("another"))).isEqualTo("2"); + }); }); } - private Object getRequiredOptionsValue(SimpleTestConnectionFactory connectionFactory, String name) { - return connectionFactory.options.getRequiredValue(Option.valueOf(name)); - } - @Test void configureWithoutUrlShouldCreateEmbeddedConnectionPoolByDefault() { this.contextRunner.run((context) -> assertThat(context).hasSingleBean(ConnectionFactory.class) @@ -215,7 +233,9 @@ class R2dbcAutoConfigurationTests { void configureWithoutUrlAndPollPoolDisabledCreateGenericConnectionFactory() { this.contextRunner.withPropertyValues("spring.r2dbc.pool.enabled=false").run((context) -> { assertThat(context).hasSingleBean(ConnectionFactory.class).doesNotHaveBean(ConnectionPool.class); - assertThat(context.getBean(ConnectionFactory.class)).isExactlyInstanceOf(H2ConnectionFactory.class); + assertThat(context.getBean(ConnectionFactory.class)) + .asInstanceOf(type(OptionsCapableConnectionFactory.class)) + .extracting(Wrapped::unwrap).isExactlyInstanceOf(H2ConnectionFactory.class); }); } @@ -260,6 +280,10 @@ class R2dbcAutoConfigurationTests { .doesNotHaveBean(DatabaseClient.class)); } + private InstanceOfAssertFactory> type(Class type) { + return InstanceOfAssertFactories.type(type); + } + private String randomDatabaseName() { return "testdb-" + UUID.randomUUID(); } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/r2dbc/SimpleBindMarkerFactoryProvider.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/r2dbc/SimpleBindMarkerFactoryProvider.java index 8fbd3a47dbd..79d55a6aa4c 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/r2dbc/SimpleBindMarkerFactoryProvider.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/r2dbc/SimpleBindMarkerFactoryProvider.java @@ -16,8 +16,8 @@ package org.springframework.boot.autoconfigure.r2dbc; -import io.r2dbc.pool.ConnectionPool; import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Wrapped; import org.springframework.boot.autoconfigure.r2dbc.SimpleConnectionFactoryProvider.SimpleTestConnectionFactory; import org.springframework.r2dbc.core.binding.BindMarkersFactory; @@ -38,9 +38,10 @@ public class SimpleBindMarkerFactoryProvider implements BindMarkerFactoryProvide return null; } + @SuppressWarnings("unchecked") private ConnectionFactory unwrapIfNecessary(ConnectionFactory connectionFactory) { - if (connectionFactory instanceof ConnectionPool) { - return ((ConnectionPool) connectionFactory).unwrap(); + if (connectionFactory instanceof Wrapped) { + return unwrapIfNecessary(((Wrapped) connectionFactory).unwrap()); } return connectionFactory; } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/r2dbc/SimpleConnectionFactoryProvider.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/r2dbc/SimpleConnectionFactoryProvider.java index 2a589fbf795..dd8c40f1e60 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/r2dbc/SimpleConnectionFactoryProvider.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/r2dbc/SimpleConnectionFactoryProvider.java @@ -25,15 +25,16 @@ import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; /** - * Simple driver to capture {@link ConnectionFactoryOptions}. + * Simple driver for testing. * * @author Mark Paluch + * @author Andy Wilkinson */ public class SimpleConnectionFactoryProvider implements ConnectionFactoryProvider { @Override public ConnectionFactory create(ConnectionFactoryOptions connectionFactoryOptions) { - return new SimpleTestConnectionFactory(connectionFactoryOptions); + return new SimpleTestConnectionFactory(); } @Override @@ -48,12 +49,6 @@ public class SimpleConnectionFactoryProvider implements ConnectionFactoryProvide public static class SimpleTestConnectionFactory implements ConnectionFactory { - final ConnectionFactoryOptions options; - - SimpleTestConnectionFactory(ConnectionFactoryOptions options) { - this.options = options; - } - @Override public Publisher create() { return Mono.error(new UnsupportedOperationException()); @@ -64,10 +59,6 @@ public class SimpleConnectionFactoryProvider implements ConnectionFactoryProvide return SimpleConnectionFactoryProvider.class::getName; } - public ConnectionFactoryOptions getOptions() { - return this.options; - } - } } diff --git a/spring-boot-project/spring-boot/build.gradle b/spring-boot-project/spring-boot/build.gradle index 8421509c081..a14806d377e 100644 --- a/spring-boot-project/spring-boot/build.gradle +++ b/spring-boot-project/spring-boot/build.gradle @@ -35,6 +35,7 @@ dependencies { optional("io.netty:netty-tcnative-boringssl-static") optional("io.projectreactor:reactor-tools") optional("io.projectreactor.netty:reactor-netty-http") + optional("io.r2dbc:r2dbc-pool") optional("io.rsocket:rsocket-core") optional("io.rsocket:rsocket-transport-netty") optional("io.undertow:undertow-servlet") { diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/r2dbc/ConnectionFactoryBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/r2dbc/ConnectionFactoryBuilder.java index 40e66a7af17..125f0a3eb85 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/r2dbc/ConnectionFactoryBuilder.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/r2dbc/ConnectionFactoryBuilder.java @@ -16,14 +16,24 @@ package org.springframework.boot.r2dbc; +import java.time.Duration; +import java.util.Locale; import java.util.function.Consumer; +import java.util.function.Function; +import io.r2dbc.pool.ConnectionPool; +import io.r2dbc.pool.ConnectionPoolConfiguration; +import io.r2dbc.pool.PoolingConnectionFactoryProvider; import io.r2dbc.spi.ConnectionFactories; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.ConnectionFactoryOptions; import io.r2dbc.spi.ConnectionFactoryOptions.Builder; +import io.r2dbc.spi.ValidationDepth; +import io.r2dbc.spi.Wrapped; +import org.springframework.boot.context.properties.PropertyMapper; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; /** * Builder for {@link ConnectionFactory}. @@ -31,10 +41,24 @@ import org.springframework.util.Assert; * @author Mark Paluch * @author Tadaya Tsuyukubo * @author Stephane Nicoll + * @author Andy Wilkinson * @since 2.5.0 */ public final class ConnectionFactoryBuilder { + private static final OptionsCapableWrapper optionsCapableWrapper; + + static { + if (ClassUtils.isPresent("io.r2dbc.pool.ConnectionPool", ConnectionFactoryBuilder.class.getClassLoader())) { + optionsCapableWrapper = new PoolingAwareOptionsCapableWrapper(); + } + else { + optionsCapableWrapper = new OptionsCapableWrapper(); + } + } + + private static final String COLON = ":"; + private final Builder optionsBuilder; private ConnectionFactoryBuilder(Builder optionsBuilder) { @@ -63,6 +87,35 @@ public final class ConnectionFactoryBuilder { return new ConnectionFactoryBuilder(options); } + /** + * Initialize a new {@link ConnectionFactoryBuilder} derived from the options of the + * specified {@code connectionFactory}. + * @param connectionFactory the connection factory whose options are to be used to + * initialize the builder + * @return a new builder initialized with the options from the connection factory + */ + public static ConnectionFactoryBuilder derivefrom(ConnectionFactory connectionFactory) { + ConnectionFactoryOptions options = extractOptionsIfPossible(connectionFactory); + if (options == null) { + throw new IllegalArgumentException( + "ConnectionFactoryOptions could not be extracted from " + connectionFactory); + } + return withOptions(options.mutate()); + } + + private static ConnectionFactoryOptions extractOptionsIfPossible(ConnectionFactory connectionFactory) { + if (connectionFactory instanceof OptionsCapableConnectionFactory) { + return ((OptionsCapableConnectionFactory) connectionFactory).getOptions(); + } + if (connectionFactory instanceof Wrapped) { + Object unwrapped = ((Wrapped) connectionFactory).unwrap(); + if (unwrapped instanceof ConnectionFactory) { + return extractOptionsIfPossible((ConnectionFactory) unwrapped); + } + } + return null; + } + /** * Configure additional options. * @param options a {@link Consumer} to customize the options @@ -123,7 +176,8 @@ public final class ConnectionFactoryBuilder { * @return a connection factory */ public ConnectionFactory build() { - return ConnectionFactories.get(buildOptions()); + ConnectionFactoryOptions options = buildOptions(); + return optionsCapableWrapper.buildAndWrap(options); } /** @@ -134,4 +188,100 @@ public final class ConnectionFactoryBuilder { return this.optionsBuilder.build(); } + private static class OptionsCapableWrapper { + + ConnectionFactory buildAndWrap(ConnectionFactoryOptions options) { + ConnectionFactory connectionFactory = ConnectionFactories.get(options); + return new OptionsCapableConnectionFactory(options, connectionFactory); + } + + } + + static final class PoolingAwareOptionsCapableWrapper extends OptionsCapableWrapper { + + private final PoolingConnectionFactoryProvider poolingProvider = new PoolingConnectionFactoryProvider(); + + @Override + ConnectionFactory buildAndWrap(ConnectionFactoryOptions options) { + if (!this.poolingProvider.supports(options)) { + return super.buildAndWrap(options); + } + ConnectionFactoryOptions delegateOptions = delegateFactoryOptions(options); + ConnectionFactory connectionFactory = super.buildAndWrap(delegateOptions); + ConnectionPoolConfiguration poolConfiguration = connectionPoolConfiguration(delegateOptions, + connectionFactory); + return new ConnectionPool(poolConfiguration); + } + + private ConnectionFactoryOptions delegateFactoryOptions(ConnectionFactoryOptions options) { + String protocol = options.getRequiredValue(ConnectionFactoryOptions.PROTOCOL); + if (protocol.trim().length() == 0) { + throw new IllegalArgumentException(String.format("Protocol %s is not valid.", protocol)); + } + String[] protocols = protocol.split(COLON, 2); + String driverDelegate = protocols[0]; + String protocolDelegate = (protocols.length != 2) ? "" : protocols[1]; + ConnectionFactoryOptions newOptions = ConnectionFactoryOptions.builder().from(options) + .option(ConnectionFactoryOptions.DRIVER, driverDelegate) + .option(ConnectionFactoryOptions.PROTOCOL, protocolDelegate).build(); + return newOptions; + } + + ConnectionPoolConfiguration connectionPoolConfiguration(ConnectionFactoryOptions options, + ConnectionFactory connectionFactory) { + ConnectionPoolConfiguration.Builder builder = ConnectionPoolConfiguration.builder(connectionFactory); + PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull(); + map.from((Object) options.getValue(PoolingConnectionFactoryProvider.INITIAL_SIZE)).as(this::toInteger) + .to(builder::initialSize); + map.from((Object) options.getValue(PoolingConnectionFactoryProvider.MAX_SIZE)).as(this::toInteger) + .to(builder::maxSize); + map.from((Object) options.getValue(PoolingConnectionFactoryProvider.ACQUIRE_RETRY)).as(this::toInteger) + .to(builder::acquireRetry); + map.from((Object) options.getValue(PoolingConnectionFactoryProvider.MAX_LIFE_TIME)).as(this::toDuration) + .to(builder::maxLifeTime); + map.from((Object) options.getValue(PoolingConnectionFactoryProvider.MAX_ACQUIRE_TIME)).as(this::toDuration) + .to(builder::maxAcquireTime); + map.from((Object) options.getValue(PoolingConnectionFactoryProvider.MAX_IDLE_TIME)).as(this::toDuration) + .to(builder::maxIdleTime); + map.from((Object) options.getValue(PoolingConnectionFactoryProvider.MAX_CREATE_CONNECTION_TIME)) + .as(this::toDuration).to(builder::maxCreateConnectionTime); + map.from(options.getValue(PoolingConnectionFactoryProvider.POOL_NAME)).to(builder::name); + map.from((Object) options.getValue(PoolingConnectionFactoryProvider.REGISTER_JMX)).as(this::toBoolean) + .to(builder::registerJmx); + map.from(options.getValue(PoolingConnectionFactoryProvider.VALIDATION_QUERY)).to(builder::validationQuery); + map.from((Object) options.getValue(PoolingConnectionFactoryProvider.VALIDATION_DEPTH)) + .as(this::toValidationDepth).to(builder::validationDepth); + ConnectionPoolConfiguration build = builder.build(); + return build; + } + + private Integer toInteger(Object object) { + return toType(Integer.class, object, Integer::valueOf); + } + + private Duration toDuration(Object object) { + return toType(Duration.class, object, Duration::parse); + } + + private Boolean toBoolean(Object object) { + return toType(Boolean.class, object, Boolean::valueOf); + } + + private ValidationDepth toValidationDepth(Object object) { + return toType(ValidationDepth.class, object, + (string) -> ValidationDepth.valueOf(string.toUpperCase(Locale.ENGLISH))); + } + + private T toType(Class type, Object object, Function converter) { + if (type.isInstance(object)) { + return type.cast(object); + } + if (object instanceof String) { + return converter.apply((String) object); + } + throw new IllegalArgumentException("Cannot convert '" + object + "' to " + type.getName()); + } + + } + } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/r2dbc/OptionsCapableConnectionFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/r2dbc/OptionsCapableConnectionFactory.java new file mode 100644 index 00000000000..38f55be48d3 --- /dev/null +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/r2dbc/OptionsCapableConnectionFactory.java @@ -0,0 +1,70 @@ +/* + * Copyright 2012-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.r2dbc; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.ConnectionFactoryMetadata; +import io.r2dbc.spi.ConnectionFactoryOptions; +import io.r2dbc.spi.Wrapped; +import org.reactivestreams.Publisher; + +/** + * {@link ConnectionFactory} capable of providing access to the + * {@link ConnectionFactoryOptions} from which it was built. + * + * @author Andy Wilkinson + * @since 2.5.0 + */ +public class OptionsCapableConnectionFactory implements Wrapped, ConnectionFactory { + + private final ConnectionFactoryOptions options; + + private final ConnectionFactory delegate; + + /** + * Create a new {@code OptionsCapableConnectionFactory} that will provide access to + * the given {@code options} that were used to build the given {@code delegate} + * {@link ConnectionFactory}. + * @param options the options from which the connection factory was built + * @param delegate the delegate connection factory that was built with options + */ + public OptionsCapableConnectionFactory(ConnectionFactoryOptions options, ConnectionFactory delegate) { + this.options = options; + this.delegate = delegate; + } + + public ConnectionFactoryOptions getOptions() { + return this.options; + } + + @Override + public Publisher create() { + return this.delegate.create(); + } + + @Override + public ConnectionFactoryMetadata getMetadata() { + return this.delegate.getMetadata(); + } + + @Override + public ConnectionFactory unwrap() { + return this.delegate; + } + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/r2dbc/ConnectionFactoryBuilderTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/r2dbc/ConnectionFactoryBuilderTests.java index 2724f1a8b0d..06abd175a1e 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/r2dbc/ConnectionFactoryBuilderTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/r2dbc/ConnectionFactoryBuilderTests.java @@ -16,16 +16,30 @@ package org.springframework.boot.r2dbc; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; import java.util.UUID; import io.r2dbc.h2.H2ConnectionFactoryMetadata; +import io.r2dbc.pool.ConnectionPool; +import io.r2dbc.pool.ConnectionPoolConfiguration; +import io.r2dbc.pool.PoolingConnectionFactoryProvider; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.ConnectionFactoryOptions; import io.r2dbc.spi.Option; +import io.r2dbc.spi.ValidationDepth; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.boot.r2dbc.ConnectionFactoryBuilder.PoolingAwareOptionsCapableWrapper; +import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.Mockito.mock; /** * Tests for {@link ConnectionFactoryBuilder}. @@ -120,4 +134,146 @@ class ConnectionFactoryBuilderTests { assertThat(connectionFactory.getMetadata().getName()).isEqualTo(H2ConnectionFactoryMetadata.NAME); } + @Test + void buildWhenDerivedWithNewDatabaseReturnsNewConnectionFactory() { + String intialDatabaseName = UUID.randomUUID().toString(); + ConnectionFactory connectionFactory = ConnectionFactoryBuilder + .withUrl(EmbeddedDatabaseConnection.H2.getUrl(intialDatabaseName)).build(); + ConnectionFactoryOptions initialOptions = ((OptionsCapableConnectionFactory) connectionFactory).getOptions(); + String derivedDatabaseName = UUID.randomUUID().toString(); + ConnectionFactory derived = ConnectionFactoryBuilder.derivefrom(connectionFactory).database(derivedDatabaseName) + .build(); + ConnectionFactoryOptions derivedOptions = ((OptionsCapableConnectionFactory) derived).getOptions(); + assertThat(derivedOptions.getRequiredValue(ConnectionFactoryOptions.DATABASE)).isEqualTo(derivedDatabaseName); + assertMatchingOptions(derivedOptions, initialOptions, ConnectionFactoryOptions.CONNECT_TIMEOUT, + ConnectionFactoryOptions.DRIVER, ConnectionFactoryOptions.HOST, ConnectionFactoryOptions.PASSWORD, + ConnectionFactoryOptions.PORT, ConnectionFactoryOptions.PROTOCOL, ConnectionFactoryOptions.SSL, + ConnectionFactoryOptions.USER); + } + + @Test + void buildWhenDerivedWithNewCredentialsReturnsNewConnectionFactory() { + ConnectionFactory connectionFactory = ConnectionFactoryBuilder + .withUrl(EmbeddedDatabaseConnection.H2.getUrl(UUID.randomUUID().toString())).build(); + ConnectionFactoryOptions initialOptions = ((OptionsCapableConnectionFactory) connectionFactory).getOptions(); + ConnectionFactory derived = ConnectionFactoryBuilder.derivefrom(connectionFactory).username("admin") + .password("secret").build(); + ConnectionFactoryOptions derivedOptions = ((OptionsCapableConnectionFactory) derived).getOptions(); + assertThat(derivedOptions.getRequiredValue(ConnectionFactoryOptions.USER)).isEqualTo("admin"); + assertThat(derivedOptions.getRequiredValue(ConnectionFactoryOptions.PASSWORD)).isEqualTo("secret"); + assertMatchingOptions(derivedOptions, initialOptions, ConnectionFactoryOptions.CONNECT_TIMEOUT, + ConnectionFactoryOptions.DATABASE, ConnectionFactoryOptions.DRIVER, ConnectionFactoryOptions.HOST, + ConnectionFactoryOptions.PORT, ConnectionFactoryOptions.PROTOCOL, ConnectionFactoryOptions.SSL); + } + + @Test + void buildWhenDerivedFromPoolReturnsNewNonPooledConnectionFactory() { + ConnectionFactory connectionFactory = ConnectionFactoryBuilder + .withUrl(EmbeddedDatabaseConnection.H2.getUrl(UUID.randomUUID().toString())).build(); + ConnectionFactoryOptions initialOptions = ((OptionsCapableConnectionFactory) connectionFactory).getOptions(); + ConnectionPoolConfiguration poolConfiguration = ConnectionPoolConfiguration.builder(connectionFactory).build(); + ConnectionPool pool = new ConnectionPool(poolConfiguration); + ConnectionFactory derived = ConnectionFactoryBuilder.derivefrom(pool).username("admin").password("secret") + .build(); + assertThat(derived).isNotInstanceOf(ConnectionPool.class).isInstanceOf(OptionsCapableConnectionFactory.class); + ConnectionFactoryOptions derivedOptions = ((OptionsCapableConnectionFactory) derived).getOptions(); + assertThat(derivedOptions.getRequiredValue(ConnectionFactoryOptions.USER)).isEqualTo("admin"); + assertThat(derivedOptions.getRequiredValue(ConnectionFactoryOptions.PASSWORD)).isEqualTo("secret"); + assertMatchingOptions(derivedOptions, initialOptions, ConnectionFactoryOptions.CONNECT_TIMEOUT, + ConnectionFactoryOptions.DATABASE, ConnectionFactoryOptions.DRIVER, ConnectionFactoryOptions.HOST, + ConnectionFactoryOptions.PORT, ConnectionFactoryOptions.PROTOCOL, ConnectionFactoryOptions.SSL); + } + + @ParameterizedTest + @SuppressWarnings({ "rawtypes", "unchecked" }) + @MethodSource("poolingConnectionProviderOptions") + void optionIsMappedWhenCreatingPoolConfiguration(Option option) { + String url = "r2dbc:pool:h2:mem:///" + UUID.randomUUID().toString(); + ExpectedOption expectedOption = ExpectedOption.get(option); + ConnectionFactoryOptions options = ConnectionFactoryBuilder.withUrl(url).configure((builder) -> builder + .option(PoolingConnectionFactoryProvider.POOL_NAME, "defaultName").option(option, expectedOption.value)) + .buildOptions(); + ConnectionPoolConfiguration configuration = new PoolingAwareOptionsCapableWrapper() + .connectionPoolConfiguration(options, mock(ConnectionFactory.class)); + assertThat(configuration).extracting(expectedOption.property).isEqualTo(expectedOption.value); + } + + @ParameterizedTest + @SuppressWarnings({ "rawtypes", "unchecked" }) + @MethodSource("poolingConnectionProviderOptions") + void stringlyTypedOptionIsMappedWhenCreatingPoolConfiguration(Option option) { + String url = "r2dbc:pool:h2:mem:///" + UUID.randomUUID().toString(); + ExpectedOption expectedOption = ExpectedOption.get(option); + ConnectionFactoryOptions options = ConnectionFactoryBuilder.withUrl(url) + .configure((builder) -> builder.option(PoolingConnectionFactoryProvider.POOL_NAME, "defaultName") + .option(option, expectedOption.value.toString())) + .buildOptions(); + ConnectionPoolConfiguration configuration = new PoolingAwareOptionsCapableWrapper() + .connectionPoolConfiguration(options, mock(ConnectionFactory.class)); + assertThat(configuration).extracting(expectedOption.property).isEqualTo(expectedOption.value); + } + + private void assertMatchingOptions(ConnectionFactoryOptions actualOptions, ConnectionFactoryOptions expectedOptions, + Option... optionsToCheck) { + for (Option option : optionsToCheck) { + assertThat(actualOptions.getValue(option)).as(option.name()).isEqualTo(expectedOptions.getValue(option)); + } + } + + private static Iterable poolingConnectionProviderOptions() { + List arguments = new ArrayList<>(); + ReflectionUtils.doWithFields(PoolingConnectionFactoryProvider.class, + (field) -> arguments.add(Arguments.of((Option) ReflectionUtils.getField(field, null))), + (field) -> Option.class.equals(field.getType())); + return arguments; + } + + private enum ExpectedOption { + + ACQUIRE_RETRY(PoolingConnectionFactoryProvider.ACQUIRE_RETRY, 4, "acquireRetry"), + + INITIAL_SIZE(PoolingConnectionFactoryProvider.INITIAL_SIZE, 2, "initialSize"), + + MAX_SIZE(PoolingConnectionFactoryProvider.MAX_SIZE, 8, "maxSize"), + + MAX_LIFE_TIME(PoolingConnectionFactoryProvider.MAX_LIFE_TIME, Duration.ofMinutes(2), "maxLifeTime"), + + MAX_ACQUIRE_TIME(PoolingConnectionFactoryProvider.MAX_ACQUIRE_TIME, Duration.ofSeconds(30), "maxAcquireTime"), + + MAX_IDLE_TIME(PoolingConnectionFactoryProvider.MAX_IDLE_TIME, Duration.ofMinutes(1), "maxIdleTime"), + + MAX_CREATE_CONNECTION_TIME(PoolingConnectionFactoryProvider.MAX_CREATE_CONNECTION_TIME, Duration.ofSeconds(10), + "maxCreateConnectionTime"), + + POOL_NAME(PoolingConnectionFactoryProvider.POOL_NAME, "testPool", "name"), + + REGISTER_JMX(PoolingConnectionFactoryProvider.REGISTER_JMX, true, "registerJmx"), + + VALIDATION_QUERY(PoolingConnectionFactoryProvider.VALIDATION_QUERY, "SELECT 1", "validationQuery"), + + VALIDATION_DEPTH(PoolingConnectionFactoryProvider.VALIDATION_DEPTH, ValidationDepth.REMOTE, "validationDepth"); + + private final Option option; + + private final Object value; + + private final String property; + + ExpectedOption(Option option, Object value, String property) { + this.option = option; + this.value = value; + this.property = property; + } + + static ExpectedOption get(Option option) { + for (ExpectedOption expectedOption : values()) { + if (expectedOption.option == option) { + return expectedOption; + } + } + throw new IllegalArgumentException("Unexpected option: '" + option + "'"); + } + + } + }