Consistently set target type in generated code

Previously, a bean definition that is optimized AOT could have
different metadata based on whether its resolved type had a generic or
not. This is due to RootBeanDefinition taking either a Class or a
ResolvableType doing fundamentally different things. While the former
sets the bean class which is to little use with an instance supplier,
the latter specifies the target type of the bean.

This commit sets the target type of the bean, using the existing
setter methods that take either a class or a ResolvableType and set the
same attribute consistently.

Closes gh-30689
This commit is contained in:
Stephane Nicoll 2023-06-20 16:44:01 +02:00
parent 4879b56bb9
commit 6c42f374c8
4 changed files with 180 additions and 18 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@ package org.springframework.beans.factory.aot;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Modifier;
import java.util.List;
import java.util.function.Predicate;
@ -47,12 +48,6 @@ import org.springframework.util.ClassUtils;
*/
class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragments {
/**
* The variable name used to hold the bean type.
*/
private static final String BEAN_TYPE_VARIABLE = "beanType";
private final BeanRegistrationsCode beanRegistrationsCode;
private final RegisteredBean registeredBean;
@ -118,19 +113,45 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme
ResolvableType beanType, BeanRegistrationCode beanRegistrationCode) {
CodeBlock.Builder code = CodeBlock.builder();
code.addStatement(generateBeanTypeCode(beanType));
RootBeanDefinition mergedBeanDefinition = this.registeredBean.getMergedBeanDefinition();
Class<?> beanClass = (mergedBeanDefinition.hasBeanClass()
? ClassUtils.getUserClass(mergedBeanDefinition.getBeanClass()) : null);
CodeBlock beanClassCode = generateBeanClassCode(
beanRegistrationCode.getClassName().packageName(), beanClass);
code.addStatement("$T $L = new $T($L)", RootBeanDefinition.class,
BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, BEAN_TYPE_VARIABLE);
BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, beanClassCode);
if (targetTypeNecessary(beanType, beanClass)) {
code.addStatement("$L.setTargetType($L)", BEAN_DEFINITION_VARIABLE,
generateBeanTypeCode(beanType));
}
return code.build();
}
private CodeBlock generateBeanClassCode(String targetPackage, @Nullable Class<?> beanClass) {
if (beanClass != null) {
if (Modifier.isPublic(beanClass.getModifiers()) || targetPackage.equals(beanClass.getPackageName())) {
return CodeBlock.of("$T.class", beanClass);
}
else {
return CodeBlock.of("$S", beanClass.getName());
}
}
return CodeBlock.of("");
}
private CodeBlock generateBeanTypeCode(ResolvableType beanType) {
if (!beanType.hasGenerics()) {
return CodeBlock.of("$T<?> $L = $T.class", Class.class, BEAN_TYPE_VARIABLE,
ClassUtils.getUserClass(beanType.toClass()));
return CodeBlock.of("$T.class", ClassUtils.getUserClass(beanType.toClass()));
}
return CodeBlock.of("$T $L = $L", ResolvableType.class, BEAN_TYPE_VARIABLE,
ResolvableTypeCodeGenerator.generateCode(beanType));
return ResolvableTypeCodeGenerator.generateCode(beanType);
}
private boolean targetTypeNecessary(ResolvableType beanType, @Nullable Class<?> beanClass) {
if (beanType.hasGenerics() || beanClass == null) {
return true;
}
return (!beanType.toClass().equals(beanClass)
|| this.registeredBean.getMergedBeanDefinition().getFactoryMethodName() != null);
}
@Override

View File

