Improve location of generated bean definitions of FactoryBeans

This commit improves the location of generated bean definitions for
FactoryBean implementations by checking the type that the factory
bean generates, rather than the factory bean implementation itself.

Closes gh-28812
This commit is contained in:
Stephane Nicoll 2022-07-19 14:59:07 +02:00
parent c0bea373a2
commit 85d4a79cdc
2 changed files with 170 additions and 4 deletions

View File

@ -16,12 +16,14 @@
package org.springframework.beans.factory.aot;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.util.List;
import java.util.function.Predicate;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanDefinitionHolder;
import org.springframework.beans.factory.support.InstanceSupplier;
@ -69,14 +71,21 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments
public Class<?> getTarget(RegisteredBean registeredBean,
Executable constructorOrFactoryMethod) {
Class<?> target = ClassUtils
.getUserClass(constructorOrFactoryMethod.getDeclaringClass());
Class<?> target = extractDeclaringClass(constructorOrFactoryMethod);
while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) {
target = registeredBean.getParent().getBeanClass();
}
return target;
}
private Class<?> extractDeclaringClass(Executable executable) {
Class<?> declaringClass = ClassUtils.getUserClass(executable.getDeclaringClass());
if (executable instanceof Constructor<?> && FactoryBean.class.isAssignableFrom(declaringClass)) {
return ResolvableType.forType(declaringClass).as(FactoryBean.class).getGeneric(0).toClass();
}
return executable.getDeclaringClass();
}
@Override
public CodeBlock generateNewBeanDefinitionCode(GenerationContext generationContext,
ResolvableType beanType, BeanRegistrationCode beanRegistrationCode) {
@ -107,7 +116,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments
generationContext.getRuntimeHints(), attributeFilter,
beanRegistrationCode.getMethods(),
(name, value) -> generateValueCode(generationContext, name, value))
.generateCode(beanDefinition);
.generateCode(beanDefinition);
}
@Nullable
@ -171,7 +180,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments
return new InstanceSupplierCodeGenerator(generationContext,
beanRegistrationCode.getClassName(),
beanRegistrationCode.getMethods(), allowDirectSupplierShortcut)
.generateCode(this.registeredBean, constructorOrFactoryMethod);
.generateCode(this.registeredBean, constructorOrFactoryMethod);
}
@Override

View File

@ -0,0 +1,157 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.beans.factory.aot;
import java.lang.reflect.Method;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.annotation.InjectAnnotationBeanPostProcessorTests.StringFactoryBean;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.testfixture.beans.factory.DummyFactory;
import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode;
import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link DefaultBeanRegistrationCodeFragments}.
*
* @author Stephane Nicoll
*/
class DefaultBeanRegistrationCodeFragmentsTests {
private final BeanRegistrationsCode beanRegistrationsCode = new MockBeanRegistrationsCode(new TestGenerationContext());
private final DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
@Test
void getTargetOnConstructor() {
RegisteredBean registeredBean = registerTestBean(TestBean.class);
assertThat(createInstance(registeredBean).getTarget(registeredBean,
TestBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class);
}
@Test
void getTargetOnConstructorToFactoryBean() {
RegisteredBean registeredBean = registerTestBean(TestBean.class);
assertThat(createInstance(registeredBean).getTarget(registeredBean,
TestBeanFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class);
}
@Test
void getTargetOnMethod() {
RegisteredBean registeredBean = registerTestBean(TestBean.class);
Method method = ReflectionUtils.findMethod(TestBeanFactoryBean.class, "getObject");
assertThat(method).isNotNull();
assertThat(createInstance(registeredBean).getTarget(registeredBean,
method)).isEqualTo(TestBeanFactoryBean.class);
}
@Test
void getTargetOnMethodWithInnerBeanInJavaPackage() {
RegisteredBean registeredBean = registerTestBean(TestBean.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(String.class));
Method method = ReflectionUtils.findMethod(getClass(), "createString");
assertThat(method).isNotNull();
assertThat(createInstance(innerBean).getTarget(innerBean,
method)).isEqualTo(getClass());
}
@Test
void getTargetOnConstructorWithInnerBeanInJavaPackage() {
RegisteredBean registeredBean = registerTestBean(TestBean.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(String.class));
assertThat(createInstance(innerBean).getTarget(innerBean,
String.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class);
}
@Test
void getTargetOnConstructorWithInnerBeanOnTypeInJavaPackage() {
RegisteredBean registeredBean = registerTestBean(TestBean.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
new RootBeanDefinition(StringFactoryBean.class));
assertThat(createInstance(innerBean).getTarget(innerBean,
StringFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class);
}
@Test
void getTargetOnMethodWithInnerBeanInRegularPackage() {
RegisteredBean registeredBean = registerTestBean(DummyFactory.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(TestBean.class));
Method method = ReflectionUtils.findMethod(TestBeanFactoryBean.class, "getObject");
assertThat(method).isNotNull();
assertThat(createInstance(innerBean).getTarget(innerBean, method)).isEqualTo(TestBeanFactoryBean.class);
}
@Test
void getTargetOnConstructorWithInnerBeanInRegularPackage() {
RegisteredBean registeredBean = registerTestBean(DummyFactory.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(TestBean.class));
assertThat(createInstance(innerBean).getTarget(innerBean,
TestBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class);
}
@Test
void getTargetOnConstructorWithInnerBeanOnFactoryBeanOnTypeInRegularPackage() {
RegisteredBean registeredBean = registerTestBean(DummyFactory.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
new RootBeanDefinition(TestBean.class));
assertThat(createInstance(innerBean).getTarget(innerBean,
TestBeanFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class);
}
private RegisteredBean registerTestBean(Class<?> beanType) {
this.beanFactory.registerBeanDefinition("testBean",
new RootBeanDefinition(beanType));
return RegisteredBean.of(this.beanFactory, "testBean");
}
private BeanRegistrationCodeFragments createInstance(RegisteredBean registeredBean) {
return new DefaultBeanRegistrationCodeFragments(this.beanRegistrationsCode, registeredBean, new BeanDefinitionMethodGeneratorFactory(this.beanFactory));
}
@SuppressWarnings("unused")
static String createString() {
return "Test";
}
@SuppressWarnings("unused")
static class TestBean {
}
static class TestBeanFactoryBean implements FactoryBean<TestBean> {
@Override
public TestBean getObject() throws Exception {
return new TestBean();
}
@Override
public Class<?> getObjectType() {
return TestBean.class;
}
}
}