Allow BeanRegistrationContributionProvider to access the BeanFactory

Closes gh-28384
This commit is contained in:
Stephane Nicoll 2022-04-26 15:02:54 +02:00
parent 88eac7794c
commit 7ea0cc3da2
3 changed files with 59 additions and 1 deletions

View File

@ -28,6 +28,7 @@ import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.io.support.SpringFactoriesLoader;
import org.springframework.core.io.support.SpringFactoriesLoader.ArgumentResolver;
/**
* A {@link BeanFactoryContribution} that generates the bean definitions of a
@ -62,7 +63,7 @@ public class BeanDefinitionsContribution implements BeanFactoryContribution {
private static List<BeanRegistrationContributionProvider> initializeProviders(DefaultListableBeanFactory beanFactory) {
List<BeanRegistrationContributionProvider> providers = new ArrayList<>(SpringFactoriesLoader.loadFactories(
BeanRegistrationContributionProvider.class, beanFactory.getBeanClassLoader()));
BeanRegistrationContributionProvider.class, beanFactory.getBeanClassLoader(), ArgumentResolver.from(type -> type.isInstance(beanFactory) ? beanFactory : null)));
providers.add(new DefaultBeanRegistrationContributionProvider(beanFactory));
return providers;
}

View File

@ -16,6 +16,9 @@
package org.springframework.beans.factory.generator;
import java.io.IOException;
import java.net.URL;
import java.util.Enumeration;
import java.util.List;
import java.util.function.BiPredicate;
@ -28,11 +31,14 @@ import org.springframework.aot.generator.DefaultGeneratedTypeContext;
import org.springframework.aot.generator.GeneratedType;
import org.springframework.aot.generator.GeneratedTypeContext;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.support.CodeSnippet;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -44,6 +50,19 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
*/
class BeanDefinitionsContributionTests {
@Test
void loadContributorWithConstructorArgumentOnBeanFactory() {
DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
beanFactory.setBeanClassLoader(new TestSpringFactoriesClassLoader(
"bean-registration-contribution-provider-constructor.factories"));
BeanDefinitionsContribution contribution = new BeanDefinitionsContribution(beanFactory);
assertThat(contribution).extracting("contributionProviders").asList()
.anySatisfy(provider -> assertThat(provider).isInstanceOfSatisfying(TestConstructorBeanRegistrationContributionProvider.class,
testProvider -> assertThat(testProvider.beanFactory).isSameAs(beanFactory)))
.anySatisfy(provider -> assertThat(provider).isInstanceOf(DefaultBeanRegistrationContributionProvider.class))
.hasSize(2);
}
@Test
void contributeThrowsContributionNotFoundIfNoContributionIsAvailable() {
DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
@ -158,4 +177,40 @@ class BeanDefinitionsContributionTests {
}
}
static class TestConstructorBeanRegistrationContributionProvider implements BeanRegistrationContributionProvider {
private final ConfigurableListableBeanFactory beanFactory;
TestConstructorBeanRegistrationContributionProvider(ConfigurableListableBeanFactory beanFactory) {
Assert.notNull(beanFactory, "BeanFactory must not be null");
this.beanFactory = beanFactory;
}
@Nullable
@Override
public BeanFactoryContribution getContributionFor(String beanName, RootBeanDefinition beanDefinition) {
return null;
}
}
static class TestSpringFactoriesClassLoader extends ClassLoader {
private final String factoriesName;
TestSpringFactoriesClassLoader(String factoriesName) {
super(BeanDefinitionsContributionTests.class.getClassLoader());
this.factoriesName = factoriesName;
}
@Override
public Enumeration<URL> getResources(String name) throws IOException {
if ("META-INF/spring.factories".equals(name)) {
return super.getResources("org/springframework/beans/factory/generator/" + this.factoriesName);
}
return super.getResources(name);
}
}
}

View File

@ -0,0 +1,2 @@
org.springframework.beans.factory.generator.BeanRegistrationContributionProvider= \
org.springframework.beans.factory.generator.BeanDefinitionsContributionTests.TestConstructorBeanRegistrationContributionProvider