Allow target to be a ClassName rather than an existing Class

See gh-29207
This commit is contained in:
Stephane Nicoll 2022-09-22 16:02:25 +02:00
parent 8ef850ff91
commit e6aef11b09
6 changed files with 43 additions and 35 deletions

View File

@ -37,6 +37,7 @@ import org.springframework.beans.factory.support.InstanceSupplier;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.ResolvableType;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.lang.Nullable;
@ -108,8 +109,8 @@ class ScopedProxyBeanRegistrationAotProcessor implements BeanRegistrationAotProc
}
@Override
public Class<?> getTarget(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) {
return this.targetBeanDefinition.getResolvableType().toClass();
public ClassName getTarget(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) {
return ClassName.get(this.targetBeanDefinition.getResolvableType().toClass());
}
@Override

View File

@ -94,9 +94,9 @@ class BeanDefinitionMethodGenerator {
registerRuntimeHintsIfNecessary(generationContext.getRuntimeHints());
BeanRegistrationCodeFragments codeFragments = getCodeFragments(generationContext,
beanRegistrationsCode);
Class<?> target = codeFragments.getTarget(this.registeredBean,
ClassName target = codeFragments.getTarget(this.registeredBean,
this.constructorOrFactoryMethod);
if (!target.getName().startsWith("java.")) {
if (!target.canonicalName().startsWith("java.")) {
GeneratedClass generatedClass = generationContext.getGeneratedClasses()
.getOrAddForFeatureComponent("BeanDefinitions", target, type -> {
type.addJavadoc("Bean definitions for {@link $T}", target);

View File

@ -26,6 +26,7 @@ import org.springframework.beans.factory.support.InstanceSupplier;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.ResolvableType;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
/**
@ -52,9 +53,9 @@ public interface BeanRegistrationCodeFragments {
* the code.
* @param registeredBean the registered bean
* @param constructorOrFactoryMethod the constructor or factory method
* @return the target class
* @return the target {@link ClassName}
*/
Class<?> getTarget(RegisteredBean registeredBean,
ClassName getTarget(RegisteredBean registeredBean,
Executable constructorOrFactoryMethod);
/**

View File

@ -26,6 +26,7 @@ import org.springframework.aot.generate.MethodReference;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.ResolvableType;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.util.Assert;
@ -51,7 +52,7 @@ public class BeanRegistrationCodeFragmentsDecorator implements BeanRegistrationC
}
@Override
public Class<?> getTarget(RegisteredBean registeredBean,
public ClassName getTarget(RegisteredBean registeredBean,
Executable constructorOrFactoryMethod) {
return this.delegate.getTarget(registeredBean, constructorOrFactoryMethod);

View File

@ -32,6 +32,7 @@ import org.springframework.beans.factory.support.InstanceSupplier;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.ResolvableType;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.ParameterizedTypeName;
import org.springframework.lang.Nullable;
@ -70,7 +71,7 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme
@Override
public Class<?> getTarget(RegisteredBean registeredBean,
public ClassName getTarget(RegisteredBean registeredBean,
Executable constructorOrFactoryMethod) {
Class<?> target = extractDeclaringClass(registeredBean.getBeanType(), constructorOrFactoryMethod);
@ -79,7 +80,7 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme
Assert.state(parent != null, "No parent available for inner bean");
target = parent.getBeanClass();
}
return target;
return ClassName.get(target);
}
private Class<?> extractDeclaringClass(ResolvableType beanType, Executable executable) {

View File

@ -34,6 +34,7 @@ import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean;
import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanConfiguration;
import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanFactoryBean;
import org.springframework.core.ResolvableType;
import org.springframework.javapoet.ClassName;
import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat;
@ -52,45 +53,45 @@ class DefaultBeanRegistrationCodeFragmentsTests {
@Test
void getTargetOnConstructor() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class);
assertThat(createInstance(registeredBean).getTarget(registeredBean,
SimpleBean.class.getDeclaredConstructors()[0])).isEqualTo(SimpleBean.class);
assertTarget(createInstance(registeredBean).getTarget(registeredBean,
SimpleBean.class.getDeclaredConstructors()[0]), SimpleBean.class);
}
@Test
void getTargetOnConstructorToPublicFactoryBean() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class);
assertThat(createInstance(registeredBean).getTarget(registeredBean,
SimpleBeanFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(SimpleBean.class);
assertTarget(createInstance(registeredBean).getTarget(registeredBean,
SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class);
}
@Test
void getTargetOnConstructorToPublicGenericFactoryBeanExtractTargetFromFactoryBeanType() {
RegisteredBean registeredBean = registerTestBean(ResolvableType
.forClassWithGenerics(GenericFactoryBean.class, SimpleBean.class));
assertThat(createInstance(registeredBean).getTarget(registeredBean,
GenericFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(SimpleBean.class);
assertTarget(createInstance(registeredBean).getTarget(registeredBean,
GenericFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class);
}
@Test
void getTargetOnConstructorToPublicGenericFactoryBeanWithBoundExtractTargetFromFactoryBeanType() {
RegisteredBean registeredBean = registerTestBean(ResolvableType
.forClassWithGenerics(NumberFactoryBean.class, Integer.class));
assertThat(createInstance(registeredBean).getTarget(registeredBean,
NumberFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(Integer.class);
assertTarget(createInstance(registeredBean).getTarget(registeredBean,
NumberFactoryBean.class.getDeclaredConstructors()[0]), Integer.class);
}
@Test
void getTargetOnConstructorToPublicGenericFactoryBeanUseBeanTypeAsFallback() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class);
assertThat(createInstance(registeredBean).getTarget(registeredBean,
GenericFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(SimpleBean.class);
assertTarget(createInstance(registeredBean).getTarget(registeredBean,
GenericFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class);
}
@Test
void getTargetOnConstructorToProtectedFactoryBean() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class);
assertThat(createInstance(registeredBean).getTarget(registeredBean,
PrivilegedTestBeanFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(
assertTarget(createInstance(registeredBean).getTarget(registeredBean,
PrivilegedTestBeanFactoryBean.class.getDeclaredConstructors()[0]),
PrivilegedTestBeanFactoryBean.class);
}
@ -99,8 +100,8 @@ class DefaultBeanRegistrationCodeFragmentsTests {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class);
Method method = ReflectionUtils.findMethod(SimpleBeanConfiguration.class, "simpleBean");
assertThat(method).isNotNull();
assertThat(createInstance(registeredBean).getTarget(registeredBean,
method)).isEqualTo(SimpleBeanConfiguration.class);
assertTarget(createInstance(registeredBean).getTarget(registeredBean, method),
SimpleBeanConfiguration.class);
}
@Test
@ -110,16 +111,15 @@ class DefaultBeanRegistrationCodeFragmentsTests {
new RootBeanDefinition(String.class));
Method method = ReflectionUtils.findMethod(getClass(), "createString");
assertThat(method).isNotNull();
assertThat(createInstance(innerBean).getTarget(innerBean,
method)).isEqualTo(getClass());
assertTarget(createInstance(innerBean).getTarget(innerBean, method), getClass());
}
@Test
void getTargetOnConstructorWithInnerBeanInJavaPackage() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(String.class));
assertThat(createInstance(innerBean).getTarget(innerBean,
String.class.getDeclaredConstructors()[0])).isEqualTo(SimpleBean.class);
assertTarget(createInstance(innerBean).getTarget(innerBean,
String.class.getDeclaredConstructors()[0]), SimpleBean.class);
}
@Test
@ -127,8 +127,8 @@ class DefaultBeanRegistrationCodeFragmentsTests {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
new RootBeanDefinition(StringFactoryBean.class));
assertThat(createInstance(innerBean).getTarget(innerBean,
StringFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(SimpleBean.class);
assertTarget(createInstance(innerBean).getTarget(innerBean,
StringFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class);
}
@Test
@ -138,8 +138,8 @@ class DefaultBeanRegistrationCodeFragmentsTests {
new RootBeanDefinition(SimpleBean.class));
Method method = ReflectionUtils.findMethod(SimpleBeanConfiguration.class, "simpleBean");
assertThat(method).isNotNull();
assertThat(createInstance(innerBean).getTarget(innerBean, method))
.isEqualTo(SimpleBeanConfiguration.class);
assertTarget(createInstance(innerBean).getTarget(innerBean, method),
SimpleBeanConfiguration.class);
}
@Test
@ -147,8 +147,8 @@ class DefaultBeanRegistrationCodeFragmentsTests {
RegisteredBean registeredBean = registerTestBean(DummyFactory.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
new RootBeanDefinition(SimpleBean.class));
assertThat(createInstance(innerBean).getTarget(innerBean,
SimpleBean.class.getDeclaredConstructors()[0])).isEqualTo(SimpleBean.class);
assertTarget(createInstance(innerBean).getTarget(innerBean,
SimpleBean.class.getDeclaredConstructors()[0]), SimpleBean.class);
}
@Test
@ -156,8 +156,12 @@ class DefaultBeanRegistrationCodeFragmentsTests {
RegisteredBean registeredBean = registerTestBean(DummyFactory.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
new RootBeanDefinition(SimpleBean.class));
assertThat(createInstance(innerBean).getTarget(innerBean,
SimpleBeanFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(SimpleBean.class);
assertTarget(createInstance(innerBean).getTarget(innerBean,
SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class);
}
private void assertTarget(ClassName target, Class<?> expected) {
assertThat(target).isEqualTo(ClassName.get(expected));
}