From 85d4a79cdcab3864d942e69f1f85b3b592b190a8 Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Tue, 19 Jul 2022 14:59:07 +0200 Subject: [PATCH] Improve location of generated bean definitions of FactoryBeans This commit improves the location of generated bean definitions for FactoryBean implementations by checking the type that the factory bean generates, rather than the factory bean implementation itself. Closes gh-28812 --- .../DefaultBeanRegistrationCodeFragments.java | 17 +- ...ultBeanRegistrationCodeFragmentsTests.java | 157 ++++++++++++++++++ 2 files changed, 170 insertions(+), 4 deletions(-) create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index cf4e4f045c..97a3df3fd6 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -16,12 +16,14 @@ package org.springframework.beans.factory.aot; +import java.lang.reflect.Constructor; import java.lang.reflect.Executable; import java.util.List; import java.util.function.Predicate; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionHolder; import org.springframework.beans.factory.support.InstanceSupplier; @@ -69,14 +71,21 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments public Class getTarget(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { - Class target = ClassUtils - .getUserClass(constructorOrFactoryMethod.getDeclaringClass()); + Class target = extractDeclaringClass(constructorOrFactoryMethod); while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) { target = registeredBean.getParent().getBeanClass(); } return target; } + private Class extractDeclaringClass(Executable executable) { + Class declaringClass = ClassUtils.getUserClass(executable.getDeclaringClass()); + if (executable instanceof Constructor && FactoryBean.class.isAssignableFrom(declaringClass)) { + return ResolvableType.forType(declaringClass).as(FactoryBean.class).getGeneric(0).toClass(); + } + return executable.getDeclaringClass(); + } + @Override public CodeBlock generateNewBeanDefinitionCode(GenerationContext generationContext, ResolvableType beanType, BeanRegistrationCode beanRegistrationCode) { @@ -107,7 +116,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments generationContext.getRuntimeHints(), attributeFilter, beanRegistrationCode.getMethods(), (name, value) -> generateValueCode(generationContext, name, value)) - .generateCode(beanDefinition); + .generateCode(beanDefinition); } @Nullable @@ -171,7 +180,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments return new InstanceSupplierCodeGenerator(generationContext, beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut) - .generateCode(this.registeredBean, constructorOrFactoryMethod); + .generateCode(this.registeredBean, constructorOrFactoryMethod); } @Override diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java new file mode 100644 index 0000000000..6ac4dfef37 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java @@ -0,0 +1,157 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.lang.reflect.Method; + +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.annotation.InjectAnnotationBeanPostProcessorTests.StringFactoryBean; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.factory.DummyFactory; +import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link DefaultBeanRegistrationCodeFragments}. + * + * @author Stephane Nicoll + */ +class DefaultBeanRegistrationCodeFragmentsTests { + + private final BeanRegistrationsCode beanRegistrationsCode = new MockBeanRegistrationsCode(new TestGenerationContext()); + + private final DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + + @Test + void getTargetOnConstructor() { + RegisteredBean registeredBean = registerTestBean(TestBean.class); + assertThat(createInstance(registeredBean).getTarget(registeredBean, + TestBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class); + } + + @Test + void getTargetOnConstructorToFactoryBean() { + RegisteredBean registeredBean = registerTestBean(TestBean.class); + assertThat(createInstance(registeredBean).getTarget(registeredBean, + TestBeanFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class); + } + + @Test + void getTargetOnMethod() { + RegisteredBean registeredBean = registerTestBean(TestBean.class); + Method method = ReflectionUtils.findMethod(TestBeanFactoryBean.class, "getObject"); + assertThat(method).isNotNull(); + assertThat(createInstance(registeredBean).getTarget(registeredBean, + method)).isEqualTo(TestBeanFactoryBean.class); + } + + @Test + void getTargetOnMethodWithInnerBeanInJavaPackage() { + RegisteredBean registeredBean = registerTestBean(TestBean.class); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(String.class)); + Method method = ReflectionUtils.findMethod(getClass(), "createString"); + assertThat(method).isNotNull(); + assertThat(createInstance(innerBean).getTarget(innerBean, + method)).isEqualTo(getClass()); + } + + @Test + void getTargetOnConstructorWithInnerBeanInJavaPackage() { + RegisteredBean registeredBean = registerTestBean(TestBean.class); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(String.class)); + assertThat(createInstance(innerBean).getTarget(innerBean, + String.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class); + } + + @Test + void getTargetOnConstructorWithInnerBeanOnTypeInJavaPackage() { + RegisteredBean registeredBean = registerTestBean(TestBean.class); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", + new RootBeanDefinition(StringFactoryBean.class)); + assertThat(createInstance(innerBean).getTarget(innerBean, + StringFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class); + } + + @Test + void getTargetOnMethodWithInnerBeanInRegularPackage() { + RegisteredBean registeredBean = registerTestBean(DummyFactory.class); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(TestBean.class)); + Method method = ReflectionUtils.findMethod(TestBeanFactoryBean.class, "getObject"); + assertThat(method).isNotNull(); + assertThat(createInstance(innerBean).getTarget(innerBean, method)).isEqualTo(TestBeanFactoryBean.class); + } + + @Test + void getTargetOnConstructorWithInnerBeanInRegularPackage() { + RegisteredBean registeredBean = registerTestBean(DummyFactory.class); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(TestBean.class)); + assertThat(createInstance(innerBean).getTarget(innerBean, + TestBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class); + } + + @Test + void getTargetOnConstructorWithInnerBeanOnFactoryBeanOnTypeInRegularPackage() { + RegisteredBean registeredBean = registerTestBean(DummyFactory.class); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", + new RootBeanDefinition(TestBean.class)); + assertThat(createInstance(innerBean).getTarget(innerBean, + TestBeanFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class); + } + + + private RegisteredBean registerTestBean(Class beanType) { + this.beanFactory.registerBeanDefinition("testBean", + new RootBeanDefinition(beanType)); + return RegisteredBean.of(this.beanFactory, "testBean"); + } + + private BeanRegistrationCodeFragments createInstance(RegisteredBean registeredBean) { + return new DefaultBeanRegistrationCodeFragments(this.beanRegistrationsCode, registeredBean, new BeanDefinitionMethodGeneratorFactory(this.beanFactory)); + } + + @SuppressWarnings("unused") + static String createString() { + return "Test"; + } + + @SuppressWarnings("unused") + static class TestBean { + + } + + + static class TestBeanFactoryBean implements FactoryBean { + + @Override + public TestBean getObject() throws Exception { + return new TestBean(); + } + + @Override + public Class getObjectType() { + return TestBean.class; + } + } + +}