Enforce BeanRegistrationExcludeFilter beans are also AOT processors

Update `BeanDefinitionMethodGeneratorFactory` to enforce that any
`BeanRegistrationExcludeFilter` filter that is from a bean factory
also implements an AOT processor interface.

See gh-28866
This commit is contained in:
Phillip Webb 2022-07-28 13:48:52 +01:00
parent a4f3d1d6e8
commit 9413383d72
3 changed files with 42 additions and 5 deletions

View File

@ -22,10 +22,12 @@ import java.util.List;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.aot.AotServices.Source;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.core.log.LogMessage; import org.springframework.core.log.LogMessage;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils; import org.springframework.util.ObjectUtils;
/** /**
@ -65,6 +67,14 @@ class BeanDefinitionMethodGeneratorFactory {
BeanDefinitionMethodGeneratorFactory(AotServices.Loader loader) { BeanDefinitionMethodGeneratorFactory(AotServices.Loader loader) {
this.aotProcessors = loader.load(BeanRegistrationAotProcessor.class); this.aotProcessors = loader.load(BeanRegistrationAotProcessor.class);
this.excludeFilters = loader.load(BeanRegistrationExcludeFilter.class); this.excludeFilters = loader.load(BeanRegistrationExcludeFilter.class);
for (BeanRegistrationExcludeFilter excludeFilter : this.excludeFilters) {
if (this.excludeFilters.getSource(excludeFilter) == Source.BEAN_FACTORY) {
Assert.state(excludeFilter instanceof BeanRegistrationAotProcessor
|| excludeFilter instanceof BeanFactoryInitializationAotProcessor,
() -> "BeanRegistrationExcludeFilter bean of type %s must also implement an AOT processor interface"
.formatted(excludeFilter.getClass().getName()));
}
}
} }
@ -97,7 +107,7 @@ class BeanDefinitionMethodGeneratorFactory {
return true; return true;
} }
for (BeanRegistrationExcludeFilter excludeFilter : this.excludeFilters) { for (BeanRegistrationExcludeFilter excludeFilter : this.excludeFilters) {
if (excludeFilter.isExcluded(registeredBean)) { if (excludeFilter.isExcludedFromAotProcessing(registeredBean)) {
logger.trace(LogMessage.format( logger.trace(LogMessage.format(
"Excluding registered bean '%s' from bean factory %s due to %s", "Excluding registered bean '%s' from bean factory %s due to %s",
registeredBean.getBeanName(), registeredBean.getBeanName(),

View File

@ -35,6 +35,6 @@ public interface BeanRegistrationExcludeFilter {
* @param registeredBean the registered bean * @param registeredBean the registered bean
* @return if the registered bean should be excluded * @return if the registered bean should be excluded
*/ */
boolean isExcluded(RegisteredBean registeredBean); boolean isExcludedFromAotProcessing(RegisteredBean registeredBean);
} }

View File

@ -26,6 +26,8 @@ import org.springframework.core.Ordered;
import org.springframework.core.mock.MockSpringFactoriesLoader; import org.springframework.core.mock.MockSpringFactoriesLoader;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
/** /**
@ -35,6 +37,25 @@ import static org.mockito.Mockito.mock;
*/ */
class BeanDefinitionMethodGeneratorFactoryTests { class BeanDefinitionMethodGeneratorFactoryTests {
@Test
void createWhenBeanRegistrationExcludeFilterBeanIsNotAotProcessorThrowsException() {
BeanRegistrationExcludeFilter filter = registeredBean -> false;
DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
beanFactory.registerSingleton("filter", filter);
assertThatIllegalStateException()
.isThrownBy(() -> new BeanDefinitionMethodGeneratorFactory(beanFactory))
.withMessageContaining("also implement an AOT processor interface");
}
@Test
void createWhenBeanRegistrationExcludeFilterFactoryIsNotAotProcessorLoads() {
BeanRegistrationExcludeFilter filter = registeredBean -> false;
MockSpringFactoriesLoader loader = new MockSpringFactoriesLoader();
loader.addInstance(BeanRegistrationExcludeFilter.class, filter);
assertThatNoException().isThrownBy(() -> new BeanDefinitionMethodGeneratorFactory(
AotServices.factories(loader)));
}
@Test @Test
void getBeanDefinitionMethodGeneratorWhenExcludedByBeanRegistrationExcludeFilterReturnsNull() { void getBeanDefinitionMethodGeneratorWhenExcludedByBeanRegistrationExcludeFilterReturnsNull() {
MockSpringFactoriesLoader springFactoriesLoader = new MockSpringFactoriesLoader(); MockSpringFactoriesLoader springFactoriesLoader = new MockSpringFactoriesLoader();
@ -144,8 +165,8 @@ class BeanDefinitionMethodGeneratorFactoryTests {
return RegisteredBean.of(beanFactory, "test"); return RegisteredBean.of(beanFactory, "test");
} }
static class MockBeanRegistrationExcludeFilter static class MockBeanRegistrationExcludeFilter implements
implements BeanRegistrationExcludeFilter, Ordered { BeanRegistrationAotProcessor, BeanRegistrationExcludeFilter, Ordered {
private final boolean excluded; private final boolean excluded;
@ -159,7 +180,13 @@ class BeanDefinitionMethodGeneratorFactoryTests {
} }
@Override @Override
public boolean isExcluded(RegisteredBean registeredBean) { public BeanRegistrationAotContribution processAheadOfTime(
RegisteredBean registeredBean) {
return null;
}
@Override
public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) {
this.registeredBean = registeredBean; this.registeredBean = registeredBean;
return this.excluded; return this.excluded;
} }