Access connection factory from DatabaseClient

This commit provides an accessor for the underlying ConnectionFactory
that a DatabaseClient uses.

Closes gh-25521
This commit is contained in:
Stephane Nicoll 2020-08-04 11:45:32 +02:00
parent 673f83e388
commit 1dcd7f418f
3 changed files with 21 additions and 0 deletions

View File

@ -55,6 +55,12 @@ import org.springframework.util.Assert;
*/
public interface DatabaseClient extends ConnectionAccessor {
/**
* Return the {@link ConnectionFactory} that this client uses.
* @return the connection factory
*/
ConnectionFactory getConnectionFactory();
/**
* Specify a static {@code sql} statement to run. Contract for specifying a
* SQL call along with options leading to the execution. The SQL string can

View File

@ -85,6 +85,11 @@ class DefaultDatabaseClient implements DatabaseClient {
}
@Override
public ConnectionFactory getConnectionFactory() {
return this.connectionFactory;
}
@Override
public GenericExecuteSpec sql(String sql) {
Assert.hasText(sql, "SQL must not be null or empty");

View File

@ -46,6 +46,7 @@ import org.springframework.lang.Nullable;
import org.springframework.r2dbc.core.binding.BindMarkersFactory;
import org.springframework.r2dbc.core.binding.BindTarget;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.doReturn;
@ -85,6 +86,15 @@ class DefaultDatabaseClientUnitTests {
connectionFactory).bindMarkers(BindMarkersFactory.indexed("$", 1));
}
@Test
void connectionFactoryIsExposed() {
ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
DatabaseClient databaseClient = DatabaseClient.builder()
.connectionFactory(connectionFactory)
.bindMarkers(BindMarkersFactory.anonymous("?")).build();
assertThat(databaseClient.getConnectionFactory()).isSameAs(connectionFactory);
}
@Test
void shouldCloseConnectionOnlyOnce() {
DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) databaseClientBuilder.build();