diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGenerator.java index c2b9e7a4f50..4741ff185fd 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGenerator.java @@ -513,7 +513,12 @@ class BeanDefinitionPropertyValueCodeGenerator { @Override @Nullable public CodeBlock generateCode(Object value, ResolvableType type) { - if (value instanceof BeanReference beanReference) { + if (value instanceof RuntimeBeanReference runtimeBeanReference + && runtimeBeanReference.getBeanType() != null) { + return CodeBlock.of("new $T($T.class)", RuntimeBeanReference.class, + runtimeBeanReference.getBeanType()); + } + else if (value instanceof BeanReference beanReference) { return CodeBlock.of("new $T($S)", RuntimeBeanReference.class, beanReference.getBeanName()); } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGeneratorTests.java index cacecfbc176..4eff8169129 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGeneratorTests.java @@ -40,6 +40,7 @@ import org.springframework.aot.test.generator.compile.Compiled; import org.springframework.aot.test.generator.compile.TestCompiler; import org.springframework.beans.factory.config.BeanReference; import org.springframework.beans.factory.config.RuntimeBeanNameReference; +import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.ManagedList; import org.springframework.beans.factory.support.ManagedMap; import org.springframework.beans.factory.support.ManagedSet; @@ -465,10 +466,31 @@ class BeanDefinitionPropertyValueCodeGeneratorTests { class BeanReferenceTests { @Test - void generatedWhenBeanReference() { - BeanReference beanReference = new RuntimeBeanNameReference("test"); - compile(beanReference, (instance, compiler) -> - assertThat(((BeanReference) instance).getBeanName()).isEqualTo(beanReference.getBeanName())); + void generatedWhenBeanNameReference() { + RuntimeBeanNameReference beanReference = new RuntimeBeanNameReference("test"); + compile(beanReference, (instance, compiler) -> { + RuntimeBeanReference actual = (RuntimeBeanReference) instance; + assertThat(actual.getBeanName()).isEqualTo(beanReference.getBeanName()); + }); + } + + @Test + void generatedWhenBeanReferenceByName() { + RuntimeBeanReference beanReference = new RuntimeBeanReference("test"); + compile(beanReference, (instance, compiler) -> { + RuntimeBeanReference actual = (RuntimeBeanReference) instance; + assertThat(actual.getBeanName()).isEqualTo(beanReference.getBeanName()); + assertThat(actual.getBeanType()).isEqualTo(beanReference.getBeanType()); + }); + } + + @Test + void generatedWhenBeanReferenceByType() { + BeanReference beanReference = new RuntimeBeanReference(String.class); + compile(beanReference, (instance, compiler) -> { + RuntimeBeanReference actual = (RuntimeBeanReference) instance; + assertThat(actual.getBeanType()).isEqualTo(String.class); + }); } }