diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java index d176a38e03..a6c518e120 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java @@ -96,7 +96,7 @@ class BeanDefinitionMethodGenerator { BeanRegistrationCodeFragments codeFragments = getCodeFragments(generationContext, beanRegistrationsCode); ClassName target = codeFragments.getTarget(this.registeredBean, this.constructorOrFactoryMethod); - if (!target.canonicalName().startsWith("java.")) { + if (isWritablePackageName(target)) { GeneratedClass generatedClass = lookupGeneratedClass(generationContext, target); GeneratedMethods generatedMethods = generatedClass.getMethods().withPrefix(getName()); GeneratedMethod generatedMethod = generateBeanDefinitionMethod(generationContext, @@ -109,6 +109,16 @@ class BeanDefinitionMethodGenerator { return generatedMethod.toMethodReference(); } + /** + * Specify if the {@link ClassName} belongs to a writable package. + * @param target the target to check + * @return {@code true} if generated code in that package is allowed + */ + private boolean isWritablePackageName(ClassName target) { + String packageName = target.packageName(); + return (!packageName.startsWith("java.") && !packageName.startsWith("javax.")); + } + /** * Return the {@link GeneratedClass} to use for the specified {@code target}. *
If the target class is an inner class, a corresponding inner class in diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index a5a03f2ab7..1715fc740a 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -24,6 +24,7 @@ import java.util.function.Predicate; import java.util.function.Supplier; import javax.lang.model.element.Modifier; +import javax.xml.parsers.DocumentBuilderFactory; import org.junit.jupiter.api.Test; @@ -455,6 +456,37 @@ class BeanDefinitionMethodGeneratorTests { }); } + @Test + void generateBeanDefinitionMethodWhenBeanIsInJavaPackage() { + RootBeanDefinition beanDefinition = (RootBeanDefinition) BeanDefinitionBuilder + .rootBeanDefinition(String.class).addConstructorArgValue("test").getBeanDefinition(); + testBeanDefinitionMethodInCurrentFile(String.class, beanDefinition); + } + + @Test + void generateBeanDefinitionMethodWhenBeanIsInJavaxPackage() { + RootBeanDefinition beanDefinition = (RootBeanDefinition) BeanDefinitionBuilder + .rootBeanDefinition(DocumentBuilderFactory.class).setFactoryMethod("newDefaultInstance").getBeanDefinition(); + testBeanDefinitionMethodInCurrentFile(DocumentBuilderFactory.class, beanDefinition); + } + + private void testBeanDefinitionMethodInCurrentFile(Class> targetType, RootBeanDefinition beanDefinition) { + RegisteredBean registeredBean = registerBean(new RootBeanDefinition(beanDefinition)); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); + freshBeanFactory.registerBeanDefinition("test", actual); + Object bean = freshBeanFactory.getBean("test"); + assertThat(bean).isInstanceOf(targetType); + assertThat(compiled.getSourceFiles().stream().filter(sourceFile -> + sourceFile.getClassName().startsWith(targetType.getPackageName()))).isEmpty(); + }); + } + private RegisteredBean registerBean(RootBeanDefinition beanDefinition) { String beanName = "testBean"; this.beanFactory.registerBeanDefinition(beanName, beanDefinition);