diff --git a/spring-context/src/main/java/org/springframework/context/support/PostProcessorRegistrationDelegate.java b/spring-context/src/main/java/org/springframework/context/support/PostProcessorRegistrationDelegate.java index 849c6a9e9ce..23e4be2f4fc 100644 --- a/spring-context/src/main/java/org/springframework/context/support/PostProcessorRegistrationDelegate.java +++ b/spring-context/src/main/java/org/springframework/context/support/PostProcessorRegistrationDelegate.java @@ -364,7 +364,7 @@ final class PostProcessorRegistrationDelegate { * Register the given BeanPostProcessor beans. */ private static void registerBeanPostProcessors( - ConfigurableListableBeanFactory beanFactory, List postProcessors) { + ConfigurableListableBeanFactory beanFactory, List postProcessors) { if (beanFactory instanceof AbstractBeanFactory) { // Bulk addition is more efficient against our CopyOnWriteArrayList there @@ -439,6 +439,7 @@ final class PostProcessorRegistrationDelegate { Class beanType = resolveBeanType(bd); postProcessRootBeanDefinition(postProcessors, beanName, beanType, bd); } + registerBeanPostProcessors(this.beanFactory, postProcessors); } private void postProcessRootBeanDefinition(List postProcessors, diff --git a/spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java b/spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java index adc78801766..57831c404d7 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java @@ -431,6 +431,24 @@ class AnnotationConfigApplicationContextTests { "annotationConfigApplicationContextTests.Config", "testBean"); } + @Test + void refreshForAotCanInstantiateBeanWithAutowiredApplicationContext() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(BeanD.class); + context.refreshForAotProcessing(); + BeanD bean = context.getBean(BeanD.class); + assertThat(bean.applicationContext).isSameAs(context); + } + + @Test + void refreshForAotCanInstantiateBeanWithFieldAutowiredApplicationContext() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(BeanB.class); + context.refreshForAotProcessing(); + BeanB bean = context.getBean(BeanB.class); + assertThat(bean.applicationContext).isSameAs(context); + } + @Configuration static class Config { @@ -506,6 +524,16 @@ class AnnotationConfigApplicationContextTests { static class BeanC {} + static class BeanD { + + private final ApplicationContext applicationContext; + + public BeanD(ApplicationContext applicationContext) { + this.applicationContext = applicationContext; + } + + } + static class NonInstantiatedFactoryBean implements FactoryBean { NonInstantiatedFactoryBean() { diff --git a/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java b/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java index fd8ec331824..96cad6c1596 100644 --- a/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java +++ b/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java @@ -23,10 +23,12 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.OS; import org.mockito.ArgumentCaptor; +import org.springframework.beans.BeansException; import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.beans.factory.config.AbstractFactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.GenericBeanDefinition; import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor; @@ -407,6 +409,32 @@ class GenericApplicationContextTests { context.close(); } + @Test + void refreshForAotInvokesBeanPostProcessorContractOnMergedBeanDefinitionPostProcessors() { + MergedBeanDefinitionPostProcessor bpp = new MergedBeanDefinitionPostProcessor() { + @Override + public void postProcessMergedBeanDefinition(RootBeanDefinition beanDefinition, Class beanType, String beanName) { + beanDefinition.setAttribute("mbdppCalled", true); + } + + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { + return (beanName.equals("test") ? "42" : bean); + } + }; + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition("bpp", BeanDefinitionBuilder.rootBeanDefinition( + MergedBeanDefinitionPostProcessor.class, () -> bpp) + .setRole(BeanDefinition.ROLE_INFRASTRUCTURE).getBeanDefinition()); + AbstractBeanDefinition bd = BeanDefinitionBuilder.rootBeanDefinition(String.class) + .addConstructorArgValue("value").getBeanDefinition(); + context.registerBeanDefinition("test", bd); + context.refreshForAotProcessing(); + assertThat(context.getBeanFactory().getMergedBeanDefinition("test") + .hasAttribute("mbdppCalled")).isTrue(); + assertThat(context.getBean("test")).isEqualTo("42"); + } + @Test void refreshForAotFailsOnAnActiveContext() { GenericApplicationContext context = new GenericApplicationContext();