diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index 26df73d45ee..d5281ef6caf 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -117,7 +117,8 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme Class beanClass = (mergedBeanDefinition.hasBeanClass() ? ClassUtils.getUserClass(mergedBeanDefinition.getBeanClass()) : null); CodeBlock beanClassCode = generateBeanClassCode( - beanRegistrationCode.getClassName().packageName(), beanClass); + beanRegistrationCode.getClassName().packageName(), + (beanClass != null ? beanClass : beanType.toClass())); code.addStatement("$T $L = new $T($L)", RootBeanDefinition.class, BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, beanClassCode); if (targetTypeNecessary(beanType, beanClass)) { @@ -127,16 +128,13 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme return code.build(); } - private CodeBlock generateBeanClassCode(String targetPackage, @Nullable Class beanClass) { - if (beanClass != null) { - if (Modifier.isPublic(beanClass.getModifiers()) || targetPackage.equals(beanClass.getPackageName())) { - return CodeBlock.of("$T.class", beanClass); - } - else { - return CodeBlock.of("$S", beanClass.getName()); - } + private CodeBlock generateBeanClassCode(String targetPackage, Class beanClass) { + if (Modifier.isPublic(beanClass.getModifiers()) || targetPackage.equals(beanClass.getPackageName())) { + return CodeBlock.of("$T.class", beanClass); + } + else { + return CodeBlock.of("$S", beanClass.getName()); } - return CodeBlock.of(""); } private CodeBlock generateBeanTypeCode(ResolvableType beanType) { @@ -147,11 +145,14 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme } private boolean targetTypeNecessary(ResolvableType beanType, @Nullable Class beanClass) { - if (beanType.hasGenerics() || beanClass == null) { + if (beanType.hasGenerics()) { return true; } - return (!beanType.toClass().equals(beanClass) - || this.registeredBean.getMergedBeanDefinition().getFactoryMethodName() != null); + if (beanClass != null + && this.registeredBean.getMergedBeanDefinition().getFactoryMethodName() != null) { + return true; + } + return (beanClass != null && !beanType.toClass().equals(beanClass)); } @Override diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 3d490ec1a23..aae78a040a8 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -47,6 +47,7 @@ import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.factory.aot.InnerBeanConfiguration; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean; +import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanConfiguration; import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy; import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Implementation; import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.One; @@ -92,29 +93,8 @@ class BeanDefinitionMethodGeneratorTests { this.beanRegistrationsCode = new MockBeanRegistrationsCode(this.generationContext); } - @Test - void generateBeanDefinitionMethodWithOnlyTargetTypeDoesNotSetBeanClass() { - RootBeanDefinition beanDefinition = new RootBeanDefinition(); - beanDefinition.setTargetType(TestBean.class); - RegisteredBean registeredBean = registerBean(beanDefinition); - BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( - this.methodGeneratorFactory, registeredBean, null, - Collections.emptyList()); - MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, this.beanRegistrationsCode); - compile(method, (actual, compiled) -> { - SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); - assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); - assertThat(sourceFile).contains("new RootBeanDefinition()"); - assertThat(sourceFile).contains("setTargetType(TestBean.class)"); - assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); - assertThat(actual).isInstanceOf(RootBeanDefinition.class); - }); - } - - @Test - void generateBeanDefinitionMethodSpecifiesBeanClassIfSet() { + void generateWithBeanClassSetsOnlyBeanClass() { RootBeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class); RegisteredBean registeredBean = registerBean(beanDefinition); BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( @@ -133,7 +113,91 @@ class BeanDefinitionMethodGeneratorTests { } @Test - void generateBeanDefinitionMethodSpecifiesBeanClassAndTargetTypIfDifferent() { + void generateWithTargetTypeWithNoGenericSetsOnlyBeanClass() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(); + beanDefinition.setTargetType(TestBean.class); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(TestBean.class)"); + assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateWithTargetTypeUsingGenericsSetsBothBeanClassAndTargetType() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(); + beanDefinition.setTargetType(ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class)); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + assertThat(actual.getResolvableType().resolve()).isEqualTo(GenericBean.class); + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(GenericBean.class)"); + assertThat(sourceFile).contains( + "setTargetType(ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class))"); + assertThat(sourceFile).contains("setInstanceSupplier(GenericBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateWithBeanClassAndFactoryMethodNameSetsTargetTypeAndBeanClass() { + this.beanFactory.registerSingleton("factory", new SimpleBeanConfiguration()); + RootBeanDefinition beanDefinition = new RootBeanDefinition(SimpleBean.class); + beanDefinition.setFactoryBeanName("factory"); + beanDefinition.setFactoryMethodName("simpleBean"); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(SimpleBean.class)"); + assertThat(sourceFile).contains("setTargetType(SimpleBean.class)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateWithTargetTypeAndFactoryMethodNameSetsOnlyBeanClass() { + this.beanFactory.registerSingleton("factory", new SimpleBeanConfiguration()); + RootBeanDefinition beanDefinition = new RootBeanDefinition(); + beanDefinition.setTargetType(SimpleBean.class); + beanDefinition.setFactoryBeanName("factory"); + beanDefinition.setFactoryMethodName("simpleBean"); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(SimpleBean.class)"); + assertThat(sourceFile).doesNotContain("setTargetType("); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateWithBeanClassAndTargetTypeDifferentSetsBoth() { RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class); beanDefinition.setTargetType(Implementation.class); beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean")); @@ -152,6 +216,28 @@ class BeanDefinitionMethodGeneratorTests { }); } + @Test + void generateWithBeanClassAndTargetTypWithGenericSetsBoth() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(Integer.class); + beanDefinition.setTargetType(ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class)); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + assertThat(actual.getResolvableType().resolve()).isEqualTo(GenericBean.class); + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(Integer.class)"); + assertThat(sourceFile).contains( + "setTargetType(ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class))"); + assertThat(sourceFile).contains("setInstanceSupplier(GenericBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + @Test void generateBeanDefinitionMethodUSeBeanClassNameIfNotReachable() { RootBeanDefinition beanDefinition = new RootBeanDefinition(PackagePrivateTestBean.class);