diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java index c4aa64de823..83e4c692a85 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java @@ -55,6 +55,12 @@ import org.springframework.util.Assert; */ public interface DatabaseClient extends ConnectionAccessor { + /** + * Return the {@link ConnectionFactory} that this client uses. + * @return the connection factory + */ + ConnectionFactory getConnectionFactory(); + /** * Specify a static {@code sql} statement to run. Contract for specifying a * SQL call along with options leading to the execution. The SQL string can diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java index 0f8c3197727..8663ef73093 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java @@ -85,6 +85,11 @@ class DefaultDatabaseClient implements DatabaseClient { } + @Override + public ConnectionFactory getConnectionFactory() { + return this.connectionFactory; + } + @Override public GenericExecuteSpec sql(String sql) { Assert.hasText(sql, "SQL must not be null or empty"); diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java index 78e14e4211c..a1a52061664 100644 --- a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java @@ -46,6 +46,7 @@ import org.springframework.lang.Nullable; import org.springframework.r2dbc.core.binding.BindMarkersFactory; import org.springframework.r2dbc.core.binding.BindTarget; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.doReturn; @@ -85,6 +86,15 @@ class DefaultDatabaseClientUnitTests { connectionFactory).bindMarkers(BindMarkersFactory.indexed("$", 1)); } + @Test + void connectionFactoryIsExposed() { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + DatabaseClient databaseClient = DatabaseClient.builder() + .connectionFactory(connectionFactory) + .bindMarkers(BindMarkersFactory.anonymous("?")).build(); + assertThat(databaseClient.getConnectionFactory()).isSameAs(connectionFactory); + } + @Test void shouldCloseConnectionOnlyOnce() { DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) databaseClientBuilder.build();