@ -47,6 +47,10 @@ import org.springframework.beans.testfixture.beans.TestBean;
import org.springframework.beans.testfixture.beans.factory.aot.InnerBeanConfiguration;
import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode;
import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Implementation;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.One;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Two;
import org.springframework.core.ResolvableType;
import org.springframework.core.test.io.support.MockSpringFactoriesLoader;
import org.springframework.core.test.tools.CompileWithForkedClassLoader;
@ -56,6 +60,7 @@ import org.springframework.core.test.tools.TestCompiler;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.ParameterizedTypeName;
import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@ -89,8 +94,10 @@ class BeanDefinitionMethodGeneratorTests {
@Test
void generateBeanDefinitionMethodGeneratesMethod() {
RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class));
void generateBeanDefinitionMethodWithOnlyTargetTypeDoesNotSetBeanClass() {
RootBeanDefinition beanDefinition = new RootBeanDefinition();
beanDefinition.setTargetType(TestBean.class);
RegisteredBean registeredBean = registerBean(beanDefinition);
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList());
@ -99,7 +106,67 @@ class BeanDefinitionMethodGeneratorTests {
compile(method, (actual, compiled) -> {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
assertThat(sourceFile).contains("beanType = TestBean.class");
assertThat(sourceFile).contains("new RootBeanDefinition()");
assertThat(sourceFile).contains("setTargetType(TestBean.class)");
assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)");
assertThat(actual).isInstanceOf(RootBeanDefinition.class);
});
}
@Test
void generateBeanDefinitionMethodSpecifiesBeanClassIfSet() {
RootBeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class);
RegisteredBean registeredBean = registerBean(beanDefinition);
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, this.beanRegistrationsCode);
compile(method, (actual, compiled) -> {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
assertThat(sourceFile).contains("new RootBeanDefinition(TestBean.class)");
assertThat(sourceFile).doesNotContain("setTargetType(");
assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)");
assertThat(actual).isInstanceOf(RootBeanDefinition.class);
});
}
@Test
void generateBeanDefinitionMethodSpecifiesBeanClassAndTargetTypIfDifferent() {
RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class);
beanDefinition.setTargetType(Implementation.class);
beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean"));
RegisteredBean registeredBean = registerBean(beanDefinition);
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, this.beanRegistrationsCode);
compile(method, (actual, compiled) -> {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
assertThat(sourceFile).contains("new RootBeanDefinition(TestHierarchy.One.class)");
assertThat(sourceFile).contains("setTargetType(TestHierarchy.Implementation.class)");
assertThat(actual).isInstanceOf(RootBeanDefinition.class);
});
}
@Test
void generateBeanDefinitionMethodUSeBeanClassNameIfNotReachable() {
RootBeanDefinition beanDefinition = new RootBeanDefinition(PackagePrivateTestBean.class);
beanDefinition.setTargetType(TestBean.class);
RegisteredBean registeredBean = registerBean(beanDefinition);
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, this.beanRegistrationsCode);
compile(method, (actual, compiled) -> {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
assertThat(sourceFile).contains("new RootBeanDefinition(\"org.springframework.beans.factory.aot.PackagePrivateTestBean\"");
assertThat(sourceFile).contains("setTargetType(TestBean.class)");
assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)");
assertThat(actual).isInstanceOf(RootBeanDefinition.class);
});
@ -116,7 +183,6 @@ class BeanDefinitionMethodGeneratorTests {
compile(method, (actual, compiled) -> {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
assertThat(sourceFile).contains("beanType = TestBean.class");
assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)");
assertThat(actual).isInstanceOf(RootBeanDefinition.class);
});
@ -183,12 +249,26 @@ class BeanDefinitionMethodGeneratorTests {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
assertThat(sourceFile).contains(
"beanType = ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class)");
"setTargetType(ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class))");
assertThat(sourceFile).contains("setInstanceSupplier(GenericBean::new)");
assertThat(actual).isInstanceOf(RootBeanDefinition.class);
});
}
@Test
void generateBeanDefinitionMethodWhenHasExplicitResolvableType() {
RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class);
beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean"));
beanDefinition.setTargetType(Two.class);
RegisteredBean registeredBean = registerBean(beanDefinition);
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, this.beanRegistrationsCode);
compile(method, (actual, compiled) -> assertThat(actual.getResolvableType().resolve()).isEqualTo(Two.class));
}
@Test
void generateBeanDefinitionMethodWhenHasInstancePostProcessorGeneratesMethod() {
RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class));

View File

@ -0,0 +1,39 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.beans.testfixture.beans.factory.aot;
/**
* A hierarchy where the exposed type of a bean is a partial signature.
*
* @author Stephane Nicoll
*/
public class TestHierarchy {
public interface One {
}
public interface Two {
}
public static class Implementation implements One, Two {
}
public static One oneBean() {
return new Implementation();
}
}

View File

@ -45,8 +45,13 @@ import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Implementation;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.One;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Two;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.AnnotationConfigUtils;
@ -75,6 +80,7 @@ import org.springframework.core.test.tools.CompileWithForkedClassLoader;
import org.springframework.core.test.tools.Compiled;
import org.springframework.core.test.tools.TestCompiler;
import org.springframework.mock.env.MockEnvironment;
import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat;
@ -327,6 +333,22 @@ class ApplicationContextAotGeneratorTests {
});
}
@Test // gh-30689
void processAheadOfTimeWithExplicitResolvableType() {
GenericApplicationContext applicationContext = new GenericApplicationContext();
DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory();
RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class);
beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean"));
// Override target type
beanDefinition.setTargetType(Two.class);
beanFactory.registerBeanDefinition("hierarchyBean", beanDefinition);
testCompiledResult(applicationContext, (initializer, compiled) -> {
GenericApplicationContext freshApplicationContext = toFreshApplicationContext(initializer);
assertThat(freshApplicationContext.getBean(Two.class))
.isInstanceOf(Implementation.class);
});
}
@Nested
@CompileWithForkedClassLoader
class ConfigurationClassCglibProxy {