diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index 38c0a23072a..26df73d45ee 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package org.springframework.beans.factory.aot; import java.lang.reflect.Constructor; import java.lang.reflect.Executable; +import java.lang.reflect.Modifier; import java.util.List; import java.util.function.Predicate; @@ -47,12 +48,6 @@ import org.springframework.util.ClassUtils; */ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragments { - /** - * The variable name used to hold the bean type. - */ - private static final String BEAN_TYPE_VARIABLE = "beanType"; - - private final BeanRegistrationsCode beanRegistrationsCode; private final RegisteredBean registeredBean; @@ -118,19 +113,45 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme ResolvableType beanType, BeanRegistrationCode beanRegistrationCode) { CodeBlock.Builder code = CodeBlock.builder(); - code.addStatement(generateBeanTypeCode(beanType)); + RootBeanDefinition mergedBeanDefinition = this.registeredBean.getMergedBeanDefinition(); + Class beanClass = (mergedBeanDefinition.hasBeanClass() + ? ClassUtils.getUserClass(mergedBeanDefinition.getBeanClass()) : null); + CodeBlock beanClassCode = generateBeanClassCode( + beanRegistrationCode.getClassName().packageName(), beanClass); code.addStatement("$T $L = new $T($L)", RootBeanDefinition.class, - BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, BEAN_TYPE_VARIABLE); + BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, beanClassCode); + if (targetTypeNecessary(beanType, beanClass)) { + code.addStatement("$L.setTargetType($L)", BEAN_DEFINITION_VARIABLE, + generateBeanTypeCode(beanType)); + } return code.build(); } + private CodeBlock generateBeanClassCode(String targetPackage, @Nullable Class beanClass) { + if (beanClass != null) { + if (Modifier.isPublic(beanClass.getModifiers()) || targetPackage.equals(beanClass.getPackageName())) { + return CodeBlock.of("$T.class", beanClass); + } + else { + return CodeBlock.of("$S", beanClass.getName()); + } + } + return CodeBlock.of(""); + } + private CodeBlock generateBeanTypeCode(ResolvableType beanType) { if (!beanType.hasGenerics()) { - return CodeBlock.of("$T $L = $T.class", Class.class, BEAN_TYPE_VARIABLE, - ClassUtils.getUserClass(beanType.toClass())); + return CodeBlock.of("$T.class", ClassUtils.getUserClass(beanType.toClass())); } - return CodeBlock.of("$T $L = $L", ResolvableType.class, BEAN_TYPE_VARIABLE, - ResolvableTypeCodeGenerator.generateCode(beanType)); + return ResolvableTypeCodeGenerator.generateCode(beanType); + } + + private boolean targetTypeNecessary(ResolvableType beanType, @Nullable Class beanClass) { + if (beanType.hasGenerics() || beanClass == null) { + return true; + } + return (!beanType.toClass().equals(beanClass) + || this.registeredBean.getMergedBeanDefinition().getFactoryMethodName() != null); } @Override diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index f108a3b7b3b..3d490ec1a23 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -47,6 +47,10 @@ import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.factory.aot.InnerBeanConfiguration; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Implementation; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.One; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Two; import org.springframework.core.ResolvableType; import org.springframework.core.test.io.support.MockSpringFactoriesLoader; import org.springframework.core.test.tools.CompileWithForkedClassLoader; @@ -56,6 +60,7 @@ import org.springframework.core.test.tools.TestCompiler; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -89,8 +94,10 @@ class BeanDefinitionMethodGeneratorTests { @Test - void generateBeanDefinitionMethodGeneratesMethod() { - RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class)); + void generateBeanDefinitionMethodWithOnlyTargetTypeDoesNotSetBeanClass() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(); + beanDefinition.setTargetType(TestBean.class); + RegisteredBean registeredBean = registerBean(beanDefinition); BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( this.methodGeneratorFactory, registeredBean, null, Collections.emptyList()); @@ -99,7 +106,67 @@ class BeanDefinitionMethodGeneratorTests { compile(method, (actual, compiled) -> { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); - assertThat(sourceFile).contains("beanType = TestBean.class"); + assertThat(sourceFile).contains("new RootBeanDefinition()"); + assertThat(sourceFile).contains("setTargetType(TestBean.class)"); + assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateBeanDefinitionMethodSpecifiesBeanClassIfSet() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(TestBean.class)"); + assertThat(sourceFile).doesNotContain("setTargetType("); + assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateBeanDefinitionMethodSpecifiesBeanClassAndTargetTypIfDifferent() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class); + beanDefinition.setTargetType(Implementation.class); + beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean")); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(TestHierarchy.One.class)"); + assertThat(sourceFile).contains("setTargetType(TestHierarchy.Implementation.class)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateBeanDefinitionMethodUSeBeanClassNameIfNotReachable() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(PackagePrivateTestBean.class); + beanDefinition.setTargetType(TestBean.class); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(\"org.springframework.beans.factory.aot.PackagePrivateTestBean\""); + assertThat(sourceFile).contains("setTargetType(TestBean.class)"); assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); assertThat(actual).isInstanceOf(RootBeanDefinition.class); }); @@ -116,7 +183,6 @@ class BeanDefinitionMethodGeneratorTests { compile(method, (actual, compiled) -> { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); - assertThat(sourceFile).contains("beanType = TestBean.class"); assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); assertThat(actual).isInstanceOf(RootBeanDefinition.class); }); @@ -183,12 +249,26 @@ class BeanDefinitionMethodGeneratorTests { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); assertThat(sourceFile).contains( - "beanType = ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class)"); + "setTargetType(ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class))"); assertThat(sourceFile).contains("setInstanceSupplier(GenericBean::new)"); assertThat(actual).isInstanceOf(RootBeanDefinition.class); }); } + @Test + void generateBeanDefinitionMethodWhenHasExplicitResolvableType() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class); + beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean")); + beanDefinition.setTargetType(Two.class); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> assertThat(actual.getResolvableType().resolve()).isEqualTo(Two.class)); + } + @Test void generateBeanDefinitionMethodWhenHasInstancePostProcessorGeneratesMethod() { RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class)); diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/TestHierarchy.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/TestHierarchy.java new file mode 100644 index 00000000000..df47ee1637e --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/TestHierarchy.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans.factory.aot; + +/** + * A hierarchy where the exposed type of a bean is a partial signature. + * + * @author Stephane Nicoll + */ +public class TestHierarchy { + + public interface One { + } + + public interface Two { + } + + public static class Implementation implements One, Two { + } + + public static One oneBean() { + return new Implementation(); + } + +} diff --git a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java index a2ab23b8ade..eb155764afc 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java @@ -49,8 +49,13 @@ import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.BeanPostProcessor; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Implementation; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.One; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Two; import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.AnnotationConfigUtils; @@ -79,6 +84,7 @@ import org.springframework.core.test.tools.CompileWithForkedClassLoader; import org.springframework.core.test.tools.Compiled; import org.springframework.core.test.tools.TestCompiler; import org.springframework.mock.env.MockEnvironment; +import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -331,6 +337,22 @@ class ApplicationContextAotGeneratorTests { }); } + @Test // gh-30689 + void processAheadOfTimeWithExplicitResolvableType() { + GenericApplicationContext applicationContext = new GenericApplicationContext(); + DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class); + beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean")); + // Override target type + beanDefinition.setTargetType(Two.class); + beanFactory.registerBeanDefinition("hierarchyBean", beanDefinition); + testCompiledResult(applicationContext, (initializer, compiled) -> { + GenericApplicationContext freshApplicationContext = toFreshApplicationContext(initializer); + assertThat(freshApplicationContext.getBean(Two.class)) + .isInstanceOf(Implementation.class); + }); + } + @Nested @CompileWithForkedClassLoader class ConfigurationClassCglibProxy {