Polish "Use DataSource.unwrap to get routing data source"

See gh-42313
This commit is contained in:
Stéphane Nicoll 2024-09-16 09:29:04 +02:00
parent 3f9f0490a6
commit 78a140ae25
2 changed files with 20 additions and 81 deletions

View File

@ -89,7 +89,7 @@ public class DataSourceHealthContributorAutoConfiguration implements Initializin
if (dataSourceHealthIndicatorProperties.isIgnoreRoutingDataSources()) {
Map<String, DataSource> filteredDatasources = dataSources.entrySet()
.stream()
.filter((e) -> !isAbstractRoutingDataSource(e.getValue()))
.filter((e) -> !isRoutingDataSource(e.getValue()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
return createContributor(filteredDatasources);
}
@ -105,9 +105,8 @@ public class DataSourceHealthContributorAutoConfiguration implements Initializin
}
private HealthContributor createContributor(DataSource source) {
if (isAbstractRoutingDataSource(source)) {
return new RoutingDataSourceHealthContributor(unwrapAbstractRoutingDataSource(source),
this::createContributor);
if (isRoutingDataSource(source)) {
return new RoutingDataSourceHealthContributor(extractRoutingDataSource(source), this::createContributor);
}
return new DataSourceHealthIndicator(source, getValidationQuery(source));
}
@ -117,7 +116,7 @@ public class DataSourceHealthContributorAutoConfiguration implements Initializin
return (poolMetadata != null) ? poolMetadata.getValidationQuery() : null;
}
private static boolean isAbstractRoutingDataSource(DataSource dataSource) {
private static boolean isRoutingDataSource(DataSource dataSource) {
if (dataSource instanceof AbstractRoutingDataSource) {
return true;
}
@ -129,7 +128,7 @@ public class DataSourceHealthContributorAutoConfiguration implements Initializin
}
}
private static AbstractRoutingDataSource unwrapAbstractRoutingDataSource(DataSource dataSource) {
private static AbstractRoutingDataSource extractRoutingDataSource(DataSource dataSource) {
if (dataSource instanceof AbstractRoutingDataSource routingDataSource) {
return routingDataSource;
}
@ -137,8 +136,7 @@ public class DataSourceHealthContributorAutoConfiguration implements Initializin
return dataSource.unwrap(AbstractRoutingDataSource.class);
}
catch (SQLException ex) {
throw new IllegalStateException(
"DataSource '%s' failed to unwrap '%s'".formatted(dataSource, AbstractRoutingDataSource.class), ex);
throw new IllegalStateException("Failed to unwrap AbstractRoutingDataSource from " + dataSource, ex);
}
}

View File

@ -16,15 +16,9 @@
package org.springframework.boot.actuate.autoconfigure.jdbc;
import java.io.PrintWriter;
import java.sql.Connection;
import java.sql.ConnectionBuilder;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.sql.ShardingKeyBuilder;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
import javax.sql.DataSource;
@ -256,11 +250,24 @@ class DataSourceHealthContributorAutoConfigurationTests {
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
if (bean instanceof DataSource dataSource) {
return new ProxyDataSource(dataSource);
return proxyDataSource(dataSource);
}
return bean;
}
private static DataSource proxyDataSource(DataSource dataSource) {
try {
DataSource mock = mock(DataSource.class);
given(mock.isWrapperFor(AbstractRoutingDataSource.class))
.willReturn(dataSource instanceof AbstractRoutingDataSource);
given(mock.unwrap(AbstractRoutingDataSource.class)).willAnswer((invocation) -> dataSource);
return mock;
}
catch (SQLException ex) {
throw new IllegalStateException(ex);
}
}
}
@Configuration(proxyBeanMethods = false)
@ -280,70 +287,4 @@ class DataSourceHealthContributorAutoConfigurationTests {
}
static class ProxyDataSource implements DataSource {
private final DataSource dataSource;
ProxyDataSource(DataSource dataSource) {
this.dataSource = dataSource;
}
@Override
public void setLogWriter(PrintWriter out) throws SQLException {
this.dataSource.setLogWriter(out);
}
@Override
public Connection getConnection() throws SQLException {
return this.dataSource.getConnection();
}
@Override
public Connection getConnection(String username, String password) throws SQLException {
return this.dataSource.getConnection(username, password);
}
@Override
public PrintWriter getLogWriter() throws SQLException {
return this.dataSource.getLogWriter();
}
@Override
public void setLoginTimeout(int seconds) throws SQLException {
this.dataSource.setLoginTimeout(seconds);
}
@Override
public int getLoginTimeout() throws SQLException {
return this.dataSource.getLoginTimeout();
}
@Override
public ConnectionBuilder createConnectionBuilder() throws SQLException {
return this.dataSource.createConnectionBuilder();
}
@Override
public Logger getParentLogger() throws SQLFeatureNotSupportedException {
return this.dataSource.getParentLogger();
}
@Override
public ShardingKeyBuilder createShardingKeyBuilder() throws SQLException {
return this.dataSource.createShardingKeyBuilder();
}
@Override
@SuppressWarnings("unchecked")
public <T> T unwrap(Class<T> iface) throws SQLException {
return iface.isInstance(this) ? (T) this : this.dataSource.unwrap(iface);
}
@Override
public boolean isWrapperFor(Class<?> iface) throws SQLException {
return (iface.isInstance(this) || this.dataSource.isWrapperFor(iface));
}
}
}