From 6c42f374c8df89273ae380d9a99409dab6c5fae1 Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Tue, 20 Jun 2023 16:44:01 +0200 Subject: [PATCH] Consistently set target type in generated code Previously, a bean definition that is optimized AOT could have different metadata based on whether its resolved type had a generic or not. This is due to RootBeanDefinition taking either a Class or a ResolvableType doing fundamentally different things. While the former sets the bean class which is to little use with an instance supplier, the latter specifies the target type of the bean. This commit sets the target type of the bean, using the existing setter methods that take either a class or a ResolvableType and set the same attribute consistently. Closes gh-30689 --- .../DefaultBeanRegistrationCodeFragments.java | 47 +++++++--- .../BeanDefinitionMethodGeneratorTests.java | 90 +++++++++++++++++-- .../beans/factory/aot/TestHierarchy.java | 39 ++++++++ .../ApplicationContextAotGeneratorTests.java | 22 +++++ 4 files changed, 180 insertions(+), 18 deletions(-) create mode 100644 spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/TestHierarchy.java diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index 38c0a23072a..26df73d45ee 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package org.springframework.beans.factory.aot; import java.lang.reflect.Constructor; import java.lang.reflect.Executable; +import java.lang.reflect.Modifier; import java.util.List; import java.util.function.Predicate; @@ -47,12 +48,6 @@ import org.springframework.util.ClassUtils; */ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragments { - /** - * The variable name used to hold the bean type. - */ - private static final String BEAN_TYPE_VARIABLE = "beanType"; - - private final BeanRegistrationsCode beanRegistrationsCode; private final RegisteredBean registeredBean; @@ -118,19 +113,45 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme ResolvableType beanType, BeanRegistrationCode beanRegistrationCode) { CodeBlock.Builder code = CodeBlock.builder(); - code.addStatement(generateBeanTypeCode(beanType)); + RootBeanDefinition mergedBeanDefinition = this.registeredBean.getMergedBeanDefinition(); + Class beanClass = (mergedBeanDefinition.hasBeanClass() + ? ClassUtils.getUserClass(mergedBeanDefinition.getBeanClass()) : null); + CodeBlock beanClassCode = generateBeanClassCode( + beanRegistrationCode.getClassName().packageName(), beanClass); code.addStatement("$T $L = new $T($L)", RootBeanDefinition.class, - BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, BEAN_TYPE_VARIABLE); + BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, beanClassCode); + if (targetTypeNecessary(beanType, beanClass)) { + code.addStatement("$L.setTargetType($L)", BEAN_DEFINITION_VARIABLE, + generateBeanTypeCode(beanType)); + } return code.build(); } + private CodeBlock generateBeanClassCode(String targetPackage, @Nullable Class beanClass) { + if (beanClass != null) { + if (Modifier.isPublic(beanClass.getModifiers()) || targetPackage.equals(beanClass.getPackageName())) { + return CodeBlock.of("$T.class", beanClass); + } + else { + return CodeBlock.of("$S", beanClass.getName()); + } + } + return CodeBlock.of(""); + } + private CodeBlock generateBeanTypeCode(ResolvableType beanType) { if (!beanType.hasGenerics()) { - return CodeBlock.of("$T $L = $T.class", Class.class, BEAN_TYPE_VARIABLE, - ClassUtils.getUserClass(beanType.toClass())); + return CodeBlock.of("$T.class", ClassUtils.getUserClass(beanType.toClass())); } - return CodeBlock.of("$T $L = $L", ResolvableType.class, BEAN_TYPE_VARIABLE, - ResolvableTypeCodeGenerator.generateCode(beanType)); + return ResolvableTypeCodeGenerator.generateCode(beanType); + } + + private boolean targetTypeNecessary(ResolvableType beanType, @Nullable Class beanClass) { + if (beanType.hasGenerics() || beanClass == null) { + return true; + } + return (!beanType.toClass().equals(beanClass) + || this.registeredBean.getMergedBeanDefinition().getFactoryMethodName() != null); } @Override diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index f108a3b7b3b..3d490ec1a23 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -47,6 +47,10 @@ import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.factory.aot.InnerBeanConfiguration; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Implementation; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.One; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Two; import org.springframework.core.ResolvableType; import org.springframework.core.test.io.support.MockSpringFactoriesLoader; import org.springframework.core.test.tools.CompileWithForkedClassLoader; @@ -56,6 +60,7 @@ import org.springframework.core.test.tools.TestCompiler; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -89,8 +94,10 @@ class BeanDefinitionMethodGeneratorTests { @Test - void generateBeanDefinitionMethodGeneratesMethod() { - RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class)); + void generateBeanDefinitionMethodWithOnlyTargetTypeDoesNotSetBeanClass() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(); + beanDefinition.setTargetType(TestBean.class); + RegisteredBean registeredBean = registerBean(beanDefinition); BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( this.methodGeneratorFactory, registeredBean, null, Collections.emptyList()); @@ -99,7 +106,67 @@ class BeanDefinitionMethodGeneratorTests { compile(method, (actual, compiled) -> { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); - assertThat(sourceFile).contains("beanType = TestBean.class"); + assertThat(sourceFile).contains("new RootBeanDefinition()"); + assertThat(sourceFile).contains("setTargetType(TestBean.class)"); + assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateBeanDefinitionMethodSpecifiesBeanClassIfSet() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(TestBean.class)"); + assertThat(sourceFile).doesNotContain("setTargetType("); + assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateBeanDefinitionMethodSpecifiesBeanClassAndTargetTypIfDifferent() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class); + beanDefinition.setTargetType(Implementation.class); + beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean")); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(TestHierarchy.One.class)"); + assertThat(sourceFile).contains("setTargetType(TestHierarchy.Implementation.class)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateBeanDefinitionMethodUSeBeanClassNameIfNotReachable() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(PackagePrivateTestBean.class); + beanDefinition.setTargetType(TestBean.class); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(\"org.springframework.beans.factory.aot.PackagePrivateTestBean\""); + assertThat(sourceFile).contains("setTargetType(TestBean.class)"); assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); assertThat(actual).isInstanceOf(RootBeanDefinition.class); }); @@ -116,7 +183,6 @@ class BeanDefinitionMethodGeneratorTests { compile(method, (actual, compiled) -> { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); - assertThat(sourceFile).contains("beanType = TestBean.class"); assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); assertThat(actual).isInstanceOf(RootBeanDefinition.class); }); @@ -183,12 +249,26 @@ class BeanDefinitionMethodGeneratorTests { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); assertThat(sourceFile).contains( - "beanType = ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class)"); + "setTargetType(ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class))"); assertThat(sourceFile).contains("setInstanceSupplier(GenericBean::new)"); assertThat(actual).isInstanceOf(RootBeanDefinition.class); }); } + @Test + void generateBeanDefinitionMethodWhenHasExplicitResolvableType() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class); + beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean")); + beanDefinition.setTargetType(Two.class); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> assertThat(actual.getResolvableType().resolve()).isEqualTo(Two.class)); + } + @Test void generateBeanDefinitionMethodWhenHasInstancePostProcessorGeneratesMethod() { RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class)); diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/TestHierarchy.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/TestHierarchy.java new file mode 100644 index 00000000000..df47ee1637e --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/TestHierarchy.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans.factory.aot; + +/** + * A hierarchy where the exposed type of a bean is a partial signature. + * + * @author Stephane Nicoll + */ +public class TestHierarchy { + + public interface One { + } + + public interface Two { + } + + public static class Implementation implements One, Two { + } + + public static One oneBean() { + return new Implementation(); + } + +} diff --git a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java index 01a210cd0d5..1ad30149c49 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java @@ -45,8 +45,13 @@ import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.BeanPostProcessor; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Implementation; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.One; +import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Two; import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.AnnotationConfigUtils; @@ -75,6 +80,7 @@ import org.springframework.core.test.tools.CompileWithForkedClassLoader; import org.springframework.core.test.tools.Compiled; import org.springframework.core.test.tools.TestCompiler; import org.springframework.mock.env.MockEnvironment; +import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -327,6 +333,22 @@ class ApplicationContextAotGeneratorTests { }); } + @Test // gh-30689 + void processAheadOfTimeWithExplicitResolvableType() { + GenericApplicationContext applicationContext = new GenericApplicationContext(); + DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class); + beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean")); + // Override target type + beanDefinition.setTargetType(Two.class); + beanFactory.registerBeanDefinition("hierarchyBean", beanDefinition); + testCompiledResult(applicationContext, (initializer, compiled) -> { + GenericApplicationContext freshApplicationContext = toFreshApplicationContext(initializer); + assertThat(freshApplicationContext.getBean(Two.class)) + .isInstanceOf(Implementation.class); + }); + } + @Nested @CompileWithForkedClassLoader class ConfigurationClassCglibProxy {