From 4d87071f3a6d94977d471cd9a31e168f63f95442 Mon Sep 17 00:00:00 2001 From: Phillip Webb Date: Wed, 13 Apr 2022 20:20:59 -0700 Subject: [PATCH] Add AOT interfaces and classes to support bean factories Add AOT processor and contribution interfaces and classes to support the generation of code that can re-hydrate a bean factory. See gh-28414 --- .../AutowiredArgumentsCodeGenerator.java | 105 ++++ .../beans/factory/aot/AotFactoriesLoader.java | 103 ++++ .../aot/BeanDefinitionMethodGenerator.java | 217 +++++++ .../BeanDefinitionMethodGeneratorFactory.java | 140 +++++ ...BeanDefinitionPropertiesCodeGenerator.java | 264 +++++++++ ...nDefinitionPropertyValueCodeGenerator.java | 529 ++++++++++++++++++ ...nFactoryInitializationAotContribution.java | 41 ++ ...BeanFactoryInitializationAotProcessor.java | 52 ++ .../aot/BeanFactoryInitializationCode.java | 52 ++ .../aot/BeanRegistrationAotContribution.java | 40 ++ .../aot/BeanRegistrationAotProcessor.java | 50 ++ .../factory/aot/BeanRegistrationCode.java | 57 ++ .../aot/BeanRegistrationCodeFragments.java | 166 ++++++ ...anRegistrationCodeFragmentsCustomizer.java | 45 ++ .../aot/BeanRegistrationCodeGenerator.java | 99 ++++ .../aot/BeanRegistrationExcludeFilter.java | 40 ++ .../aot/BeanRegistrationsAotContribution.java | 135 +++++ .../aot/BeanRegistrationsAotProcessor.java | 55 ++ .../factory/aot/BeanRegistrationsCode.java | 44 ++ .../ConstructorOrFactoryMethodResolver.java | 450 +++++++++++++++ .../DefaultBeanRegistrationCodeFragments.java | 185 ++++++ .../aot/InstanceSupplierCodeGenerator.java | 364 ++++++++++++ .../aot/ResolvableTypeCodeGenerator.java | 69 +++ .../beans/factory/aot/package-info.java | 9 + .../resources/META-INF/spring/aot.factories | 2 + .../AutowiredArgumentsCodeGeneratorTests.java | 200 +++++++ .../factory/aot/AotFactoriesLoaderTests.java | 100 ++++ ...DefinitionMethodGeneratorFactoryTests.java | 164 ++++++ .../BeanDefinitionMethodGeneratorTests.java | 402 +++++++++++++ ...efinitionPropertiesCodeGeneratorTests.java | 441 +++++++++++++++ ...nitionPropertyValueCodeGeneratorTests.java | 483 ++++++++++++++++ ...BeanRegistrationsAotContributionTests.java | 178 ++++++ .../BeanRegistrationsAotProcessorTests.java | 49 ++ ...nstructorOrFactoryMethodResolverTests.java | 499 +++++++++++++++++ .../beans/factory/aot/EnumWithClassBody.java | 43 ++ .../aot/ExampleClass$$GeneratedBy.java | 26 + .../beans/factory/aot/ExampleClass.java | 26 + .../InstanceSupplierCodeGeneratorTests.java | 347 ++++++++++++ .../aot/MockBeanRegistrationsCode.java | 54 ++ .../factory/aot/PackagePrivateTestBean.java | 26 + .../TestBeanRegistrationsAotProcessor.java | 26 + ...TestBeanWithPackagePrivateConstructor.java | 24 + .../beans/TestBeanWithPrivateConstructor.java | 24 + .../beans/TestBeanWithPrivateMethod.java | 28 + .../beans/TestBeanWithPublicField.java | 23 + .../generator/SimpleConfiguration.java | 17 +- 46 files changed, 6492 insertions(+), 1 deletion(-) create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredArgumentsCodeGenerator.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/AotFactoriesLoader.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorFactory.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGenerator.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationAotContribution.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationAotProcessor.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationAotContribution.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationAotProcessor.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCode.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsCustomizer.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationExcludeFilter.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotProcessor.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsCode.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/ConstructorOrFactoryMethodResolver.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/ResolvableTypeCodeGenerator.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/aot/package-info.java create mode 100644 spring-beans/src/main/resources/META-INF/spring/aot.factories create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredArgumentsCodeGeneratorTests.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/AotFactoriesLoaderTests.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorFactoryTests.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGeneratorTests.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotProcessorTests.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/ConstructorOrFactoryMethodResolverTests.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/EnumWithClassBody.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/ExampleClass$$GeneratedBy.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/ExampleClass.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/MockBeanRegistrationsCode.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/aot/PackagePrivateTestBean.java create mode 100644 spring-beans/src/testFixtures/java/org/springframework/beans/factory/aot/TestBeanRegistrationsAotProcessor.java create mode 100644 spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPackagePrivateConstructor.java create mode 100644 spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPrivateConstructor.java create mode 100644 spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPrivateMethod.java create mode 100644 spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPublicField.java diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredArgumentsCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredArgumentsCodeGenerator.java new file mode 100644 index 00000000000..db68b7cba80 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredArgumentsCodeGenerator.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.annotation; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.function.Predicate; + +import org.springframework.javapoet.CodeBlock; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; + +/** + * Code generator to apply {@link AutowiredArguments}. + *

+ * Generates code in the form:

{@code
+ * args.get(0), args.get(1)
+ * }
or
{@code
+ * args.get(0, String.class), args.get(1, Integer.class)
+ * }
+ *

+ * The simpler form is only used if the target method or constructor is + * unambiguous. + * + * @author Phillip Webb + * @author Stephane Nicoll + * @since 6.0 + */ +public class AutowiredArgumentsCodeGenerator { + + private final Class target; + + private final Executable executable; + + + public AutowiredArgumentsCodeGenerator(Class target, Executable executable) { + this.target = target; + this.executable = executable; + } + + + public CodeBlock generateCode(Class[] parameterTypes) { + return generateCode(parameterTypes, 0, "args"); + } + + public CodeBlock generateCode(Class[] parameterTypes, int startIndex) { + return generateCode(parameterTypes, startIndex, "args"); + } + + public CodeBlock generateCode(Class[] parameterTypes, int startIndex, + String variableName) { + + Assert.notNull(parameterTypes, "ParameterTypes must not be null"); + Assert.notNull(variableName, "VariableName must not be null"); + boolean ambiguous = isAmbiguous(); + CodeBlock.Builder builder = CodeBlock.builder(); + for (int i = startIndex; i < parameterTypes.length; i++) { + builder.add((i != startIndex) ? ", " : ""); + if (!ambiguous) { + builder.add("$L.get($L)", variableName, i - startIndex); + } + else { + builder.add("$L.get($L, $T.class)", variableName, i - startIndex, + parameterTypes[i]); + } + } + return builder.build(); + } + + private boolean isAmbiguous() { + if (this.executable instanceof Constructor constructor) { + return Arrays.stream(this.target.getDeclaredConstructors()) + .filter(Predicate.not(constructor::equals)) + .anyMatch(this::hasSameParameterCount); + } + if (this.executable instanceof Method method) { + return Arrays.stream(ReflectionUtils.getAllDeclaredMethods(this.target)) + .filter(Predicate.not(method::equals)) + .filter(candidate -> candidate.getName().equals(method.getName())) + .anyMatch(this::hasSameParameterCount); + } + return true; + } + + private boolean hasSameParameterCount(Executable executable) { + return this.executable.getParameterCount() == executable.getParameterCount(); + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/AotFactoriesLoader.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/AotFactoriesLoader.java new file mode 100644 index 00000000000..a57134b3bc1 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/AotFactoriesLoader.java @@ -0,0 +1,103 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.core.io.support.SpringFactoriesLoader; +import org.springframework.util.Assert; + +/** + * AOT specific factory loading mechanism for internal use within the framework. + *

+ * Loads and instantiates factories of a given type from + * {@value #FACTORIES_RESOURCE_LOCATION} and merges them with matching beans + * from a {@link ListableBeanFactory}. + * + * @author Phillip Webb + * @since 6.0 + * @see SpringFactoriesLoader + */ +public class AotFactoriesLoader { + + /** + * The location to look for AOT factories. + */ + public static final String FACTORIES_RESOURCE_LOCATION = "META-INF/spring/aot.factories"; + + + private final ListableBeanFactory beanFactory; + + private final SpringFactoriesLoader factoriesLoader; + + + /** + * Create a new {@link AotFactoriesLoader} instance backed by the given bean + * factory. + * @param beanFactory the bean factory to use + */ + public AotFactoriesLoader(ListableBeanFactory beanFactory) { + Assert.notNull(beanFactory, "BeanFactory must not be null"); + ClassLoader classLoader = (beanFactory instanceof ConfigurableBeanFactory configurableBeanFactory) + ? configurableBeanFactory.getBeanClassLoader() : null; + this.beanFactory = beanFactory; + this.factoriesLoader = SpringFactoriesLoader.forResourceLocation(classLoader, + FACTORIES_RESOURCE_LOCATION); + } + + /** + * Create a new {@link AotFactoriesLoader} instance backed by the given bean + * factory and loading items from the given {@link SpringFactoriesLoader} + * rather than from {@value #FACTORIES_RESOURCE_LOCATION}. + * @param beanFactory the bean factory to use + * @param factoriesLoader the factories loader to use + */ + public AotFactoriesLoader(ListableBeanFactory beanFactory, + SpringFactoriesLoader factoriesLoader) { + + Assert.notNull(beanFactory, "BeanFactory must not be null"); + Assert.notNull(factoriesLoader, "FactoriesLoader must not be null"); + this.beanFactory = beanFactory; + this.factoriesLoader = factoriesLoader; + } + + + /** + * Load items from factories file and merge them with any beans defined in + * the {@link DefaultListableBeanFactory}. + * @param the item type + * @param type the item type to load + * @return a list of loaded instances + */ + public List load(Class type) { + List result = new ArrayList<>(); + result.addAll(BeanFactoryUtils + .beansOfTypeIncludingAncestors(this.beanFactory, type, true, false) + .values()); + result.addAll(this.factoriesLoader.load(type)); + AnnotationAwareOrderComparator.sort(result); + return Collections.unmodifiableList(result); + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java new file mode 100644 index 00000000000..ad26b4a5a05 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java @@ -0,0 +1,217 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.lang.reflect.Executable; +import java.util.List; + +import javax.lang.model.element.Modifier; + +import org.springframework.aot.generate.ClassGenerator.JavaFileGenerator; +import org.springframework.aot.generate.GeneratedClass; +import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.MethodNameGenerator; +import org.springframework.aot.generate.MethodReference; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.TypeSpec; +import org.springframework.lang.Nullable; + +/** + * Generates a method that returns a {@link BeanDefinition} to be registered. + * + * @author Phillip Webb + * @since 6.0 + * @see BeanDefinitionMethodGeneratorFactory + */ +class BeanDefinitionMethodGenerator { + + private final BeanDefinitionMethodGeneratorFactory methodGeneratorFactory; + + private final RegisteredBean registeredBean; + + private final Executable constructorOrFactoryMethod; + + @Nullable + private final String innerBeanPropertyName; + + private final List aotContributions; + + private final List codeFragmentsCustomizers; + + + /** + * Create a new {@link BeanDefinitionMethodGenerator} instance. + * @param methodGeneratorFactory the method generator factory + * @param registeredBean the registered bean + * @param innerBeanPropertyName the inner bean property name + * @param aotContributions the AOT contributions + * @param codeFragmentsCustomizers the code fragments customizers + */ + BeanDefinitionMethodGenerator( + BeanDefinitionMethodGeneratorFactory methodGeneratorFactory, + RegisteredBean registeredBean, @Nullable String innerBeanPropertyName, + List aotContributions, + List codeFragmentsCustomizers) { + + this.methodGeneratorFactory = methodGeneratorFactory; + this.registeredBean = registeredBean; + this.constructorOrFactoryMethod = ConstructorOrFactoryMethodResolver + .resolve(registeredBean); + this.innerBeanPropertyName = innerBeanPropertyName; + this.aotContributions = aotContributions; + this.codeFragmentsCustomizers = codeFragmentsCustomizers; + } + + /** + * Generate the method that returns the {@link BeanDefinition} to be + * registered. + * @param generationContext the generation context + * @param beanRegistrationsCode the bean registrations code + * @return a reference to the generated method. + */ + MethodReference generateBeanDefinitionMethod(GenerationContext generationContext, + BeanRegistrationsCode beanRegistrationsCode) { + + BeanRegistrationCodeFragments codeFragments = getCodeFragments( + beanRegistrationsCode); + Class target = codeFragments.getTarget(this.registeredBean, + this.constructorOrFactoryMethod); + if (!target.getName().startsWith("java.")) { + GeneratedClass generatedClass = generationContext.getClassGenerator() + .getOrGenerateClass(new BeanDefinitionsJavaFileGenerator(target), + target, "BeanDefinitions"); + MethodGenerator methodGenerator = generatedClass.getMethodGenerator() + .withName(getName()); + GeneratedMethod generatedMethod = generateBeanDefinitionMethod( + generationContext, generatedClass.getName(), methodGenerator, + codeFragments, Modifier.PUBLIC); + return MethodReference.ofStatic(generatedClass.getName(), + generatedMethod.getName()); + } + MethodGenerator methodGenerator = beanRegistrationsCode.getMethodGenerator() + .withName(getName()); + GeneratedMethod generatedMethod = generateBeanDefinitionMethod(generationContext, + beanRegistrationsCode.getClassName(), methodGenerator, codeFragments, + Modifier.PRIVATE); + return MethodReference.ofStatic(beanRegistrationsCode.getClassName(), + generatedMethod.getName().toString()); + + } + + private GeneratedMethod generateBeanDefinitionMethod( + GenerationContext generationContext, ClassName className, + MethodGenerator methodGenerator, BeanRegistrationCodeFragments codeFragments, + Modifier modifier) { + + BeanRegistrationCodeGenerator codeGenerator = new BeanRegistrationCodeGenerator( + className, methodGenerator, this.registeredBean, + this.constructorOrFactoryMethod, codeFragments); + GeneratedMethod method = methodGenerator.generateMethod("get", "bean", + "definition"); + this.aotContributions.forEach(aotContribution -> aotContribution + .applyTo(generationContext, codeGenerator)); + return method.using(builder -> { + builder.addJavadoc("Get the $L definition for '$L'", + (!this.registeredBean.isInnerBean()) ? "bean" : "inner-bean", + getName()); + builder.addModifiers(modifier, Modifier.STATIC); + builder.returns(BeanDefinition.class); + builder.addCode(codeGenerator.generateCode(generationContext)); + }); + } + + private BeanRegistrationCodeFragments getCodeFragments( + BeanRegistrationsCode beanRegistrationsCode) { + + BeanRegistrationCodeFragments codeFragments = new DefaultBeanRegistrationCodeFragments( + beanRegistrationsCode, this.registeredBean, this.methodGeneratorFactory); + for (BeanRegistrationCodeFragmentsCustomizer customizer : this.codeFragmentsCustomizers) { + codeFragments = customizer.customizeBeanRegistrationCodeFragments( + this.registeredBean, codeFragments); + } + return codeFragments; + } + + private String getName() { + if (this.innerBeanPropertyName != null) { + return this.innerBeanPropertyName; + } + if (!this.registeredBean.isGeneratedBeanName()) { + return getSimpleBeanName(this.registeredBean.getBeanName()); + } + RegisteredBean nonGeneratedParent = this.registeredBean; + while (nonGeneratedParent != null && nonGeneratedParent.isGeneratedBeanName()) { + nonGeneratedParent = nonGeneratedParent.getParent(); + } + return (nonGeneratedParent != null) + ? MethodNameGenerator.join( + getSimpleBeanName(nonGeneratedParent.getBeanName()), "innerBean") + : "innerBean"; + } + + private String getSimpleBeanName(String beanName) { + int lastDot = beanName.lastIndexOf('.'); + beanName = (lastDot != -1) ? beanName.substring(lastDot + 1) : beanName; + int lastDollar = beanName.lastIndexOf('$'); + beanName = (lastDollar != -1) ? beanName.substring(lastDollar + 1) : beanName; + return beanName; + } + + + /** + * {@link BeanDefinitionsJavaFileGenerator} to create the + * {@code BeanDefinitions} file. + */ + private static class BeanDefinitionsJavaFileGenerator implements JavaFileGenerator { + + private final Class target; + + + BeanDefinitionsJavaFileGenerator(Class target) { + this.target = target; + } + + + @Override + public JavaFile generateJavaFile(ClassName className, GeneratedMethods methods) { + TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); + classBuilder.addJavadoc("Bean definitions for {@link $T}", this.target); + classBuilder.addModifiers(Modifier.PUBLIC); + methods.doWithMethodSpecs(classBuilder::addMethod); + return JavaFile.builder(className.packageName(), classBuilder.build()) + .build(); + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } + + @Override + public boolean equals(Object obj) { + return getClass() == obj.getClass(); + } + + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorFactory.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorFactory.java new file mode 100644 index 00000000000..f4b741394cb --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorFactory.java @@ -0,0 +1,140 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.core.log.LogMessage; +import org.springframework.lang.Nullable; +import org.springframework.util.ObjectUtils; + +/** + * Factory used to create a {@link BeanDefinitionMethodGenerator} instance for a + * {@link RegisteredBean}. + * + * @author Phillip Webb + * @since 6.0 + * @see BeanDefinitionMethodGenerator + * @see #getBeanDefinitionMethodGenerator(RegisteredBean, String) + */ +class BeanDefinitionMethodGeneratorFactory { + + private static final Log logger = LogFactory + .getLog(BeanDefinitionMethodGeneratorFactory.class); + + + private final List aotProcessors; + + private final List excludeFilters; + + private final List codeGenerationCustomizers; + + + /** + * Create a new {@link BeanDefinitionMethodGeneratorFactory} backed by the + * given {@link ConfigurableListableBeanFactory}. + * @param beanFactory the bean factory use + */ + BeanDefinitionMethodGeneratorFactory(ConfigurableListableBeanFactory beanFactory) { + this(new AotFactoriesLoader(beanFactory)); + } + + /** + * Create a new {@link BeanDefinitionMethodGeneratorFactory} backed by the + * given {@link AotFactoriesLoader}. + * @param loader the AOT factory loader to use + */ + BeanDefinitionMethodGeneratorFactory(AotFactoriesLoader loader) { + this.aotProcessors = loader.load(BeanRegistrationAotProcessor.class); + this.excludeFilters = loader.load(BeanRegistrationExcludeFilter.class); + this.codeGenerationCustomizers = loader + .load(BeanRegistrationCodeFragmentsCustomizer.class); + } + + + /** + * Return a {@link BeanDefinitionMethodGenerator} for the given + * {@link RegisteredBean} or {@code null} if the registered bean is excluded + * by a {@link BeanRegistrationExcludeFilter}. The resulting + * {@link BeanDefinitionMethodGenerator} will include all + * {@link BeanRegistrationAotProcessor} provided contributions. + * @param registeredBean the registered bean + * @return a new {@link BeanDefinitionMethodGenerator} instance or + * {@code null} + */ + @Nullable + BeanDefinitionMethodGenerator getBeanDefinitionMethodGenerator( + RegisteredBean registeredBean, @Nullable String innerBeanPropertyName) { + + if (isExcluded(registeredBean)) { + return null; + } + List contributions = getAotContributions( + registeredBean); + return new BeanDefinitionMethodGenerator(this, registeredBean, + innerBeanPropertyName, contributions, this.codeGenerationCustomizers); + } + + private boolean isExcluded(RegisteredBean registeredBean) { + if (isImplicitlyExcluded(registeredBean)) { + return true; + } + for (BeanRegistrationExcludeFilter excludeFilter : this.excludeFilters) { + if (excludeFilter.isExcluded(registeredBean)) { + logger.trace(LogMessage.format( + "Excluding registered bean '%s' from bean factory %s due to %s", + registeredBean.getBeanName(), + ObjectUtils.identityToString(registeredBean.getBeanFactory()), + excludeFilter.getClass().getName())); + return true; + } + } + return false; + } + + private boolean isImplicitlyExcluded(RegisteredBean registeredBean) { + Class beanClass = registeredBean.getBeanClass(); + return BeanFactoryInitializationAotProcessor.class.isAssignableFrom(beanClass) + || BeanRegistrationAotProcessor.class.isAssignableFrom(beanClass); + } + + private List getAotContributions( + RegisteredBean registeredBean) { + + String beanName = registeredBean.getBeanName(); + List contributions = new ArrayList<>(); + for (BeanRegistrationAotProcessor aotProcessor : this.aotProcessors) { + BeanRegistrationAotContribution contribution = aotProcessor + .processAheadOfTime(registeredBean); + if (contribution != null) { + logger.trace(LogMessage.format( + "Adding bean registration AOT contribution %S from %S to '%S'", + contribution.getClass().getName(), + aotProcessor.getClass().getName(), beanName)); + contributions.add(contribution); + } + } + return contributions; + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java new file mode 100644 index 00000000000..02d1d991638 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java @@ -0,0 +1,264 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.lang.reflect.Method; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiFunction; +import java.util.function.BiPredicate; +import java.util.function.Function; +import java.util.function.Predicate; + +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.beans.MutablePropertyValues; +import org.springframework.beans.PropertyValue; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; +import org.springframework.beans.factory.support.AbstractBeanDefinition; +import org.springframework.beans.factory.support.InstanceSupplier; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; +import org.springframework.util.ReflectionUtils; +import org.springframework.util.StringUtils; + +/** + * Internal code generator to set {@link RootBeanDefinition} properties. + *

+ * Generates code in the following form:

+ * beanDefinition.setPrimary(true);
+ * beanDefinition.setScope(BeanDefinition.SCOPE_PROTOTYPE);
+ * ...
+ * 
+ *

+ * The generated code expects the following variables to be available: + *

+ *

+ *

+ * Note that this generator does not set the {@link InstanceSupplier}. + * + * @author Phillip Webb + * @author Stephane Nicoll + * @since 6.0 + */ +class BeanDefinitionPropertiesCodeGenerator { + + private static final RootBeanDefinition DEFAULT_BEAN_DEFINITION = new RootBeanDefinition(); + + private static final String BEAN_DEFINITION_VARIABLE = BeanRegistrationCodeFragments.BEAN_DEFINITION_VARIABLE; + + + private final RuntimeHints hints; + + private final Predicate attributeFilter; + + private final BiFunction customValueCodeGenerator; + + private final BeanDefinitionPropertyValueCodeGenerator valueCodeGenerator; + + + BeanDefinitionPropertiesCodeGenerator(RuntimeHints hints, + Predicate attributeFilter, MethodGenerator methodGenerator, + BiFunction customValueCodeGenerator) { + + this.hints = hints; + this.attributeFilter = attributeFilter; + this.customValueCodeGenerator = customValueCodeGenerator; + this.valueCodeGenerator = new BeanDefinitionPropertyValueCodeGenerator( + methodGenerator); + } + + + CodeBlock generateCode(BeanDefinition beanDefinition) { + CodeBlock.Builder builder = CodeBlock.builder(); + addStatementForValue(builder, beanDefinition, BeanDefinition::isPrimary, + "$L.setPrimary($L)"); + addStatementForValue(builder, beanDefinition, BeanDefinition::getScope, + this::hasScope, "$L.setScope($S)"); + addStatementForValue(builder, beanDefinition, BeanDefinition::getDependsOn, + this::hasDependsOn, "$L.setDependsOn($L)", this::toStringVarArgs); + addStatementForValue(builder, beanDefinition, BeanDefinition::isAutowireCandidate, + "$L.setAutowireCandidate($L)"); + addStatementForValue(builder, beanDefinition, BeanDefinition::getRole, + this::hasRole, "$L.setRole($L)", this::toRole); + if (beanDefinition instanceof AbstractBeanDefinition abstractBeanDefinition) { + addStatementForValue(builder, beanDefinition, + AbstractBeanDefinition::getLazyInit, "$L.setLazyInit($L)"); + addStatementForValue(builder, beanDefinition, + AbstractBeanDefinition::isSynthetic, "$L.setSynthetic($L)"); + addInitDestroyMethods(builder, abstractBeanDefinition, + abstractBeanDefinition.getInitMethodNames(), + "$L.setInitMethodNames($L)"); + addInitDestroyMethods(builder, abstractBeanDefinition, + abstractBeanDefinition.getDestroyMethodNames(), + "$L.setDestroyMethodNames($L)"); + } + addConstructorArgumentValues(builder, beanDefinition); + addPropertyValues(builder, beanDefinition); + addAttributes(builder, beanDefinition); + return builder.build(); + } + + private void addInitDestroyMethods(Builder builder, + AbstractBeanDefinition beanDefinition, String[] methodNames, String format) { + + if (!ObjectUtils.isEmpty(methodNames)) { + Class beanUserClass = ClassUtils + .getUserClass(beanDefinition.getResolvableType().toClass()); + Builder arguments = CodeBlock.builder(); + for (int i = 0; i < methodNames.length; i++) { + String methodName = methodNames[i]; + if (!AbstractBeanDefinition.INFER_METHOD.equals(methodName)) { + arguments.add((i != 0) ? ", $S" : "$S", methodName); + addInitDestroyHint(beanUserClass, methodName); + } + } + builder.addStatement(format, BEAN_DEFINITION_VARIABLE, arguments.build()); + } + } + + private void addInitDestroyHint(Class beanUserClass, String methodName) { + Method method = ReflectionUtils.findMethod(beanUserClass, methodName); + if (method != null) { + this.hints.reflection().registerMethod(method); + } + } + + private void addConstructorArgumentValues(CodeBlock.Builder builder, + BeanDefinition beanDefinition) { + + Map argumentValues = beanDefinition + .getConstructorArgumentValues().getIndexedArgumentValues(); + if (!argumentValues.isEmpty()) { + argumentValues.forEach((index, valueHolder) -> { + String name = valueHolder.getName(); + Object value = valueHolder.getValue(); + CodeBlock code = this.customValueCodeGenerator.apply(name, value); + if (code == null) { + code = this.valueCodeGenerator.generateCode(value); + } + builder.addStatement( + "$L.getConstructorArgumentValues().addIndexedArgumentValue($L, $L)", + BEAN_DEFINITION_VARIABLE, index, code); + }); + } + } + + private void addPropertyValues(CodeBlock.Builder builder, + BeanDefinition beanDefinition) { + + MutablePropertyValues propertyValues = beanDefinition.getPropertyValues(); + if (!propertyValues.isEmpty()) { + for (PropertyValue propertyValue : propertyValues) { + String name = propertyValue.getName(); + Object value = propertyValue.getValue(); + CodeBlock code = this.customValueCodeGenerator.apply(name, value); + if (code == null) { + code = this.valueCodeGenerator.generateCode(value); + } + builder.addStatement("$L.getPropertyValues().addPropertyValue($S, $L)", + BEAN_DEFINITION_VARIABLE, propertyValue.getName(), code); + } + } + } + + private void addAttributes(CodeBlock.Builder builder, BeanDefinition beanDefinition) { + String[] attributeNames = beanDefinition.attributeNames(); + if (!ObjectUtils.isEmpty(attributeNames)) { + for (String attributeName : attributeNames) { + if (this.attributeFilter.test(attributeName)) { + CodeBlock value = this.valueCodeGenerator + .generateCode(beanDefinition.getAttribute(attributeName)); + builder.addStatement("$L.setAttribute($S, $L)", + BEAN_DEFINITION_VARIABLE, attributeName, value); + } + } + } + } + + private boolean hasScope(String defaultValue, String actualValue) { + return StringUtils.hasText(actualValue) + && !ConfigurableBeanFactory.SCOPE_SINGLETON.equals(actualValue); + } + + private boolean hasDependsOn(String[] defaultValue, String[] actualValue) { + return !ObjectUtils.isEmpty(actualValue); + } + + private boolean hasRole(int defaultValue, int actualValue) { + return actualValue != BeanDefinition.ROLE_APPLICATION; + } + + private CodeBlock toStringVarArgs(String[] strings) { + CodeBlock.Builder builder = CodeBlock.builder(); + for (int i = 0; i < strings.length; i++) { + builder.add((i != 0) ? ", " : ""); + builder.add("$S", strings[i]); + } + return builder.build(); + } + + private Object toRole(int value) { + return switch (value) { + case BeanDefinition.ROLE_INFRASTRUCTURE -> CodeBlock.builder() + .add("$T.ROLE_INFRASTRUCTURE", BeanDefinition.class).build(); + case BeanDefinition.ROLE_SUPPORT -> CodeBlock.builder() + .add("$T.ROLE_SUPPORT", BeanDefinition.class).build(); + default -> value; + }; + } + + private void addStatementForValue( + CodeBlock.Builder builder, BeanDefinition beanDefinition, + Function getter, String format) { + + addStatementForValue(builder, beanDefinition, getter, + (defaultValue, actualValue) -> !Objects.equals(defaultValue, actualValue), + format); + } + + private void addStatementForValue( + CodeBlock.Builder builder, BeanDefinition beanDefinition, + Function getter, BiPredicate filter, String format) { + + addStatementForValue(builder, beanDefinition, getter, filter, format, + actualValue -> actualValue); + } + + @SuppressWarnings("unchecked") + private void addStatementForValue( + CodeBlock.Builder builder, BeanDefinition beanDefinition, + Function getter, BiPredicate filter, String format, + Function formatter) { + + T defaultValue = getter.apply((B) DEFAULT_BEAN_DEFINITION); + T actualValue = getter.apply((B) beanDefinition); + if (filter.test(defaultValue, actualValue)) { + builder.addStatement(format, BEAN_DEFINITION_VARIABLE, + formatter.apply(actualValue)); + } + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGenerator.java new file mode 100644 index 00000000000..b8ea68b23c4 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGenerator.java @@ -0,0 +1,529 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; + +import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.MethodNameGenerator; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.BeanReference; +import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.support.ManagedList; +import org.springframework.beans.factory.support.ManagedMap; +import org.springframework.beans.factory.support.ManagedSet; +import org.springframework.core.ResolvableType; +import org.springframework.javapoet.AnnotationSpec; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; + +/** + * Internal code generator used to generate code for a single value contained in + * a {@link BeanDefinition} property. + * + * @author Stephane Nicoll + * @author Phillip Webb + * @since 6.0 + */ +class BeanDefinitionPropertyValueCodeGenerator { + + static final CodeBlock NULL_VALUE_CODE_BLOCK = CodeBlock.of("null"); + + private final MethodGenerator methodGenerator; + + private final List delegates; + + + BeanDefinitionPropertyValueCodeGenerator(MethodGenerator methodGenerator) { + this.methodGenerator = methodGenerator; + this.delegates = new ArrayList<>(); + this.delegates.add(new PrimitiveDelegate()); + this.delegates.add(new StringDelegate()); + this.delegates.add(new EnumDelegate()); + this.delegates.add(new ClassDelegate()); + this.delegates.add(new ResolvableTypeDelegate()); + this.delegates.add(new ArrayDelegate()); + this.delegates.add(new ManagedListDelegate()); + this.delegates.add(new ManagedSetDelegate()); + this.delegates.add(new ManagedMapDelegate()); + this.delegates.add(new ListDelegate()); + this.delegates.add(new SetDelegate()); + this.delegates.add(new MapDelegate()); + this.delegates.add(new BeanReferenceDelegate()); + } + + + CodeBlock generateCode(@Nullable Object value) { + ResolvableType type = (value != null) ? ResolvableType.forInstance(value) + : ResolvableType.NONE; + return generateCode(value, type); + } + + private CodeBlock generateCode(@Nullable Object value, ResolvableType type) { + if (value == null) { + return NULL_VALUE_CODE_BLOCK; + } + for (Delegate delegate : this.delegates) { + CodeBlock code = delegate.generateCode(value, type); + if (code != null) { + return code; + } + } + throw new IllegalArgumentException( + "'type' " + type + " must be supported for instance code generation"); + } + + + /** + * Internal delegate used to support generation for a specific type. + */ + @FunctionalInterface + private interface Delegate { + + @Nullable + CodeBlock generateCode(Object value, ResolvableType type); + + } + + + /** + * {@link Delegate} for {@code primitive} types. + */ + private class PrimitiveDelegate implements Delegate { + + private static final Map CHAR_ESCAPES; + + static { + Map escapes = new HashMap<>(); + escapes.put('\b', "\\b"); + escapes.put('\t', "\\t"); + escapes.put('\n', "\\n"); + escapes.put('\f', "\\f"); + escapes.put('\r', "\\r"); + escapes.put('\"', "\""); + escapes.put('\'', "\\'"); + escapes.put('\\', "\\\\"); + CHAR_ESCAPES = Collections.unmodifiableMap(escapes); + } + + @Override + @Nullable + public CodeBlock generateCode(Object value, ResolvableType type) { + if (value instanceof Boolean || value instanceof Integer) { + return CodeBlock.of("$L", value); + } + if (value instanceof Byte) { + return CodeBlock.of("(byte) $L", value); + } + if (value instanceof Short) { + return CodeBlock.of("(short) $L", value); + } + if (value instanceof Long) { + return CodeBlock.of("$LL", value); + } + if (value instanceof Float) { + return CodeBlock.of("$LF", value); + } + if (value instanceof Double) { + return CodeBlock.of("(double) $L", value); + } + if (value instanceof Character character) { + return CodeBlock.of("'$L'", escape(character)); + } + return null; + } + + private String escape(char ch) { + String escaped = CHAR_ESCAPES.get(ch); + if (escaped != null) { + return escaped; + } + return (!Character.isISOControl(ch)) ? Character.toString(ch) + : String.format("\\u%04x", (int) ch); + } + } + + + /** + * {@link Delegate} for {@link String} types. + */ + private class StringDelegate implements Delegate { + + @Override + @Nullable + public CodeBlock generateCode(Object value, ResolvableType type) { + if (value instanceof String) { + return CodeBlock.of("$S", value); + } + return null; + } + + } + + + /** + * {@link Delegate} for {@link Enum} types. + */ + private class EnumDelegate implements Delegate { + + @Override + @Nullable + public CodeBlock generateCode(Object value, ResolvableType type) { + if (value instanceof Enum enumValue) { + return CodeBlock.of("$T.$L", enumValue.getDeclaringClass(), + enumValue.name()); + } + return null; + } + + } + + + /** + * {@link Delegate} for {@link Class} types. + */ + private class ClassDelegate implements Delegate { + + @Override + @Nullable + public CodeBlock generateCode(Object value, ResolvableType type) { + if (value instanceof Class clazz) { + return CodeBlock.of("$T.class", ClassUtils.getUserClass(clazz)); + } + return null; + } + + } + + + /** + * {@link Delegate} for {@link ResolvableType} types. + */ + private class ResolvableTypeDelegate implements Delegate { + + @Override + @Nullable + public CodeBlock generateCode(Object value, ResolvableType type) { + if (value instanceof ResolvableType resolvableType) { + return ResolvableTypeCodeGenerator.generateCode(resolvableType); + } + return null; + } + + } + + + /** + * {@link Delegate} for {@code array} types. + */ + private class ArrayDelegate implements Delegate { + + @Override + @Nullable + public CodeBlock generateCode(@Nullable Object value, ResolvableType type) { + if (type.isArray()) { + ResolvableType componentType = type.getComponentType(); + int length = Array.getLength(value); + CodeBlock.Builder builder = CodeBlock.builder(); + builder.add("new $T {", type.toClass()); + for (int i = 0; i < length; i++) { + Object component = Array.get(value, i); + builder.add((i != 0) ? ", " : ""); + builder.add("$L", BeanDefinitionPropertyValueCodeGenerator.this + .generateCode(component, componentType)); + } + builder.add("}"); + return builder.build(); + } + return null; + } + + } + + + /** + * Abstract {@link Delegate} for {@code Collection} types. + */ + private abstract class CollectionDelegate> + implements Delegate { + + private final Class collectionType; + + private final CodeBlock emptyResult; + + public CollectionDelegate(Class collectionType, CodeBlock emptyResult) { + this.collectionType = collectionType; + this.emptyResult = emptyResult; + } + + @Override + @SuppressWarnings("unchecked") + @Nullable + public CodeBlock generateCode(Object value, ResolvableType type) { + if (this.collectionType.isInstance(value)) { + T collection = (T) value; + if (collection.isEmpty()) { + return this.emptyResult; + } + ResolvableType elementType = type.as(this.collectionType).getGeneric(); + return generateCollectionCode(elementType, collection); + } + return null; + } + + protected CodeBlock generateCollectionCode(ResolvableType elementType, + T collection) { + return generateCollectionOf(collection, this.collectionType, elementType); + } + + protected final CodeBlock generateCollectionOf(Collection collection, + Class collectionType, ResolvableType elementType) { + Builder builder = CodeBlock.builder(); + builder.add("$T.of(", collectionType); + Iterator iterator = collection.iterator(); + while (iterator.hasNext()) { + Object element = iterator.next(); + builder.add("$L", BeanDefinitionPropertyValueCodeGenerator.this + .generateCode(element, elementType)); + builder.add((!iterator.hasNext()) ? "" : ", "); + } + builder.add(")"); + return builder.build(); + } + + } + + + /** + * {@link Delegate} for {@link ManagedList} types. + */ + private class ManagedListDelegate extends CollectionDelegate> { + + public ManagedListDelegate() { + super(ManagedList.class, CodeBlock.of("new $T()", ManagedList.class)); + } + + } + + + /** + * {@link Delegate} for {@link ManagedSet} types. + */ + private class ManagedSetDelegate extends CollectionDelegate> { + + public ManagedSetDelegate() { + super(ManagedSet.class, CodeBlock.of("new $T()", ManagedSet.class)); + } + + } + + + /** + * {@link Delegate} for {@link ManagedMap} types. + */ + private class ManagedMapDelegate implements Delegate { + + private static final CodeBlock EMPTY_RESULT = CodeBlock.of("$T.ofEntries()", + ManagedMap.class); + + @Override + @Nullable + public CodeBlock generateCode(Object value, ResolvableType type) { + if (value instanceof ManagedMap managedMap) { + return generateManagedMapCode(type, managedMap); + } + return null; + } + + private CodeBlock generateManagedMapCode(ResolvableType type, + ManagedMap managedMap) { + if (managedMap.isEmpty()) { + return EMPTY_RESULT; + } + ResolvableType keyType = type.as(Map.class).getGeneric(0); + ResolvableType valueType = type.as(Map.class).getGeneric(1); + CodeBlock.Builder builder = CodeBlock.builder(); + builder.add("$T.ofEntries(", ManagedMap.class); + Iterator> iterator = managedMap.entrySet().iterator(); + while (iterator.hasNext()) { + Entry entry = iterator.next(); + builder.add("$T.entry($L,$L)", Map.class, + BeanDefinitionPropertyValueCodeGenerator.this + .generateCode(entry.getKey(), keyType), + BeanDefinitionPropertyValueCodeGenerator.this + .generateCode(entry.getValue(), valueType)); + builder.add((!iterator.hasNext()) ? "" : ", "); + } + builder.add(")"); + return builder.build(); + } + + } + + + /** + * {@link Delegate} for {@link List} types. + */ + private class ListDelegate extends CollectionDelegate> { + + ListDelegate() { + super(List.class, CodeBlock.of("$T.emptyList()", Collections.class)); + } + + } + + + /** + * {@link Delegate} for {@link Set} types. + */ + private class SetDelegate extends CollectionDelegate> { + + SetDelegate() { + super(Set.class, CodeBlock.of("$T.emptySet()", Collections.class)); + } + + @Override + protected CodeBlock generateCollectionCode(ResolvableType elementType, + Set set) { + if (set instanceof LinkedHashSet) { + return CodeBlock.of("new $T($L)", LinkedHashSet.class, + generateCollectionOf(set, List.class, elementType)); + } + set = orderForCodeConsistency(set); + return super.generateCollectionCode(elementType, set); + } + + private Set orderForCodeConsistency(Set set) { + return new TreeSet(set); + } + + } + + + /** + * {@link Delegate} for {@link Map} types. + */ + private class MapDelegate implements Delegate { + + private static final CodeBlock EMPTY_RESULT = CodeBlock.of("$T.emptyMap()", + Collections.class); + + @Override + @Nullable + public CodeBlock generateCode(Object value, ResolvableType type) { + if (value instanceof Map map) { + return generateMapCode(type, map); + } + return null; + } + + private CodeBlock generateMapCode(ResolvableType type, Map map) { + if (map.isEmpty()) { + return EMPTY_RESULT; + } + ResolvableType keyType = type.as(Map.class).getGeneric(0); + ResolvableType valueType = type.as(Map.class).getGeneric(1); + if (map instanceof LinkedHashMap) { + return generateLinkedHashMapCode(map, keyType, valueType); + } + map = orderForCodeConsistency(map); + boolean useOfEntries = map.size() > 10; + CodeBlock.Builder builder = CodeBlock.builder(); + builder.add("$T" + ((!useOfEntries) ? ".of(" : ".ofEntries("), Map.class); + Iterator> iterator = map.entrySet().iterator(); + while (iterator.hasNext()) { + Entry entry = iterator.next(); + CodeBlock keyCode = BeanDefinitionPropertyValueCodeGenerator.this + .generateCode(entry.getKey(), keyType); + CodeBlock valueCode = BeanDefinitionPropertyValueCodeGenerator.this + .generateCode(entry.getValue(), valueType); + if (!useOfEntries) { + builder.add("$L, $L", keyCode, valueCode); + } + else { + builder.add("$T.entry($L,$L)", Map.class, keyCode, valueCode); + } + builder.add((!iterator.hasNext()) ? "" : ", "); + } + builder.add(")"); + return builder.build(); + } + + private Map orderForCodeConsistency(Map map) { + return new TreeMap<>(map); + } + + private CodeBlock generateLinkedHashMapCode(Map map, + ResolvableType keyType, ResolvableType valueType) { + GeneratedMethod method = BeanDefinitionPropertyValueCodeGenerator.this.methodGenerator + .generateMethod(MethodNameGenerator.join("get", "map")) + .using(builder -> { + builder.addAnnotation(AnnotationSpec + .builder(SuppressWarnings.class) + .addMember("value", "{\"rawtypes\", \"unchecked\"}") + .build()); + builder.returns(Map.class); + builder.addStatement("$T map = new $T($L)", Map.class, + LinkedHashMap.class, map.size()); + map.forEach( + (key, value) -> builder.addStatement("map.put($L, $L)", + BeanDefinitionPropertyValueCodeGenerator.this + .generateCode(key, keyType), + BeanDefinitionPropertyValueCodeGenerator.this + .generateCode(value, valueType))); + builder.addStatement("return map"); + }); + return CodeBlock.of("$L()", method.getName()); + } + + } + + + /** + * {@link Delegate} for {@link BeanReference} types. + */ + private class BeanReferenceDelegate implements Delegate { + + @Override + @Nullable + public CodeBlock generateCode(Object value, ResolvableType type) { + if (value instanceof BeanReference beanReference) { + return CodeBlock.of("new $T($S)", RuntimeBeanReference.class, + beanReference.getBeanName()); + } + return null; + } + + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationAotContribution.java new file mode 100644 index 00000000000..7d3b497e9de --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationAotContribution.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import org.springframework.aot.generate.GenerationContext; + +/** + * AOT contribution from a {@link BeanFactoryInitializationAotProcessor} used to + * initialize a bean factory. + * + * @author Phillip Webb + * @since 6.0 + * @see BeanFactoryInitializationAotProcessor + */ +@FunctionalInterface +public interface BeanFactoryInitializationAotContribution { + + /** + * Apply this contribution to the given + * {@link BeanFactoryInitializationCode}. + * @param generationContext the active generation context + * @param beanFactoryInitializationCode the bean factory initialization code + */ + void applyTo(GenerationContext generationContext, + BeanFactoryInitializationCode beanFactoryInitializationCode); + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationAotProcessor.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationAotProcessor.java new file mode 100644 index 00000000000..9765e2c1f58 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationAotProcessor.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.lang.Nullable; + +/** + * AOT processor that makes bean factory initialization contributions by + * processing {@link ConfigurableListableBeanFactory} instances. + * + * @author Phillip Webb + * @since 6.0 + * @see BeanFactoryInitializationAotContribution + */ +@FunctionalInterface +public interface BeanFactoryInitializationAotProcessor { + + /** + * Process the given {@link ConfigurableListableBeanFactory} instance + * ahead-of-time and return a contribution or {@code null}. + *

+ * Processors are free to use any techniques they like to analyze the given + * instance. Most typically use reflection to find fields or methods to use + * in the contribution. Contributions typically generate source code or + * resource files that can be used when the AOT optimized application runs. + *

+ * If the given instance isn't relevant to the processor, it should return a + * {@code null} contribution. + * @param beanFactory the bean factory to process + * @return a {@link BeanFactoryInitializationAotContribution} or + * {@code null} + */ + @Nullable + BeanFactoryInitializationAotContribution processAheadOfTime( + ConfigurableListableBeanFactory beanFactory); + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java new file mode 100644 index 00000000000..92e250ba7e9 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.MethodReference; + +/** + * Interface that can be used to configure the code that will be generated to + * perform bean factory initialization. + * + * @author Phillip Webb + * @since 6.0 + * @see BeanFactoryInitializationAotContribution + */ +public interface BeanFactoryInitializationCode { + + /** + * The recommended variable name to used referring to the bean factory. + */ + String BEAN_FACTORY_VARIABLE = "beanFactory"; + + /** + * Return a {@link MethodGenerator} that can be used to add more methods to + * the Initializing code. + * @return the method generator + */ + MethodGenerator getMethodGenerator(); + + /** + * Add an initializer method call. + * @param methodReference a reference to the initialize method to call. The + * referenced method must have the same functional signature as + * {@code Consumer}. + */ + void addInitializer(MethodReference methodReference); + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationAotContribution.java new file mode 100644 index 00000000000..be02a4bbe5b --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationAotContribution.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import org.springframework.aot.generate.GenerationContext; + +/** + * AOT contribution from a {@link BeanRegistrationAotProcessor} used to register + * a single bean definition. + * + * @author Phillip Webb + * @since 6.0 + * @see BeanRegistrationAotProcessor + */ +@FunctionalInterface +public interface BeanRegistrationAotContribution { + + /** + * Apply this contribution to the given {@link BeanRegistrationCode}. + * @param generationContext the active generation context + * @param beanRegistrationCode the generated registration + */ + void applyTo(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode); + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationAotProcessor.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationAotProcessor.java new file mode 100644 index 00000000000..646b632958d --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationAotProcessor.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.lang.Nullable; + +/** + * AOT processor that makes bean registration contributions by processing + * {@link RegisteredBean} instances. + * + * @author Phillip Webb + * @since 6.0 + * @see BeanRegistrationAotContribution + */ +@FunctionalInterface +public interface BeanRegistrationAotProcessor { + + /** + * Process the given {@link RegisteredBean} instance ahead-of-time and + * return a contribution or {@code null}. + *

+ * Processors are free to use any techniques they like to analyze the given + * instance. Most typically use reflection to find fields or methods to use + * in the contribution. Contributions typically generate source code or + * resource files that can be used when the AOT optimized application runs. + *

+ * If the given instance isn't relevant to the processor, it should return a + * {@code null} contribution. + * @param registeredBean the registered bean to process + * @return a {@link BeanRegistrationAotContribution} or {@code null} + */ + @Nullable + BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean); + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCode.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCode.java new file mode 100644 index 00000000000..fe8909b1eb0 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCode.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.MethodReference; +import org.springframework.beans.factory.support.InstanceSupplier; +import org.springframework.javapoet.ClassName; + +/** + * Interface that can be used to configure the code that will be generated to + * perform registration of a single bean. + * + * @author Phillip Webb + * @since 6.0 + * @see BeanRegistrationCodeFragments + * @see BeanRegistrationCodeFragmentsCustomizer + */ +public interface BeanRegistrationCode { + + /** + * Return the name of the class being used for registrations. + * @return the name of the class + */ + ClassName getClassName(); + + /** + * Return a {@link MethodGenerator} that can be used to add more methods to + * the registrations code. + * @return the method generator + */ + MethodGenerator getMethodGenerator(); + + /** + * Add an instance post processor method call to the registration code. + * @param methodReference a reference to the post-process method to call. + * The referenced method must have a functional signature compatible with + * {@link InstanceSupplier#andThen}. + * @see InstanceSupplier#andThen(org.springframework.util.function.ThrowableBiFunction) + */ + void addInstancePostProcessor(MethodReference methodReference); + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java new file mode 100644 index 00000000000..acc2c0b935f --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java @@ -0,0 +1,166 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.lang.reflect.Executable; +import java.util.List; +import java.util.function.Predicate; + +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.MethodReference; +import org.springframework.beans.factory.support.InstanceSupplier; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.core.ResolvableType; +import org.springframework.javapoet.CodeBlock; +import org.springframework.util.Assert; + +/** + * Class used to generate the various fragments of code needed to register a + * bean. + * + * @author Phillip Webb + * @since 6.0 + * @see BeanRegistrationCodeFragmentsWrapper + * @see BeanRegistrationCodeFragmentsCustomizer + */ +public abstract class BeanRegistrationCodeFragments { + + /** + * The variable name to used when creating the bean definition. + */ + protected static final String BEAN_DEFINITION_VARIABLE = "beanDefinition"; + + /** + * The variable name to used when creating the bean definition. + */ + protected static final String INSTANCE_SUPPLIER_VARIABLE = "instanceSupplier"; + + + private final BeanRegistrationCodeFragments codeFragments; + + + protected BeanRegistrationCodeFragments(BeanRegistrationCodeFragments codeFragments) { + Assert.notNull(codeFragments, "'codeFragments' must not be null"); + this.codeFragments = codeFragments; + } + + + /** + * Package-private constructor exclusively for + * {@link DefaultBeanRegistrationCodeFragments}. + */ + BeanRegistrationCodeFragments() { + this.codeFragments = null; + } + + /** + * Return the target for the registration. Used to determine where to write + * the code. + * @param registeredBean the registered bean + * @param constructorOrFactoryMethod the constructor or factory method + * @return the target class + */ + public Class getTarget(RegisteredBean registeredBean, + Executable constructorOrFactoryMethod) { + + return this.codeFragments.getTarget(registeredBean, constructorOrFactoryMethod); + } + + /** + * Generate the code that defines the new bean definition instance. + * @param generationContext the generation context + * @param beanType the bean type + * @param beanRegistrationCode the bean registration code + * @return the generated code + */ + public CodeBlock generateNewBeanDefinitionCode(GenerationContext generationContext, + ResolvableType beanType, BeanRegistrationCode beanRegistrationCode) { + + return this.codeFragments.generateNewBeanDefinitionCode(generationContext, + beanType, beanRegistrationCode); + + } + + /** + * Generate the code that sets the properties of the bean definition. + * @param generationContext the generation context + * @param beanRegistrationCode the bean registration code + * @param attributeFilter any attribute filtering that should be applied + * @return the generated code + */ + public CodeBlock generateSetBeanDefinitionPropertiesCode( + GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, RootBeanDefinition beanDefinition, + Predicate attributeFilter) { + + return this.codeFragments.generateSetBeanDefinitionPropertiesCode( + generationContext, beanRegistrationCode, beanDefinition, attributeFilter); + + } + + /** + * Generate the code that sets the instance supplier on the bean definition. + * @param generationContext the generation context + * @param beanRegistrationCode the bean registration code + * @param instanceSupplierCode the instance supplier code supplier code + * @param postProcessors any instance post processors that should be applied + * @return the generated code + * @see #generateInstanceSupplierCode + */ + public CodeBlock generateSetBeanInstanceSupplierCode( + GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, CodeBlock instanceSupplierCode, + List postProcessors) { + + return this.codeFragments.generateSetBeanInstanceSupplierCode(generationContext, + beanRegistrationCode, instanceSupplierCode, postProcessors); + } + + /** + * Generate the instance supplier code. + * @param generationContext the generation context + * @param beanRegistrationCode the bean registration code + * @param constructorOrFactoryMethod the constructor or factory method for + * the bean + * @param allowDirectSupplierShortcut if direct suppliers may be used rather + * than always needing an {@link InstanceSupplier} + * @return the generated code + */ + public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, + Executable constructorOrFactoryMethod, boolean allowDirectSupplierShortcut) { + + return this.codeFragments.generateInstanceSupplierCode(generationContext, + beanRegistrationCode, constructorOrFactoryMethod, + allowDirectSupplierShortcut); + } + + /** + * Generate the return statement. + * @param generationContext the generation context + * @param beanRegistrationCode the bean registration code + * @return the generated code + */ + public CodeBlock generateReturnCode(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode) { + + return this.codeFragments.generateReturnCode(generationContext, + beanRegistrationCode); + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsCustomizer.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsCustomizer.java new file mode 100644 index 00000000000..d94ccd215d4 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsCustomizer.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import org.springframework.beans.factory.support.RegisteredBean; + +/** + * Strategy factory interface that can be used to customize the + * {@link BeanRegistrationCodeFragments} that us used for a given + * {@link RegisteredBean}. This interface can be used if default code generation + * isn't suitable for specific types of {@link RegisteredBean}. + * + * @author Phillip Webb + * @since 6.0 + */ +@FunctionalInterface +public interface BeanRegistrationCodeFragmentsCustomizer { + + /** + * Apply this {@link BeanRegistrationCodeFragmentsCustomizer} to the given + * {@link BeanRegistrationCodeFragments code fragments generator}. The + * returned code generator my be a + * {@link BeanRegistrationCodeFragmentsWrapper wrapper} around the original. + * @param registeredBean the registered bean + * @param codeFragments the existing code fragments + * @return the code generator to use, either the original or a wrapped one; + */ + BeanRegistrationCodeFragments customizeBeanRegistrationCodeFragments( + RegisteredBean registeredBean, BeanRegistrationCodeFragments codeFragments); + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java new file mode 100644 index 00000000000..2371751fbf8 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.lang.reflect.Executable; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Predicate; + +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.MethodReference; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.util.Assert; + +/** + * {@link BeanRegistrationCode} implementation with code generation support. + * + * @author Phillip Webb + * @since 6.0 + */ +class BeanRegistrationCodeGenerator implements BeanRegistrationCode { + + private static final Predicate NO_ATTRIBUTE_FILTER = attribute -> true; + + private final ClassName className; + + private final MethodGenerator methodGenerator; + + private final List instancePostProcessors = new ArrayList<>(); + + private final RegisteredBean registeredBean; + + private final Executable constructorOrFactoryMethod; + + private final BeanRegistrationCodeFragments codeFragments; + + + BeanRegistrationCodeGenerator(ClassName className, MethodGenerator methodGenerator, + RegisteredBean registeredBean, Executable constructorOrFactoryMethod, + BeanRegistrationCodeFragments codeFragments) { + + this.className = className; + this.methodGenerator = methodGenerator; + this.registeredBean = registeredBean; + this.constructorOrFactoryMethod = constructorOrFactoryMethod; + this.codeFragments = codeFragments; + } + + @Override + public ClassName getClassName() { + return this.className; + } + + @Override + public MethodGenerator getMethodGenerator() { + return this.methodGenerator; + } + + @Override + public void addInstancePostProcessor(MethodReference methodReference) { + Assert.notNull(methodReference, "MethodReference must not be null"); + this.instancePostProcessors.add(methodReference); + } + + CodeBlock generateCode(GenerationContext generationContext) { + CodeBlock.Builder builder = CodeBlock.builder(); + builder.add(this.codeFragments.generateNewBeanDefinitionCode(generationContext, + this.registeredBean.getBeanType(), this)); + builder.add(this.codeFragments.generateSetBeanDefinitionPropertiesCode( + generationContext, this, this.registeredBean.getMergedBeanDefinition(), + NO_ATTRIBUTE_FILTER)); + CodeBlock instanceSupplierCode = this.codeFragments.generateInstanceSupplierCode( + generationContext, this, this.constructorOrFactoryMethod, + this.instancePostProcessors.isEmpty()); + builder.add( + this.codeFragments.generateSetBeanInstanceSupplierCode(generationContext, + this, instanceSupplierCode, this.instancePostProcessors)); + builder.add(this.codeFragments.generateReturnCode(generationContext, this)); + return builder.build(); + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationExcludeFilter.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationExcludeFilter.java new file mode 100644 index 00000000000..a4da8642e8a --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationExcludeFilter.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import org.springframework.beans.factory.support.RegisteredBean; + +/** + * Filter that can be used to exclude AOT processing of a + * {@link RegisteredBean}. + * + * @author Phillip Webb + * @author Stephane Nicoll + * @since 6.0 + */ +@FunctionalInterface +public interface BeanRegistrationExcludeFilter { + + /** + * Return if the registered bean should be excluded from AOT processing and + * registration. + * @param registeredBean the registered bean + * @return if the registered bean should be excluded + */ + boolean isExcluded(RegisteredBean registeredBean); + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java new file mode 100644 index 00000000000..8057fb1577b --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.util.Map; + +import javax.lang.model.element.Modifier; + +import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.MethodReference; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.TypeSpec; + +/** + * AOT contribution from a {@link BeanRegistrationsAotProcessor} used to + * register bean definitions. + * + * @author Phillip Webb + * @since 6.0 + * @see BeanRegistrationsAotProcessor + */ +class BeanRegistrationsAotContribution + implements BeanFactoryInitializationAotContribution { + + + private static final String BEAN_FACTORY_PARAMETER_NAME = "beanFactory"; + + + private final Map registrations; + + + BeanRegistrationsAotContribution( + Map registrations) { + + this.registrations = registrations; + } + + + @Override + public void applyTo(GenerationContext generationContext, + BeanFactoryInitializationCode beanFactoryInitializationCode) { + + ClassName className = generationContext.getClassNameGenerator() + .generateClassName("BeanFactory", "Registrations"); + BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator( + className); + GeneratedMethod registerMethod = codeGenerator.getMethodGenerator() + .generateMethod("registerBeanDefinitions") + .using(builder -> generateRegisterMethod(builder, generationContext, + codeGenerator)); + JavaFile javaFile = codeGenerator.generatedJavaFile(className); + generationContext.getGeneratedFiles().addSourceFile(javaFile); + beanFactoryInitializationCode + .addInitializer(MethodReference.of(className, registerMethod.getName())); + } + + private void generateRegisterMethod(MethodSpec.Builder builder, + GenerationContext generationContext, + BeanRegistrationsCode beanRegistrationsCode) { + + builder.addJavadoc("Register the bean definitions."); + builder.addModifiers(Modifier.PUBLIC); + builder.addParameter(DefaultListableBeanFactory.class, + BEAN_FACTORY_PARAMETER_NAME); + CodeBlock.Builder code = CodeBlock.builder(); + this.registrations.forEach((beanName, beanDefinitionMethodGenerator) -> { + MethodReference beanDefinitionMethod = beanDefinitionMethodGenerator + .generateBeanDefinitionMethod(generationContext, + beanRegistrationsCode); + code.addStatement("$L.registerBeanDefinition($S, $L)", + BEAN_FACTORY_PARAMETER_NAME, beanName, + beanDefinitionMethod.toInvokeCodeBlock()); + }); + builder.addCode(code.build()); + } + + + /** + * {@link BeanRegistrationsCode} with generation support. + */ + static class BeanRegistrationsCodeGenerator implements BeanRegistrationsCode { + + private final ClassName className; + + private final GeneratedMethods generatedMethods = new GeneratedMethods(); + + + public BeanRegistrationsCodeGenerator(ClassName className) { + this.className = className; + } + + + @Override + public ClassName getClassName() { + return this.className; + } + + @Override + public MethodGenerator getMethodGenerator() { + return this.generatedMethods; + } + + JavaFile generatedJavaFile(ClassName className) { + TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); + classBuilder.addJavadoc("Register bean definitions for the bean factory."); + classBuilder.addModifiers(Modifier.PUBLIC); + this.generatedMethods.doWithMethodSpecs(classBuilder::addMethod); + return JavaFile.builder(className.packageName(), classBuilder.build()) + .build(); + } + + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotProcessor.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotProcessor.java new file mode 100644 index 00000000000..02f6c48f168 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotProcessor.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.RegisteredBean; + +/** + * {@link BeanFactoryInitializationAotProcessor} that contributes code to + * register beans. + * + * @author Phillip Webb + * @since 6.0 + */ +class BeanRegistrationsAotProcessor implements BeanFactoryInitializationAotProcessor { + + @Override + public BeanRegistrationsAotContribution processAheadOfTime( + ConfigurableListableBeanFactory beanFactory) { + + BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory = + new BeanDefinitionMethodGeneratorFactory(beanFactory); + Map registrations = new LinkedHashMap<>(); + for (String beanName : beanFactory.getBeanDefinitionNames()) { + RegisteredBean registeredBean = RegisteredBean.of(beanFactory, beanName); + BeanDefinitionMethodGenerator beanDefinitionMethodGenerator = beanDefinitionMethodGeneratorFactory + .getBeanDefinitionMethodGenerator(registeredBean, null); + if (beanDefinitionMethodGenerator != null) { + registrations.put(beanName, beanDefinitionMethodGenerator); + } + } + if (registrations.isEmpty()) { + return null; + } + return new BeanRegistrationsAotContribution(registrations); + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsCode.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsCode.java new file mode 100644 index 00000000000..b3ff475fe81 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsCode.java @@ -0,0 +1,44 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.javapoet.ClassName; + +/** + * Interface that can be used to configure the code that will be generated to + * register beans. + * + * @author Phillip Webb + * @since 6.0 + */ +public interface BeanRegistrationsCode { + + /** + * Return the name of the class being used for registrations. + * @return the generated class name. + */ + ClassName getClassName(); + + /** + * Return a {@link MethodGenerator} that can be used to add more methods to + * the registrations code. + * @return the method generator + */ + MethodGenerator getMethodGenerator(); + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/ConstructorOrFactoryMethodResolver.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/ConstructorOrFactoryMethodResolver.java new file mode 100644 index 00000000000..108978aae6e --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/ConstructorOrFactoryMethodResolver.java @@ -0,0 +1,450 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.BeanReference; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.beans.factory.config.ConstructorArgumentValues; +import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; +import org.springframework.beans.factory.support.AbstractBeanDefinition; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.MergedAnnotations; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; + +/** + * Resolves the {@link Executable} (factory method or constructor) that should + * be used to create a bean. This class is similar to + * {@code org.springframework.beans.factory.support.ConstructorResolver} but it + * doesn't need bean initialization. + * + * @author Stephane Nicoll + * @author Phillip Webb + * @since 6.0 + */ +class ConstructorOrFactoryMethodResolver { + + private static final Log logger = LogFactory + .getLog(ConstructorOrFactoryMethodResolver.class); + + + private final ConfigurableBeanFactory beanFactory; + + private final ClassLoader classLoader; + + + ConstructorOrFactoryMethodResolver(ConfigurableBeanFactory beanFactory) { + this.beanFactory = beanFactory; + this.classLoader = (beanFactory.getBeanClassLoader() != null) + ? beanFactory.getBeanClassLoader() : ClassUtils.getDefaultClassLoader(); + } + + + Executable resolve(BeanDefinition beanDefinition) { + Supplier beanType = () -> getBeanType(beanDefinition); + List valueTypes = beanDefinition.hasConstructorArgumentValues() + ? determineParameterValueTypes( + beanDefinition.getConstructorArgumentValues()) + : Collections.emptyList(); + Method resolvedFactoryMethod = resolveFactoryMethod(beanDefinition, valueTypes); + if (resolvedFactoryMethod != null) { + return resolvedFactoryMethod; + } + Class factoryBeanClass = getFactoryBeanClass(beanDefinition); + if (factoryBeanClass != null && !factoryBeanClass + .equals(beanDefinition.getResolvableType().toClass())) { + ResolvableType resolvableType = beanDefinition.getResolvableType(); + boolean isCompatible = ResolvableType.forClass(factoryBeanClass) + .as(FactoryBean.class).getGeneric(0).isAssignableFrom(resolvableType); + Assert.state(isCompatible, + () -> String.format( + "Incompatible target type '%s' for factory bean '%s'", + resolvableType.toClass().getName(), + factoryBeanClass.getName())); + return resolveConstructor(() -> ResolvableType.forClass(factoryBeanClass), + valueTypes); + } + Executable resolvedConstructor = resolveConstructor(beanType, valueTypes); + if (resolvedConstructor != null) { + return resolvedConstructor; + } + Executable resolvedConstructorOrFactoryMethod = getField(beanDefinition, + "resolvedConstructorOrFactoryMethod", Executable.class); + if (resolvedConstructorOrFactoryMethod != null) { + logger.error( + "resolvedConstructorOrFactoryMethod required for " + beanDefinition); + return resolvedConstructorOrFactoryMethod; + } + return null; + } + + private List determineParameterValueTypes( + ConstructorArgumentValues constructorArgumentValues) { + + List parameterTypes = new ArrayList<>(); + for (ValueHolder valueHolder : constructorArgumentValues + .getIndexedArgumentValues().values()) { + parameterTypes.add(determineParameterValueType(valueHolder)); + } + return parameterTypes; + } + + private ResolvableType determineParameterValueType(ValueHolder valueHolder) { + if (valueHolder.getType() != null) { + return ResolvableType.forClass(loadClass(valueHolder.getType())); + } + Object value = valueHolder.getValue(); + if (value instanceof BeanReference) { + return ResolvableType.forClass(this.beanFactory + .getType(((BeanReference) value).getBeanName(), false)); + } + if (value instanceof BeanDefinition) { + return extractTypeFromBeanDefinition(getBeanType((BeanDefinition) value)); + } + if (value instanceof Class) { + return ResolvableType.forClassWithGenerics(Class.class, (Class) value); + } + return ResolvableType.forInstance(value); + } + + private ResolvableType extractTypeFromBeanDefinition(ResolvableType type) { + if (FactoryBean.class.isAssignableFrom(type.toClass())) { + return type.as(FactoryBean.class).getGeneric(0); + } + return type; + } + + @Nullable + private Method resolveFactoryMethod(BeanDefinition beanDefinition, + List valueTypes) { + + if (beanDefinition instanceof RootBeanDefinition rbd) { + Method resolvedFactoryMethod = rbd.getResolvedFactoryMethod(); + if (resolvedFactoryMethod != null) { + return resolvedFactoryMethod; + } + } + String factoryMethodName = beanDefinition.getFactoryMethodName(); + if (factoryMethodName != null) { + String factoryBeanName = beanDefinition.getFactoryBeanName(); + Class beanClass = getBeanClass((factoryBeanName != null) + ? this.beanFactory.getMergedBeanDefinition(factoryBeanName) + : beanDefinition); + List methods = new ArrayList<>(); + Assert.state(beanClass != null, + () -> "Failed to determine bean class of " + beanDefinition); + ReflectionUtils.doWithMethods(beanClass, methods::add, + method -> isFactoryMethodCandidate(beanClass, method, + factoryMethodName)); + if (methods.size() >= 1) { + Function> parameterTypesFactory = method -> { + List types = new ArrayList<>(); + for (int i = 0; i < method.getParameterCount(); i++) { + types.add(ResolvableType.forMethodParameter(method, i)); + } + return types; + }; + return (Method) resolveFactoryMethod(methods, parameterTypesFactory, + valueTypes); + } + } + return null; + } + + private boolean isFactoryMethodCandidate(Class beanClass, Method method, + String factoryMethodName) { + + if (method.getName().equals(factoryMethodName)) { + if (Modifier.isStatic(method.getModifiers())) { + return method.getDeclaringClass().equals(beanClass); + } + return !Modifier.isPrivate(method.getModifiers()); + } + return false; + } + + @Nullable + private Executable resolveConstructor(Supplier beanType, + List valueTypes) { + + Class type = ClassUtils.getUserClass(beanType.get().toClass()); + Constructor[] constructors = type.getDeclaredConstructors(); + if (constructors.length == 1) { + return constructors[0]; + } + for (Constructor constructor : constructors) { + if (MergedAnnotations.from(constructor).isPresent(Autowired.class)) { + return constructor; + } + } + Function, List> parameterTypesFactory = executable -> { + List types = new ArrayList<>(); + for (int i = 0; i < executable.getParameterCount(); i++) { + types.add(ResolvableType.forConstructorParameter(executable, i)); + } + return types; + }; + List matches = Arrays.stream(constructors) + .filter(executable -> match(parameterTypesFactory.apply(executable), + valueTypes, FallbackMode.NONE)) + .toList(); + if (matches.size() == 1) { + return matches.get(0); + } + List assignableElementFallbackMatches = Arrays + .stream(constructors) + .filter(executable -> match(parameterTypesFactory.apply(executable), + valueTypes, FallbackMode.ASSIGNABLE_ELEMENT)) + .toList(); + if (assignableElementFallbackMatches.size() == 1) { + return assignableElementFallbackMatches.get(0); + } + List typeConversionFallbackMatches = Arrays + .stream(constructors) + .filter(executable -> match(parameterTypesFactory.apply(executable), + valueTypes, + ConstructorOrFactoryMethodResolver.FallbackMode.TYPE_CONVERSION)) + .toList(); + return (typeConversionFallbackMatches.size() == 1) + ? typeConversionFallbackMatches.get(0) : null; + } + + private Executable resolveFactoryMethod(List executables, + Function> parameterTypesFactory, + List valueTypes) { + + List matches = executables.stream() + .filter(executable -> match(parameterTypesFactory.apply(executable), + valueTypes, ConstructorOrFactoryMethodResolver.FallbackMode.NONE)) + .toList(); + if (matches.size() == 1) { + return matches.get(0); + } + List assignableElementFallbackMatches = executables.stream() + .filter(executable -> match(parameterTypesFactory.apply(executable), + valueTypes, + ConstructorOrFactoryMethodResolver.FallbackMode.ASSIGNABLE_ELEMENT)) + .toList(); + if (assignableElementFallbackMatches.size() == 1) { + return assignableElementFallbackMatches.get(0); + } + List typeConversionFallbackMatches = executables.stream() + .filter(executable -> match(parameterTypesFactory.apply(executable), + valueTypes, + ConstructorOrFactoryMethodResolver.FallbackMode.TYPE_CONVERSION)) + .toList(); + Assert.state(typeConversionFallbackMatches.size() <= 1, + () -> "Multiple matches with parameters '" + valueTypes + "': " + + typeConversionFallbackMatches); + return (typeConversionFallbackMatches.size() == 1) + ? typeConversionFallbackMatches.get(0) : null; + } + + private boolean match(List parameterTypes, + List valueTypes, + ConstructorOrFactoryMethodResolver.FallbackMode fallbackMode) { + + if (parameterTypes.size() != valueTypes.size()) { + return false; + } + for (int i = 0; i < parameterTypes.size(); i++) { + if (!isMatch(parameterTypes.get(i), valueTypes.get(i), fallbackMode)) { + return false; + } + } + return true; + } + + private boolean isMatch(ResolvableType parameterType, ResolvableType valueType, + ConstructorOrFactoryMethodResolver.FallbackMode fallbackMode) { + + if (isAssignable(valueType).test(parameterType)) { + return true; + } + return switch (fallbackMode) { + case ASSIGNABLE_ELEMENT -> isAssignable(valueType) + .test(extractElementType(parameterType)); + case TYPE_CONVERSION -> typeConversionFallback(valueType).test(parameterType); + default -> false; + }; + } + + private Predicate isAssignable(ResolvableType valueType) { + return parameterType -> parameterType.isAssignableFrom(valueType); + } + + private ResolvableType extractElementType(ResolvableType parameterType) { + if (parameterType.isArray()) { + return parameterType.getComponentType(); + } + if (Collection.class.isAssignableFrom(parameterType.toClass())) { + return parameterType.as(Collection.class).getGeneric(0); + } + return ResolvableType.NONE; + } + + private Predicate typeConversionFallback(ResolvableType valueType) { + return parameterType -> { + if (valueOrCollection(valueType, this::isStringForClassFallback) + .test(parameterType)) { + return true; + } + return valueOrCollection(valueType, this::isSimpleConvertibleType) + .test(parameterType); + }; + } + + private Predicate valueOrCollection(ResolvableType valueType, + Function> predicateProvider) { + + return parameterType -> { + if (predicateProvider.apply(valueType).test(parameterType)) { + return true; + } + if (predicateProvider.apply(extractElementType(valueType)) + .test(extractElementType(parameterType))) { + return true; + } + return (predicateProvider.apply(valueType) + .test(extractElementType(parameterType))); + }; + } + + /** + * Return a {@link Predicate} for a parameter type that checks if its target + * value is a {@link Class} and the value type is a {@link String}. This is + * a regular use cases where a {@link Class} is defined in the bean + * definition as an FQN. + * @param valueType the type of the value + * @return a predicate to indicate a fallback match for a String to Class + * parameter + */ + private Predicate isStringForClassFallback(ResolvableType valueType) { + return parameterType -> (valueType.isAssignableFrom(String.class) + && parameterType.isAssignableFrom(Class.class)); + } + + private Predicate isSimpleConvertibleType(ResolvableType valueType) { + return parameterType -> isSimpleConvertibleType(parameterType.toClass()) + && isSimpleConvertibleType(valueType.toClass()); + } + + @Nullable + private Class getFactoryBeanClass(BeanDefinition beanDefinition) { + if (beanDefinition instanceof RootBeanDefinition rbd) { + if (rbd.hasBeanClass()) { + Class beanClass = rbd.getBeanClass(); + return FactoryBean.class.isAssignableFrom(beanClass) ? beanClass : null; + } + } + return null; + } + + @Nullable + private Class getBeanClass(BeanDefinition beanDefinition) { + if (beanDefinition instanceof AbstractBeanDefinition abd) { + return abd.hasBeanClass() ? abd.getBeanClass() + : loadClass(abd.getBeanClassName()); + } + return (beanDefinition.getBeanClassName() != null) + ? loadClass(beanDefinition.getBeanClassName()) : null; + } + + private ResolvableType getBeanType(BeanDefinition beanDefinition) { + ResolvableType resolvableType = beanDefinition.getResolvableType(); + if (resolvableType != ResolvableType.NONE) { + return resolvableType; + } + if (beanDefinition instanceof RootBeanDefinition rbd) { + if (rbd.hasBeanClass()) { + return ResolvableType.forClass(rbd.getBeanClass()); + } + } + String beanClassName = beanDefinition.getBeanClassName(); + if (beanClassName != null) { + return ResolvableType.forClass(loadClass(beanClassName)); + } + throw new IllegalStateException( + "Failed to determine bean class of " + beanDefinition); + } + + private Class loadClass(String beanClassName) { + try { + return ClassUtils.forName(beanClassName, this.classLoader); + } + catch (ClassNotFoundException ex) { + throw new IllegalStateException("Failed to load class " + beanClassName); + } + } + + @Nullable + private T getField(BeanDefinition beanDefinition, String fieldName, + Class targetType) { + + Field field = ReflectionUtils.findField(RootBeanDefinition.class, fieldName); + ReflectionUtils.makeAccessible(field); + return targetType.cast(ReflectionUtils.getField(field, beanDefinition)); + } + + private static boolean isSimpleConvertibleType(Class type) { + return (type.isPrimitive() && type != void.class) || type == Double.class + || type == Float.class || type == Long.class || type == Integer.class + || type == Short.class || type == Character.class || type == Byte.class + || type == Boolean.class || type == String.class; + } + + static Executable resolve(RegisteredBean registeredBean) { + return new ConstructorOrFactoryMethodResolver(registeredBean.getBeanFactory()) + .resolve(registeredBean.getMergedBeanDefinition()); + } + + + enum FallbackMode { + + NONE, + + ASSIGNABLE_ELEMENT, + + TYPE_CONVERSION + + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java new file mode 100644 index 00000000000..ecaa2393b23 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -0,0 +1,185 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.lang.reflect.Executable; +import java.util.List; +import java.util.function.Predicate; + +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.MethodReference; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.BeanDefinitionHolder; +import org.springframework.beans.factory.support.InstanceSupplier; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.core.ResolvableType; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; + +/** + * Internal {@link BeanRegistrationCodeFragments} implementation used by + * default. + * + * @author Phillip Webb + */ +class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments { + + /** + * The variable name used to hold the bean type. + */ + private static final String BEAN_TYPE_VARIABLE = "beanType"; + + + private final BeanRegistrationsCode beanRegistrationsCode; + + private final RegisteredBean registeredBean; + + private final BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory; + + + DefaultBeanRegistrationCodeFragments(BeanRegistrationsCode beanRegistrationsCode, + RegisteredBean registeredBean, + BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory) { + + this.beanRegistrationsCode = beanRegistrationsCode; + this.registeredBean = registeredBean; + this.beanDefinitionMethodGeneratorFactory = beanDefinitionMethodGeneratorFactory; + } + + + @Override + public Class getTarget(RegisteredBean registeredBean, + Executable constructorOrFactoryMethod) { + + Class target = ClassUtils + .getUserClass(constructorOrFactoryMethod.getDeclaringClass()); + while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) { + target = registeredBean.getParent().getBeanClass(); + } + return target; + } + + @Override + public CodeBlock generateNewBeanDefinitionCode(GenerationContext generationContext, + ResolvableType beanType, BeanRegistrationCode beanRegistrationCode) { + + CodeBlock.Builder builder = CodeBlock.builder(); + builder.addStatement(generateBeanTypeCode(beanType)); + builder.addStatement("$T $L = new $T($L)", RootBeanDefinition.class, + BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, BEAN_TYPE_VARIABLE); + return builder.build(); + } + + private CodeBlock generateBeanTypeCode(ResolvableType beanType) { + if (!beanType.hasGenerics()) { + return CodeBlock.of("$T $L = $T.class", Class.class, BEAN_TYPE_VARIABLE, + ClassUtils.getUserClass(beanType.toClass())); + } + return CodeBlock.of("$T $L = $L", ResolvableType.class, BEAN_TYPE_VARIABLE, + ResolvableTypeCodeGenerator.generateCode(beanType)); + } + + @Override + public CodeBlock generateSetBeanDefinitionPropertiesCode( + GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, RootBeanDefinition beanDefinition, + Predicate attributeFilter) { + + return new BeanDefinitionPropertiesCodeGenerator( + generationContext.getRuntimeHints(), attributeFilter, + beanRegistrationCode.getMethodGenerator(), + (name, value) -> generateValueCode(generationContext, name, value)) + .generateCode(beanDefinition); + } + + @Nullable + protected CodeBlock generateValueCode(GenerationContext generationContext, + String name, Object value) { + + RegisteredBean innerRegisteredBean = getInnerRegisteredBean(value); + if (innerRegisteredBean != null) { + BeanDefinitionMethodGenerator methodGenerator = this.beanDefinitionMethodGeneratorFactory + .getBeanDefinitionMethodGenerator(innerRegisteredBean, name); + Assert.state(methodGenerator != null, "Unexpected filtering of inner-bean"); + MethodReference generatedMethod = methodGenerator + .generateBeanDefinitionMethod(generationContext, + this.beanRegistrationsCode); + return generatedMethod.toInvokeCodeBlock(); + } + return null; + } + + @Nullable + private RegisteredBean getInnerRegisteredBean(Object value) { + if (value instanceof BeanDefinitionHolder beanDefinitionHolder) { + return RegisteredBean.ofInnerBean(this.registeredBean, beanDefinitionHolder); + } + if (value instanceof BeanDefinition beanDefinition) { + return RegisteredBean.ofInnerBean(this.registeredBean, beanDefinition); + } + return null; + } + + public CodeBlock generateSetBeanInstanceSupplierCode( + GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, CodeBlock instanceSupplierCode, + List postProcessors) { + + CodeBlock.Builder builder = CodeBlock.builder(); + if (postProcessors.isEmpty()) { + builder.addStatement("$L.setInstanceSupplier($L)", BEAN_DEFINITION_VARIABLE, + instanceSupplierCode); + return builder.build(); + } + builder.addStatement("$T $L = $L", + ParameterizedTypeName.get(InstanceSupplier.class, + this.registeredBean.getBeanClass()), + INSTANCE_SUPPLIER_VARIABLE, instanceSupplierCode); + for (MethodReference postProcessor : postProcessors) { + builder.addStatement("$L = $L.andThen($L)", INSTANCE_SUPPLIER_VARIABLE, + INSTANCE_SUPPLIER_VARIABLE, postProcessor.toCodeBlock()); + } + builder.addStatement("$L.setInstanceSupplier($L)", BEAN_DEFINITION_VARIABLE, + INSTANCE_SUPPLIER_VARIABLE); + return builder.build(); + } + + @Override + public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, + Executable constructorOrFactoryMethod, boolean allowDirectSupplierShortcut) { + + return new InstanceSupplierCodeGenerator(generationContext, + beanRegistrationCode.getClassName(), + beanRegistrationCode.getMethodGenerator(), allowDirectSupplierShortcut) + .generateCode(this.registeredBean, constructorOrFactoryMethod); + } + + @Override + public CodeBlock generateReturnCode(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode) { + + CodeBlock.Builder builder = CodeBlock.builder(); + builder.addStatement("return $L", BEAN_DEFINITION_VARIABLE); + return builder.build(); + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java new file mode 100644 index 00000000000..6beb406a0e6 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java @@ -0,0 +1,364 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; +import java.lang.reflect.Member; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.function.Consumer; + +import org.springframework.aot.generate.AccessVisibility; +import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.hint.ExecutableHint; +import org.springframework.aot.hint.ExecutableMode; +import org.springframework.beans.factory.annotation.AutowiredArgumentsCodeGenerator; +import org.springframework.beans.factory.annotation.AutowiredInstantiationArgumentsResolver; +import org.springframework.beans.factory.support.InstanceSupplier; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.core.ResolvableType; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; +import org.springframework.util.ClassUtils; +import org.springframework.util.function.ThrowingSupplier; + +/** + * Internal code generator to create an {@link InstanceSupplier}. + *

+ * Generates code in the form:

{@code
+ * InstanceSupplier.of(TheGeneratedClass::getMyBeanInstance);
+ * }
+ * + * @author Phillip Webb + * @author Stephane Nicoll + * @since 6.0 + */ +class InstanceSupplierCodeGenerator { + + private static final String REGISTERED_BEAN_PARAMETER_NAME = "registeredBean"; + + private static final javax.lang.model.element.Modifier[] PRIVATE_STATIC = { + javax.lang.model.element.Modifier.PRIVATE, + javax.lang.model.element.Modifier.STATIC }; + + private static final CodeBlock NO_ARGS = CodeBlock.of(""); + + private static final Consumer INTROSPECT = builder -> builder + .withMode(ExecutableMode.INTROSPECT); + + + private final GenerationContext generationContext; + + private final ClassName className; + + private final MethodGenerator methodGenerator; + + private final boolean allowDirectSupplierShortcut; + + + InstanceSupplierCodeGenerator(GenerationContext generationContext, + ClassName className, MethodGenerator methodGenerator, + boolean allowDirectSupplierShortcut) { + + this.generationContext = generationContext; + this.className = className; + this.methodGenerator = methodGenerator; + this.allowDirectSupplierShortcut = allowDirectSupplierShortcut; + } + + + CodeBlock generateCode(RegisteredBean registeredBean, + Executable constructorOrFactoryMethod) { + + if (constructorOrFactoryMethod instanceof Constructor constructor) { + return generateCodeForConstructor(registeredBean, constructor); + } + if (constructorOrFactoryMethod instanceof Method method) { + return generateCodeForFactoryMethod(registeredBean, method); + } + throw new IllegalStateException( + "No suitable executor found for " + registeredBean.getBeanName()); + } + + private CodeBlock generateCodeForConstructor(RegisteredBean registeredBean, + Constructor constructor) { + + String name = registeredBean.getBeanName(); + Class declaringClass = ClassUtils + .getUserClass(constructor.getDeclaringClass()); + boolean dependsOnBean = ClassUtils.isInnerClass(declaringClass); + AccessVisibility accessVisibility = getAccessVisibility(registeredBean, + constructor); + if (accessVisibility == AccessVisibility.PUBLIC + || accessVisibility == AccessVisibility.PACKAGE_PRIVATE) { + return generateCodeForAccessibleConstructor(name, constructor, declaringClass, + dependsOnBean); + } + return generateCodeForInaccessibleConstructor(name, constructor, declaringClass, + dependsOnBean); + } + + private CodeBlock generateCodeForAccessibleConstructor(String name, + Constructor constructor, Class declaringClass, boolean dependsOnBean) { + + this.generationContext.getRuntimeHints().reflection() + .registerConstructor(constructor, INTROSPECT); + if (!dependsOnBean && constructor.getParameterCount() == 0) { + if (!this.allowDirectSupplierShortcut) { + return CodeBlock.of("$T.using($T::new)", InstanceSupplier.class, + declaringClass); + } + if (!isThrowingCheckedException(constructor)) { + return CodeBlock.of("$T::new", declaringClass); + } + return CodeBlock.of("$T.of($T::new)", ThrowingSupplier.class, + declaringClass); + } + GeneratedMethod getInstanceMethod = generateGetInstanceMethod() + .using(builder -> buildGetInstanceMethodForConstructor(builder, name, + constructor, declaringClass, dependsOnBean, PRIVATE_STATIC)); + return CodeBlock.of("$T.of($T::$L)", InstanceSupplier.class, this.className, + getInstanceMethod.getName()); + } + + private CodeBlock generateCodeForInaccessibleConstructor(String name, + Constructor constructor, Class declaringClass, boolean dependsOnBean) { + + this.generationContext.getRuntimeHints().reflection() + .registerConstructor(constructor); + GeneratedMethod getInstanceMethod = generateGetInstanceMethod().using(builder -> { + builder.addJavadoc("Instantiate the bean instance for '$L'.", name); + builder.addModifiers(PRIVATE_STATIC); + builder.returns(declaringClass); + builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); + int parameterOffset = (!dependsOnBean) ? 0 : 1; + builder.addStatement( + generateResolverForConstructor(constructor, parameterOffset)); + builder.addStatement("return resolver.resolveAndInstantiate($L)", + REGISTERED_BEAN_PARAMETER_NAME); + }); + return CodeBlock.of("$T.of($T::$L)", InstanceSupplier.class, this.className, + getInstanceMethod.getName()); + } + + private void buildGetInstanceMethodForConstructor(MethodSpec.Builder builder, + String name, Constructor constructor, Class declaringClass, + boolean dependsOnBean, javax.lang.model.element.Modifier... modifiers) { + + builder.addJavadoc("Create the bean instance for '$L'.", name); + builder.addModifiers(modifiers); + builder.returns(declaringClass); + builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); + if (constructor.getParameterCount() == 0) { + CodeBlock instantiationCode = generateNewInstanceCodeForConstructor( + dependsOnBean, declaringClass, NO_ARGS); + builder.addCode(generateReturnStatement(instantiationCode)); + } + else { + int parameterOffset = (!dependsOnBean) ? 0 : 1; + CodeBlock.Builder code = CodeBlock.builder(); + code.addStatement( + generateResolverForConstructor(constructor, parameterOffset)); + CodeBlock arguments = new AutowiredArgumentsCodeGenerator(declaringClass, + constructor).generateCode(constructor.getParameterTypes(), + parameterOffset); + CodeBlock newInstance = generateNewInstanceCodeForConstructor(dependsOnBean, + declaringClass, arguments); + code.addStatement("return resolver.resolve($L, (args) -> $L)", + REGISTERED_BEAN_PARAMETER_NAME, newInstance); + builder.addCode(code.build()); + } + } + + private CodeBlock generateResolverForConstructor(Constructor constructor, + int parameterOffset) { + + CodeBlock parameterTypes = generateParameterTypesCode( + constructor.getParameterTypes(), parameterOffset); + return CodeBlock.of("$T resolver = $T.forConstructor($L)", + AutowiredInstantiationArgumentsResolver.class, + AutowiredInstantiationArgumentsResolver.class, parameterTypes); + } + + private CodeBlock generateNewInstanceCodeForConstructor(boolean dependsOnBean, + Class declaringClass, CodeBlock args) { + + if (!dependsOnBean) { + return CodeBlock.of("new $T($L)", declaringClass, args); + } + return CodeBlock.of("$L.getBeanFactory().getBean($T.class).new $L($L)", + REGISTERED_BEAN_PARAMETER_NAME, declaringClass.getEnclosingClass(), + declaringClass.getSimpleName(), args); + } + + private CodeBlock generateCodeForFactoryMethod(RegisteredBean registeredBean, + Method factoryMethod) { + + String name = registeredBean.getBeanName(); + Class declaringClass = ClassUtils + .getUserClass(factoryMethod.getDeclaringClass()); + boolean dependsOnBean = !Modifier.isStatic(factoryMethod.getModifiers()); + AccessVisibility accessVisibility = getAccessVisibility(registeredBean, + factoryMethod); + if (accessVisibility == AccessVisibility.PUBLIC + || accessVisibility == AccessVisibility.PACKAGE_PRIVATE) { + return generateCodeForAccessibleFactoryMethod(name, factoryMethod, + declaringClass, dependsOnBean); + } + return generateCodeForInaccessibleFactoryMethod(name, factoryMethod, + declaringClass); + } + + private CodeBlock generateCodeForAccessibleFactoryMethod(String name, + Method factoryMethod, Class declaringClass, boolean dependsOnBean) { + this.generationContext.getRuntimeHints().reflection() + .registerMethod(factoryMethod, INTROSPECT); + if (!dependsOnBean && factoryMethod.getParameterCount() == 0) { + if (!this.allowDirectSupplierShortcut) { + return CodeBlock.of("$T.using($T::$L)", InstanceSupplier.class, + declaringClass, factoryMethod.getName()); + } + if (!isThrowingCheckedException(factoryMethod)) { + return CodeBlock.of("$T::$L", declaringClass, factoryMethod.getName()); + } + return CodeBlock.of("$T.of($T::$L)", ThrowingSupplier.class, declaringClass, + factoryMethod.getName()); + } + GeneratedMethod getInstanceMethod = generateGetInstanceMethod() + .using(builder -> buildGetInstanceMethodForFactoryMethod(builder, name, + factoryMethod, declaringClass, dependsOnBean, PRIVATE_STATIC)); + return CodeBlock.of("$T.of($T::$L)", InstanceSupplier.class, this.className, + getInstanceMethod.getName()); + } + + private CodeBlock generateCodeForInaccessibleFactoryMethod(String name, + Method factoryMethod, Class declaringClass) { + + this.generationContext.getRuntimeHints().reflection() + .registerMethod(factoryMethod); + GeneratedMethod getInstanceMethod = generateGetInstanceMethod().using(builder -> { + builder.addJavadoc("Instantiate the bean instance for '$L'.", name); + builder.addModifiers(PRIVATE_STATIC); + builder.returns(factoryMethod.getReturnType()); + builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); + builder.addStatement(generateResolverForFactoryMethod(factoryMethod, + declaringClass, factoryMethod.getName())); + builder.addStatement("return resolver.resolveAndInstantiate($L)", + REGISTERED_BEAN_PARAMETER_NAME); + }); + return CodeBlock.of("$T.of($T::$L)", InstanceSupplier.class, this.className, + getInstanceMethod.getName()); + } + + private void buildGetInstanceMethodForFactoryMethod(MethodSpec.Builder builder, + String name, Method factoryMethod, Class declaringClass, + boolean dependsOnBean, javax.lang.model.element.Modifier... modifiers) { + + String factoryMethodName = factoryMethod.getName(); + builder.addJavadoc("Get the bean instance for '$L'.", name); + builder.addModifiers(modifiers); + builder.returns(factoryMethod.getReturnType()); + if (isThrowingCheckedException(factoryMethod)) { + builder.addException(Exception.class); + } + builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); + if (factoryMethod.getParameterCount() == 0) { + CodeBlock instantiationCode = generateNewInstanceCodeForMethod(dependsOnBean, + declaringClass, factoryMethodName, NO_ARGS); + builder.addCode(generateReturnStatement(instantiationCode)); + } + else { + CodeBlock.Builder code = CodeBlock.builder(); + code.addStatement(generateResolverForFactoryMethod(factoryMethod, + declaringClass, factoryMethodName)); + CodeBlock arguments = new AutowiredArgumentsCodeGenerator(declaringClass, + factoryMethod).generateCode(factoryMethod.getParameterTypes()); + CodeBlock newInstance = generateNewInstanceCodeForMethod(dependsOnBean, + declaringClass, factoryMethodName, arguments); + code.addStatement("return resolver.resolve($L, (args) -> $L)", + REGISTERED_BEAN_PARAMETER_NAME, newInstance); + builder.addCode(code.build()); + } + } + + private CodeBlock generateResolverForFactoryMethod(Method factoryMethod, + Class declaringClass, String factoryMethodName) { + + if (factoryMethod.getParameterCount() == 0) { + return CodeBlock.of("$T resolver = $T.forFactoryMethod($T.class, $S)", + AutowiredInstantiationArgumentsResolver.class, + AutowiredInstantiationArgumentsResolver.class, declaringClass, + factoryMethodName); + } + CodeBlock parameterTypes = generateParameterTypesCode( + factoryMethod.getParameterTypes(), 0); + return CodeBlock.of("$T resolver = $T.forFactoryMethod($T.class, $S, $L)", + AutowiredInstantiationArgumentsResolver.class, + AutowiredInstantiationArgumentsResolver.class, declaringClass, + factoryMethodName, parameterTypes); + } + + private CodeBlock generateNewInstanceCodeForMethod(boolean dependsOnBean, + Class declaringClass, String factoryMethodName, CodeBlock args) { + + if (!dependsOnBean) { + return CodeBlock.of("$T.$L($L)", declaringClass, factoryMethodName, args); + } + return CodeBlock.of("$L.getBeanFactory().getBean($T.class).$L($L)", + REGISTERED_BEAN_PARAMETER_NAME, declaringClass, factoryMethodName, args); + } + + private CodeBlock generateReturnStatement(CodeBlock instantiationCode) { + CodeBlock.Builder code = CodeBlock.builder(); + code.addStatement("return $L", instantiationCode); + return code.build(); + } + + protected AccessVisibility getAccessVisibility(RegisteredBean registeredBean, + Member member) { + + AccessVisibility beanTypeAccessVisibility = AccessVisibility + .forResolvableType(registeredBean.getBeanType()); + AccessVisibility memberAccessVisibility = AccessVisibility.forMember(member); + return AccessVisibility.lowest(beanTypeAccessVisibility, memberAccessVisibility); + } + + private CodeBlock generateParameterTypesCode(Class[] parameterTypes, int offset) { + CodeBlock.Builder builder = CodeBlock.builder(); + for (int i = offset; i < parameterTypes.length; i++) { + builder.add(i != offset ? ", " : ""); + builder.add("$T.class", parameterTypes[i]); + } + return builder.build(); + } + + private GeneratedMethod generateGetInstanceMethod() { + return this.methodGenerator.generateMethod("get", "instance"); + } + + private boolean isThrowingCheckedException(Executable executable) { + return Arrays.stream(executable.getGenericExceptionTypes()) + .map(ResolvableType::forType).map(ResolvableType::toClass) + .anyMatch(Exception.class::isAssignableFrom); + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/ResolvableTypeCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/ResolvableTypeCodeGenerator.java new file mode 100644 index 00000000000..537d79030b2 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/ResolvableTypeCodeGenerator.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.util.Arrays; + +import org.springframework.core.ResolvableType; +import org.springframework.javapoet.CodeBlock; +import org.springframework.util.ClassUtils; + +/** + * Internal code generator used to support {@link ResolvableType}. + * + * @author Stephane Nicoll + * @author Phillip Webb + * @since 6.0 + */ +final class ResolvableTypeCodeGenerator { + + + private ResolvableTypeCodeGenerator() { + } + + + public static CodeBlock generateCode(ResolvableType resolvableType) { + return generateCode(resolvableType, false); + } + + private static CodeBlock generateCode(ResolvableType resolvableType, boolean allowClassResult) { + if (ResolvableType.NONE.equals(resolvableType)) { + return CodeBlock.of("$T.NONE", ResolvableType.class); + } + Class type = ClassUtils.getUserClass(resolvableType.toClass()); + if (resolvableType.hasGenerics()) { + return generateCodeWithGenerics(resolvableType, type); + } + if (allowClassResult) { + return CodeBlock.of("$T.class", type); + } + return CodeBlock.of("$T.forClass($T.class)", ResolvableType.class, type); + } + + private static CodeBlock generateCodeWithGenerics(ResolvableType target, Class type) { + ResolvableType[] generics = target.getGenerics(); + boolean hasNoNestedGenerics = Arrays.stream(generics).noneMatch(ResolvableType::hasGenerics); + CodeBlock.Builder builder = CodeBlock.builder(); + builder.add("$T.forClassWithGenerics($T.class", ResolvableType.class, type); + for (ResolvableType generic : generics) { + builder.add(", $L", generateCode(generic, hasNoNestedGenerics)); + } + builder.add(")"); + return builder.build(); + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/package-info.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/package-info.java new file mode 100644 index 00000000000..bf7c97a915d --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/package-info.java @@ -0,0 +1,9 @@ +/** + * AOT support for bean factories. + */ +@NonNullApi +@NonNullFields +package org.springframework.beans.factory.aot; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-beans/src/main/resources/META-INF/spring/aot.factories b/spring-beans/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..b95137025d8 --- /dev/null +++ b/spring-beans/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor=\ +org.springframework.beans.factory.aot.BeanRegistrationsAotProcessor diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredArgumentsCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredArgumentsCodeGeneratorTests.java new file mode 100644 index 00000000000..3027873f77b --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredArgumentsCodeGeneratorTests.java @@ -0,0 +1,200 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.annotation; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Method; + +import org.junit.jupiter.api.Test; + +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link AutowiredArgumentsCodeGenerator}. + * + * @author Phillip Webb + * @author Stephane Nicoll + */ +class AutowiredArgumentsCodeGeneratorTests { + + @Test + void generateCodeWhenNoArguments() { + Method method = ReflectionUtils.findMethod(UnambiguousMethods.class, "zero"); + AutowiredArgumentsCodeGenerator generator = new AutowiredArgumentsCodeGenerator( + UnambiguousMethods.class, method); + assertThat(generator.generateCode(method.getParameterTypes())).hasToString(""); + } + + @Test + void generatedCodeWhenSingleArgument() { + Method method = ReflectionUtils.findMethod(UnambiguousMethods.class, "one", + String.class); + AutowiredArgumentsCodeGenerator generator = new AutowiredArgumentsCodeGenerator( + UnambiguousMethods.class, method); + assertThat(generator.generateCode(method.getParameterTypes())) + .hasToString("args.get(0)"); + } + + @Test + void generateCodeWhenMulitpleArguments() { + Method method = ReflectionUtils.findMethod(UnambiguousMethods.class, "three", + String.class, Integer.class, Boolean.class); + AutowiredArgumentsCodeGenerator generator = new AutowiredArgumentsCodeGenerator( + UnambiguousMethods.class, method); + assertThat(generator.generateCode(method.getParameterTypes())) + .hasToString("args.get(0), args.get(1), args.get(2)"); + } + + @Test + void generateCodeWhenMulitpleArgumentsWithOffset() { + Constructor constructor = Outer.Nested.class.getDeclaredConstructors()[0]; + AutowiredArgumentsCodeGenerator generator = new AutowiredArgumentsCodeGenerator( + Outer.Nested.class, constructor); + assertThat(generator.generateCode(constructor.getParameterTypes(), 1)) + .hasToString("args.get(0), args.get(1)"); + } + + @Test + void generateCodeWhenAmbiguousConstructor() throws Exception { + Constructor constructor = AmbiguousConstructors.class + .getDeclaredConstructor(String.class, Integer.class); + AutowiredArgumentsCodeGenerator generator = new AutowiredArgumentsCodeGenerator( + AmbiguousConstructors.class, constructor); + assertThat(generator.generateCode(constructor.getParameterTypes())).hasToString( + "args.get(0, java.lang.String.class), args.get(1, java.lang.Integer.class)"); + } + + @Test + void generateCodeWhenUnambiguousConstructor() throws Exception { + Constructor constructor = UnambiguousConstructors.class + .getDeclaredConstructor(String.class, Integer.class); + AutowiredArgumentsCodeGenerator generator = new AutowiredArgumentsCodeGenerator( + UnambiguousConstructors.class, constructor); + assertThat(generator.generateCode(constructor.getParameterTypes())) + .hasToString("args.get(0), args.get(1)"); + } + + @Test + void generateCodeWhenAmbiguousMethod() { + Method method = ReflectionUtils.findMethod(AmbiguousMethods.class, "two", + String.class, Integer.class); + AutowiredArgumentsCodeGenerator generator = new AutowiredArgumentsCodeGenerator( + AmbiguousMethods.class, method); + assertThat(generator.generateCode(method.getParameterTypes())).hasToString( + "args.get(0, java.lang.String.class), args.get(1, java.lang.Integer.class)"); + } + + @Test + void generateCodeWhenAmbiguousSubclassMethod() { + Method method = ReflectionUtils.findMethod(UnambiguousMethods.class, "two", + String.class, Integer.class); + AutowiredArgumentsCodeGenerator generator = new AutowiredArgumentsCodeGenerator( + AmbiguousSubclassMethods.class, method); + assertThat(generator.generateCode(method.getParameterTypes())).hasToString( + "args.get(0, java.lang.String.class), args.get(1, java.lang.Integer.class)"); + } + + @Test + void generateCodeWhenUnambiguousMethod() { + Method method = ReflectionUtils.findMethod(UnambiguousMethods.class, "two", + String.class, Integer.class); + AutowiredArgumentsCodeGenerator generator = new AutowiredArgumentsCodeGenerator( + UnambiguousMethods.class, method); + assertThat(generator.generateCode(method.getParameterTypes())) + .hasToString("args.get(0), args.get(1)"); + } + + @Test + void generateCodeWithCustomArgVariable() { + Method method = ReflectionUtils.findMethod(UnambiguousMethods.class, "one", + String.class); + AutowiredArgumentsCodeGenerator generator = new AutowiredArgumentsCodeGenerator( + UnambiguousMethods.class, method); + assertThat(generator.generateCode(method.getParameterTypes(), 0, "objs")) + .hasToString("objs.get(0)"); + } + + static class Outer { + + class Nested { + + Nested(String a, Integer b) { + } + + } + + } + + static class UnambiguousMethods { + + void zero() { + } + + void one(String a) { + } + + void two(String a, Integer b) { + } + + void three(String a, Integer b, Boolean c) { + } + + } + + static class AmbiguousMethods { + + void two(String a, Integer b) { + } + + void two(Integer b, String a) { + } + + } + + static class AmbiguousSubclassMethods extends UnambiguousMethods { + + void two(Integer a, String b) { + } + + } + + static class UnambiguousConstructors { + + UnambiguousConstructors() { + } + + UnambiguousConstructors(String a) { + } + + UnambiguousConstructors(String a, Integer b) { + } + + } + + static class AmbiguousConstructors { + + AmbiguousConstructors(String a, Integer b) { + } + + AmbiguousConstructors(Integer b, String a) { + } + + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/AotFactoriesLoaderTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/AotFactoriesLoaderTests.java new file mode 100644 index 00000000000..d252a3663ed --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/AotFactoriesLoaderTests.java @@ -0,0 +1,100 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.core.Ordered; +import org.springframework.core.mock.MockSpringFactoriesLoader; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link AotFactoriesLoader}. + * + * @author Phillip Webb + */ +class AotFactoriesLoaderTests { + + @Test + void createWhenBeanFactoryIsNullThrowsException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new AotFactoriesLoader(null)) + .withMessage("BeanFactory must not be null"); + } + + @Test + void createWhenSpringFactoriesLoaderIsNullThrowsException() { + ListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + assertThatIllegalArgumentException() + .isThrownBy(() -> new AotFactoriesLoader(beanFactory, null)) + .withMessage("FactoriesLoader must not be null"); + } + + @Test + void loadLoadsFromBeanFactoryAndSpringFactoriesLoaderInOrder() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("b1", new TestFactoryImpl(0, "b1")); + beanFactory.registerSingleton("b2", new TestFactoryImpl(2, "b2")); + MockSpringFactoriesLoader springFactoriesLoader = new MockSpringFactoriesLoader(); + springFactoriesLoader.addInstance(TestFactory.class, + new TestFactoryImpl(1, "l1")); + springFactoriesLoader.addInstance(TestFactory.class, + new TestFactoryImpl(3, "l2")); + AotFactoriesLoader loader = new AotFactoriesLoader(beanFactory, + springFactoriesLoader); + List loaded = loader.load(TestFactory.class); + assertThat(loaded).hasSize(4); + assertThat(loaded.get(0)).hasToString("b1"); + assertThat(loaded.get(1)).hasToString("l1"); + assertThat(loaded.get(2)).hasToString("b2"); + assertThat(loaded.get(3)).hasToString("l2"); + } + + static interface TestFactory { + + } + + static class TestFactoryImpl implements TestFactory, Ordered { + + private final int order; + + private final String name; + + TestFactoryImpl(int order, String name) { + this.order = order; + this.name = name; + } + + @Override + public int getOrder() { + return this.order; + } + + @Override + public String toString() { + return this.name; + } + + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorFactoryTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorFactoryTests.java new file mode 100644 index 00000000000..f9dc7040cf3 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorFactoryTests.java @@ -0,0 +1,164 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.core.Ordered; +import org.springframework.core.mock.MockSpringFactoriesLoader; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link BeanDefinitionMethodGeneratorFactory}. + * + * @author Phillip Webb + */ +class BeanDefinitionMethodGeneratorFactoryTests { + + @Test + void getBeanDefinitionMethodGeneratorWhenExcludedByBeanRegistrationExcludeFilterReturnsNull() { + MockSpringFactoriesLoader springFactoriesLoader = new MockSpringFactoriesLoader(); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + springFactoriesLoader.addInstance(BeanRegistrationExcludeFilter.class, + new MockBeanRegistrationExcludeFilter(true, 0)); + RegisteredBean registeredBean = registerTestBean(beanFactory); + BeanDefinitionMethodGeneratorFactory methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( + new AotFactoriesLoader(beanFactory, springFactoriesLoader)); + assertThat(methodGeneratorFactory.getBeanDefinitionMethodGenerator(registeredBean, + null)).isNull(); + } + + @Test + void getBeanDefinitionMethodGeneratorWhenExcludedByBeanRegistrationExcludeFilterBeanReturnsNull() { + MockSpringFactoriesLoader springFactoriesLoader = new MockSpringFactoriesLoader(); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + RegisteredBean registeredBean = registerTestBean(beanFactory); + beanFactory.registerSingleton("filter", + new MockBeanRegistrationExcludeFilter(true, 0)); + BeanDefinitionMethodGeneratorFactory methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( + new AotFactoriesLoader(beanFactory, springFactoriesLoader)); + assertThat(methodGeneratorFactory.getBeanDefinitionMethodGenerator(registeredBean, + null)).isNull(); + } + + @Test + void getBeanDefinitionMethodGeneratorConsidersFactoryLoadedExcludeFiltersAndBeansInOrderedOrder() { + MockBeanRegistrationExcludeFilter filter1 = new MockBeanRegistrationExcludeFilter( + false, 1); + MockBeanRegistrationExcludeFilter filter2 = new MockBeanRegistrationExcludeFilter( + false, 2); + MockBeanRegistrationExcludeFilter filter3 = new MockBeanRegistrationExcludeFilter( + false, 3); + MockBeanRegistrationExcludeFilter filter4 = new MockBeanRegistrationExcludeFilter( + true, 4); + MockBeanRegistrationExcludeFilter filter5 = new MockBeanRegistrationExcludeFilter( + true, 5); + MockBeanRegistrationExcludeFilter filter6 = new MockBeanRegistrationExcludeFilter( + true, 6); + MockSpringFactoriesLoader springFactoriesLoader = new MockSpringFactoriesLoader(); + springFactoriesLoader.addInstance(BeanRegistrationExcludeFilter.class, filter3, + filter1, filter5); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("filter4", filter4); + beanFactory.registerSingleton("filter2", filter2); + beanFactory.registerSingleton("filter6", filter6); + RegisteredBean registeredBean = registerTestBean(beanFactory); + BeanDefinitionMethodGeneratorFactory methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( + new AotFactoriesLoader(beanFactory, springFactoriesLoader)); + assertThat(methodGeneratorFactory.getBeanDefinitionMethodGenerator(registeredBean, + null)).isNull(); + assertThat(filter1.wasCalled()).isTrue(); + assertThat(filter2.wasCalled()).isTrue(); + assertThat(filter3.wasCalled()).isTrue(); + assertThat(filter4.wasCalled()).isTrue(); + assertThat(filter5.wasCalled()).isFalse(); + assertThat(filter6.wasCalled()).isFalse(); + } + + @Test + void getBeanDefinitionMethodGeneratorAddsContributionsFromProcessors() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + BeanRegistrationAotContribution beanContribution = mock( + BeanRegistrationAotContribution.class); + BeanRegistrationAotProcessor processorBean = registeredBean -> beanContribution; + beanFactory.registerSingleton("processorBean", processorBean); + MockSpringFactoriesLoader springFactoriesLoader = new MockSpringFactoriesLoader(); + BeanRegistrationAotContribution loaderContribution = mock( + BeanRegistrationAotContribution.class); + BeanRegistrationAotProcessor loaderProcessor = registeredBean -> loaderContribution; + springFactoriesLoader.addInstance(BeanRegistrationAotProcessor.class, + loaderProcessor); + RegisteredBean registeredBean = registerTestBean(beanFactory); + BeanDefinitionMethodGeneratorFactory methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( + new AotFactoriesLoader(beanFactory, springFactoriesLoader)); + BeanDefinitionMethodGenerator methodGenerator = methodGeneratorFactory + .getBeanDefinitionMethodGenerator(registeredBean, null); + assertThat(methodGenerator).extracting("aotContributions").asList() + .containsExactly(beanContribution, loaderContribution); + } + + private RegisteredBean registerTestBean(DefaultListableBeanFactory beanFactory) { + beanFactory.registerBeanDefinition("test", BeanDefinitionBuilder + .rootBeanDefinition(TestBean.class).getBeanDefinition()); + return RegisteredBean.of(beanFactory, "test"); + } + + static class MockBeanRegistrationExcludeFilter + implements BeanRegistrationExcludeFilter, Ordered { + + private final boolean excluded; + + private final int order; + + private RegisteredBean registeredBean; + + MockBeanRegistrationExcludeFilter(boolean excluded, int order) { + this.excluded = excluded; + this.order = order; + } + + @Override + public boolean isExcluded(RegisteredBean registeredBean) { + this.registeredBean = registeredBean; + return this.excluded; + } + + @Override + public int getOrder() { + return this.order; + } + + boolean wasCalled() { + return this.registeredBean != null; + } + + } + + static class TestBean { + + } + + static class InnerTestBean { + + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java new file mode 100644 index 00000000000..8c08e401f20 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -0,0 +1,402 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.DefaultGenerationContext; +import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.InMemoryGeneratedFiles; +import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.test.generator.compile.CompileWithTargetClassAccess; +import org.springframework.aot.test.generator.compile.Compiled; +import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.aot.test.generator.file.SourceFile; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.InstanceSupplier; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.AnnotatedBean; +import org.springframework.beans.testfixture.beans.GenericBean; +import org.springframework.beans.testfixture.beans.TestBean; +import org.springframework.core.ResolvableType; +import org.springframework.core.mock.MockSpringFactoriesLoader; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.javapoet.TypeSpec; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link BeanDefinitionMethodGenerator} and + * {@link DefaultBeanRegistrationCodeFragments}. + * + * @author Phillip Webb + */ +class BeanDefinitionMethodGeneratorTests { + + private InMemoryGeneratedFiles generatedFiles; + + private DefaultGenerationContext generationContext; + + private DefaultListableBeanFactory beanFactory; + + private MockSpringFactoriesLoader springFactoriesLoader; + + private MockBeanRegistrationsCode beanRegistrationsCode; + + private BeanDefinitionMethodGeneratorFactory methodGeneratorFactory; + + @BeforeEach + void setup() { + this.generatedFiles = new InMemoryGeneratedFiles(); + this.generationContext = new DefaultGenerationContext(this.generatedFiles); + this.beanFactory = new DefaultListableBeanFactory(); + this.springFactoriesLoader = new MockSpringFactoriesLoader(); + this.methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( + new AotFactoriesLoader(this.beanFactory, this.springFactoriesLoader)); + this.beanRegistrationsCode = new MockBeanRegistrationsCode( + ClassName.get("__", "Registration")); + } + + @Test + void generateBeanDefinitionMethodGeneratesMethod() { + RegisteredBean registeredBean = registerBean( + new RootBeanDefinition(TestBean.class)); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList(), Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + testCompiledResult(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("beanType = TestBean.class"); + assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateBeanDefinitionMethodWhenHasGenericsGeneratesMethod() { + RegisteredBean registeredBean = registerBean(new RootBeanDefinition( + ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class))); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList(), Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + testCompiledResult(method, (actual, compiled) -> { + assertThat(actual.getResolvableType().resolve()).isEqualTo(GenericBean.class); + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains( + "beanType = ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class)"); + assertThat(sourceFile).contains("setInstanceSupplier(GenericBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateBeanDefinitionMethodWhenHasInstancePostProcessorGeneratesMethod() { + RegisteredBean registeredBean = registerBean( + new RootBeanDefinition(TestBean.class)); + BeanRegistrationAotContribution aotContribution = (generationContext, + beanRegistrationCode) -> { + GeneratedMethod method = beanRegistrationCode.getMethodGenerator() + .generateMethod("postProcess") + .using(builder -> builder.addModifiers(Modifier.STATIC) + .addParameter(RegisteredBean.class, "registeredBean") + .addParameter(TestBean.class, "testBean") + .returns(TestBean.class).addCode("return new $T($S);", + TestBean.class, "postprocessed")); + beanRegistrationCode.addInstancePostProcessor(MethodReference.ofStatic( + beanRegistrationCode.getClassName(), method.getName().toString())); + }; + List aotContributions = Collections + .singletonList(aotContribution); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, aotContributions, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + testCompiledResult(method, (actual, compiled) -> { + assertThat(actual.getBeanClass()).isEqualTo(TestBean.class); + InstanceSupplier supplier = (InstanceSupplier) actual + .getInstanceSupplier(); + try { + TestBean instance = (TestBean) supplier.get(registeredBean); + assertThat(instance.getName()).isEqualTo("postprocessed"); + } + catch (Exception ex) { + } + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("instanceSupplier.andThen("); + }); + } + + @Test + void generateBeanDefinitionMethodWhenHasAttributeFilterGeneratesMethod() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class); + beanDefinition.setAttribute("a", "A"); + beanDefinition.setAttribute("b", "B"); + RegisteredBean registeredBean = registerBean(beanDefinition); + List fragmentCustomizers = Collections + .singletonList(this::customizeWithAttributeFilter); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList(), fragmentCustomizers); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + testCompiledResult(method, (actual, compiled) -> { + assertThat(actual.getAttribute("a")).isEqualTo("A"); + assertThat(actual.getAttribute("b")).isNull(); + }); + } + + private BeanRegistrationCodeFragments customizeWithAttributeFilter( + RegisteredBean registeredBean, BeanRegistrationCodeFragments codeFragments) { + return new BeanRegistrationCodeFragments(codeFragments) { + + @Override + public CodeBlock generateSetBeanDefinitionPropertiesCode( + GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, + RootBeanDefinition beanDefinition, + Predicate attributeFilter) { + return super.generateSetBeanDefinitionPropertiesCode(generationContext, + beanRegistrationCode, beanDefinition, name -> "a".equals(name)); + } + + }; + } + + @Test + void generateBeanDefinitionMethodWhenInnerBeanGeneratesMethod() { + RegisteredBean parent = registerBean(new RootBeanDefinition(TestBean.class)); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(parent, + new RootBeanDefinition(AnnotatedBean.class)); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, innerBean, "testInnerBean", + Collections.emptyList(), Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + testCompiledResult(method, (actual, compiled) -> { + assertThat(compiled.getSourceFile(".*BeanDefinitions")) + .contains("Get the inner-bean definition for 'testInnerBean'"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateBeanDefinitionMethodWhenHasInnerBeanPropertyValueGeneratesMethod() { + RootBeanDefinition innerBeanDefinition = (RootBeanDefinition) BeanDefinitionBuilder + .rootBeanDefinition(AnnotatedBean.class) + .setRole(BeanDefinition.ROLE_INFRASTRUCTURE).setPrimary(true) + .getBeanDefinition(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class); + beanDefinition.getPropertyValues().add("name", innerBeanDefinition); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList(), Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + testCompiledResult(method, (actual, compiled) -> { + RootBeanDefinition actualInnerBeanDefinition = (RootBeanDefinition) actual + .getPropertyValues().get("name"); + assertThat(actualInnerBeanDefinition.isPrimary()).isTrue(); + assertThat(actualInnerBeanDefinition.getRole()) + .isEqualTo(BeanDefinition.ROLE_INFRASTRUCTURE); + Supplier innerInstanceSupplier = actualInnerBeanDefinition + .getInstanceSupplier(); + try { + assertThat(innerInstanceSupplier.get()).isInstanceOf(AnnotatedBean.class); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + }); + } + + @Test + void generateBeanDefinitionMethodWhenHasInnerBeanConstructorValueGeneratesMethod() { + RootBeanDefinition innerBeanDefinition = (RootBeanDefinition) BeanDefinitionBuilder + .rootBeanDefinition(String.class) + .setRole(BeanDefinition.ROLE_INFRASTRUCTURE).setPrimary(true) + .getBeanDefinition(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class); + ValueHolder valueHolder = new ValueHolder(innerBeanDefinition); + valueHolder.setName("second"); + beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, + valueHolder); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList(), Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + testCompiledResult(method, (actual, compiled) -> { + RootBeanDefinition actualInnerBeanDefinition = (RootBeanDefinition) actual + .getConstructorArgumentValues() + .getIndexedArgumentValue(0, RootBeanDefinition.class).getValue(); + assertThat(actualInnerBeanDefinition.isPrimary()).isTrue(); + assertThat(actualInnerBeanDefinition.getRole()) + .isEqualTo(BeanDefinition.ROLE_INFRASTRUCTURE); + Supplier innerInstanceSupplier = actualInnerBeanDefinition + .getInstanceSupplier(); + try { + assertThat(innerInstanceSupplier.get()).isInstanceOf(String.class); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + assertThat(compiled.getSourceFile(".*BeanDefinitions")) + .contains("getSecondBeanDefinition()"); + }); + } + + @Test + void generateBeanDefinitionMethodWhenHasAotContributionsAppliesContributions() { + RegisteredBean registeredBean = registerBean( + new RootBeanDefinition(TestBean.class)); + List aotContributions = new ArrayList<>(); + aotContributions + .add((generationContext, beanRegistrationCode) -> beanRegistrationCode + .getMethodGenerator().generateMethod("aotContributedMethod") + .using(builder -> builder.addComment("Example Contribution"))); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, aotContributions, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + testCompiledResult(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("AotContributedMethod()"); + assertThat(sourceFile).contains("Example Contribution"); + }); + } + + @Test + void generateBeanDefinitionMethodWhenHasBeanRegistrationCodeFragmentsCustomizerReturnsCodeGeneratesMethod() { + RegisteredBean registeredBean = registerBean( + new RootBeanDefinition(TestBean.class)); + List codeFragmentsCustomizers = new ArrayList<>(); + codeFragmentsCustomizers.add(this::customizeBeanRegistrationCodeFragments); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList(), codeFragmentsCustomizers); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + testCompiledResult(method, + (actual, compiled) -> assertThat( + compiled.getSourceFile(".*BeanDefinitions")) + .contains("// Custom Code")); + } + + private BeanRegistrationCodeFragments customizeBeanRegistrationCodeFragments( + RegisteredBean registeredBean, BeanRegistrationCodeFragments codeFragments) { + return new BeanRegistrationCodeFragments(codeFragments) { + + @Override + public CodeBlock generateReturnCode(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode) { + CodeBlock.Builder builder = CodeBlock.builder(); + builder.addStatement("// Custom Code"); + builder.add(super.generateReturnCode(generationContext, + beanRegistrationCode)); + return builder.build(); + } + + }; + } + + @Test + @CompileWithTargetClassAccess(classes = PackagePrivateTestBean.class) + void generateBeanDefinitionMethodWhenPackagePrivateBean() { + RegisteredBean registeredBean = registerBean( + new RootBeanDefinition(PackagePrivateTestBean.class)); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList(), Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + testCompiledResult(method, false, (actual, compiled) -> { + DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); + freshBeanFactory.registerBeanDefinition("test", actual); + Object bean = freshBeanFactory.getBean("test"); + assertThat(bean).isInstanceOf(PackagePrivateTestBean.class); + assertThat(compiled.getSourceFileFromPackage( + PackagePrivateTestBean.class.getPackageName())).isNotNull(); + }); + } + + private RegisteredBean registerBean(RootBeanDefinition beanDefinition) { + String beanName = "testBean"; + this.beanFactory.registerBeanDefinition(beanName, beanDefinition); + RegisteredBean registeredBean = RegisteredBean.of(this.beanFactory, beanName); + return registeredBean; + } + + @SuppressWarnings("unchecked") + private void testCompiledResult(MethodReference method, + BiConsumer result) { + testCompiledResult(method, false, result); + } + + @SuppressWarnings("unchecked") + private void testCompiledResult(MethodReference method, boolean targetClassAccess, + BiConsumer result) { + this.generationContext.writeGeneratedContent(); + JavaFile javaFile = generateJavaFile(method); + TestCompiler.forSystem().withFiles(this.generatedFiles).printFiles(System.out) + .compile(javaFile::writeTo, compiled -> result.accept( + (RootBeanDefinition) compiled.getInstance(Supplier.class).get(), + compiled)); + } + + private JavaFile generateJavaFile(MethodReference method) { + TypeSpec.Builder builder = TypeSpec.classBuilder("Registration"); + builder.addModifiers(Modifier.PUBLIC); + builder.addSuperinterface( + ParameterizedTypeName.get(Supplier.class, BeanDefinition.class)); + builder.addMethod(MethodSpec.methodBuilder("get").addModifiers(Modifier.PUBLIC) + .returns(BeanDefinition.class) + .addCode("return $L;", method.toInvokeCodeBlock()).build()); + this.beanRegistrationsCode.getGeneratedMethods() + .doWithMethodSpecs(builder::addMethod); + return JavaFile.builder("__", builder.build()).build(); + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java new file mode 100644 index 00000000000..cba19031d51 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java @@ -0,0 +1,441 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.hint.ExecutableMode; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.test.generator.compile.Compiled; +import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.BeanReference; +import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; +import org.springframework.beans.factory.config.RuntimeBeanNameReference; +import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.ManagedList; +import org.springframework.beans.factory.support.ManagedMap; +import org.springframework.beans.factory.support.ManagedSet; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.javapoet.TypeSpec; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link BeanDefinitionPropertiesCodeGenerator}. + * + * @author Phillip Webb + * @author Stephane Nicoll + */ +class BeanDefinitionPropertiesCodeGeneratorTests { + + private final RootBeanDefinition beanDefinition = new RootBeanDefinition(); + + private final GeneratedMethods generatedMethods = new GeneratedMethods(); + + private final RuntimeHints hints = new RuntimeHints(); + + private BeanDefinitionPropertiesCodeGenerator generator = new BeanDefinitionPropertiesCodeGenerator( + this.hints, attribute -> true, this.generatedMethods, (name, value) -> null); + + + @Test + void setPrimaryWhenFalse() { + this.beanDefinition.setPrimary(false); + testCompiledResult((actual, compiled) -> { + assertThat(compiled.getSourceFile()).doesNotContain("setPrimary"); + assertThat(actual.isPrimary()).isFalse(); + }); + } + + @Test + void setPrimaryWhenTrue() { + this.beanDefinition.setPrimary(true); + testCompiledResult((actual, compiled) -> assertThat(actual.isPrimary()).isTrue()); + } + + @Test + void setScopeWhenEmptyString() { + this.beanDefinition.setScope(""); + testCompiledResult((actual, compiled) -> { + assertThat(compiled.getSourceFile()).doesNotContain("setScope"); + assertThat(actual.getScope()).isEmpty(); + }); + } + + @Test + void setScopeWhenSingleton() { + this.beanDefinition.setScope("singleton"); + testCompiledResult((actual, compiled) -> { + assertThat(compiled.getSourceFile()).doesNotContain("setScope"); + assertThat(actual.getScope()).isEmpty(); + }); + } + + @Test + void setScopeWhenOther() { + this.beanDefinition.setScope("prototype"); + testCompiledResult((actual, compiled) -> assertThat(actual.getScope()) + .isEqualTo("prototype")); + } + + @Test + void setDependsOnWhenEmpty() { + this.beanDefinition.setDependsOn(); + testCompiledResult((actual, compiled) -> { + assertThat(compiled.getSourceFile()).doesNotContain("setDependsOn"); + assertThat(actual.getDependsOn()).isNull(); + }); + } + + @Test + void setDependsOnWhenNotEmpty() { + this.beanDefinition.setDependsOn("a", "b", "c"); + testCompiledResult((actual, compiled) -> assertThat(actual.getDependsOn()) + .containsExactly("a", "b", "c")); + } + + @Test + void setLazyInitWhenNoSet() { + testCompiledResult((actual, compiled) -> { + assertThat(compiled.getSourceFile()).doesNotContain("setLazyInit"); + assertThat(actual.isLazyInit()).isFalse(); + assertThat(actual.getLazyInit()).isNull(); + }); + } + + @Test + void setLazyInitWhenFalse() { + this.beanDefinition.setLazyInit(false); + testCompiledResult((actual, compiled) -> { + assertThat(actual.isLazyInit()).isFalse(); + assertThat(actual.getLazyInit()).isFalse(); + }); + } + + @Test + void setLazyInitWhenTrue() { + this.beanDefinition.setLazyInit(true); + testCompiledResult((actual, compiled) -> { + assertThat(actual.isLazyInit()).isTrue(); + assertThat(actual.getLazyInit()).isTrue(); + }); + } + + @Test + void setAutowireCandidateWhenFalse() { + this.beanDefinition.setAutowireCandidate(false); + testCompiledResult( + (actual, compiled) -> assertThat(actual.isAutowireCandidate()).isFalse()); + } + + @Test + void setAutowireCandidateWhenTrue() { + this.beanDefinition.setAutowireCandidate(true); + testCompiledResult((actual, compiled) -> { + assertThat(compiled.getSourceFile()).doesNotContain("setAutowireCandidate"); + assertThat(actual.isAutowireCandidate()).isTrue(); + }); + } + + @Test + void setSyntheticWhenFalse() { + this.beanDefinition.setSynthetic(false); + testCompiledResult((actual, compiled) -> { + assertThat(compiled.getSourceFile()).doesNotContain("setSynthetic"); + assertThat(actual.isSynthetic()).isFalse(); + }); + } + + @Test + void setSyntheticWhenTrue() { + this.beanDefinition.setSynthetic(true); + testCompiledResult( + (actual, compiled) -> assertThat(actual.isSynthetic()).isTrue()); + } + + @Test + void setRoleWhenApplication() { + this.beanDefinition.setRole(BeanDefinition.ROLE_APPLICATION); + testCompiledResult((actual, compiled) -> { + assertThat(compiled.getSourceFile()).doesNotContain("setRole"); + assertThat(actual.getRole()).isEqualTo(BeanDefinition.ROLE_APPLICATION); + }); + } + + @Test + void setRoleWhenInfrastructure() { + this.beanDefinition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + testCompiledResult((actual, compiled) -> { + assertThat(compiled.getSourceFile()) + .contains("setRole(BeanDefinition.ROLE_INFRASTRUCTURE);"); + assertThat(actual.getRole()).isEqualTo(BeanDefinition.ROLE_INFRASTRUCTURE); + }); + } + + @Test + void setRoleWhenSupport() { + this.beanDefinition.setRole(BeanDefinition.ROLE_SUPPORT); + testCompiledResult((actual, compiled) -> { + assertThat(compiled.getSourceFile()) + .contains("setRole(BeanDefinition.ROLE_SUPPORT);"); + assertThat(actual.getRole()).isEqualTo(BeanDefinition.ROLE_SUPPORT); + }); + } + + @Test + void setRoleWhenOther() { + this.beanDefinition.setRole(999); + testCompiledResult( + (actual, compiled) -> assertThat(actual.getRole()).isEqualTo(999)); + } + + @Test + void setInitMethodWhenSingleInitMethod() { + this.beanDefinition.setTargetType(InitDestroyBean.class); + this.beanDefinition.setInitMethodName("i1"); + testCompiledResult((actual, compiled) -> assertThat(actual.getInitMethodNames()) + .containsExactly("i1")); + assertHasMethodInvokeHints("i1"); + } + + @Test + void setInitMethodWhenMultipleInitMethods() { + this.beanDefinition.setTargetType(InitDestroyBean.class); + this.beanDefinition.setInitMethodNames("i1", "i2"); + testCompiledResult((actual, compiled) -> assertThat(actual.getInitMethodNames()) + .containsExactly("i1", "i2")); + assertHasMethodInvokeHints("i1", "i2"); + } + + @Test + void setDestroyMethodWhenDestroyInitMethod() { + this.beanDefinition.setTargetType(InitDestroyBean.class); + this.beanDefinition.setDestroyMethodName("d1"); + testCompiledResult( + (actual, compiled) -> assertThat(actual.getDestroyMethodNames()) + .containsExactly("d1")); + assertHasMethodInvokeHints("d1"); + } + + @Test + void setDestroyMethodWhenMultipleDestroyMethods() { + this.beanDefinition.setTargetType(InitDestroyBean.class); + this.beanDefinition.setDestroyMethodNames("d1", "d2"); + testCompiledResult( + (actual, compiled) -> assertThat(actual.getDestroyMethodNames()) + .containsExactly("d1", "d2")); + assertHasMethodInvokeHints("d1", "d2"); + } + + private void assertHasMethodInvokeHints(String... methodNames) { + assertThat(hints.reflection().getTypeHint(InitDestroyBean.class)) + .satisfies(typeHint -> { + for (String methodName : methodNames) { + assertThat(typeHint.methods()).anySatisfy(methodHint -> { + assertThat(methodHint.getName()).isEqualTo(methodName); + assertThat(methodHint.getModes()) + .containsExactly(ExecutableMode.INVOKE); + }); + } + }); + } + + @Test + void constructorArgumentValuesWhenValues() { + this.beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, + String.class); + this.beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(1, + "test"); + this.beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(2, + 123); + testCompiledResult((actual, compiled) -> { + Map values = actual.getConstructorArgumentValues() + .getIndexedArgumentValues(); + assertThat(values.get(0).getValue()).isEqualTo(String.class); + assertThat(values.get(1).getValue()).isEqualTo("test"); + assertThat(values.get(2).getValue()).isEqualTo(123); + }); + } + + @Test + void propertyValuesWhenValues() { + this.beanDefinition.getPropertyValues().add("test", String.class); + this.beanDefinition.getPropertyValues().add("spring", "framework"); + testCompiledResult((actual, compiled) -> { + assertThat(actual.getPropertyValues().get("test")).isEqualTo(String.class); + assertThat(actual.getPropertyValues().get("spring")).isEqualTo("framework"); + }); + } + + @Test + void propertyValuesWhenContainsBeanReference() { + this.beanDefinition.getPropertyValues().add("myService", + new RuntimeBeanNameReference("test")); + testCompiledResult((actual, compiled) -> { + assertThat(actual.getPropertyValues().contains("myService")).isTrue(); + assertThat(actual.getPropertyValues().get("myService")) + .isInstanceOfSatisfying(RuntimeBeanReference.class, + beanReference -> assertThat(beanReference.getBeanName()) + .isEqualTo("test")); + }); + } + + @Test + void propertyValuesWhenContainsManagedList() { + ManagedList managedList = new ManagedList<>(); + managedList.add(new RuntimeBeanNameReference("test")); + this.beanDefinition.getPropertyValues().add("value", managedList); + testCompiledResult((actual, compiled) -> { + Object value = actual.getPropertyValues().get("value"); + assertThat(value).isInstanceOf(ManagedList.class); + assertThat(((List) value).get(0)).isInstanceOf(BeanReference.class); + }); + } + + @Test + void propertyValuesWhenContainsManagedSet() { + ManagedSet managedSet = new ManagedSet<>(); + managedSet.add(new RuntimeBeanNameReference("test")); + this.beanDefinition.getPropertyValues().add("value", managedSet); + testCompiledResult((actual, compiled) -> { + Object value = actual.getPropertyValues().get("value"); + assertThat(value).isInstanceOf(ManagedSet.class); + assertThat(((Set) value).iterator().next()) + .isInstanceOf(BeanReference.class); + }); + } + + @Test + void propertyValuesWhenContainsManagedMap() { + ManagedMap managedMap = new ManagedMap<>(); + managedMap.put("test", new RuntimeBeanNameReference("test")); + this.beanDefinition.getPropertyValues().add("value", managedMap); + testCompiledResult((actual, compiled) -> { + Object value = actual.getPropertyValues().get("value"); + assertThat(value).isInstanceOf(ManagedMap.class); + assertThat(((Map) value).get("test")).isInstanceOf(BeanReference.class); + }); + } + + @Test + void attributesWhenAllFiltered() { + this.beanDefinition.setAttribute("a", "A"); + this.beanDefinition.setAttribute("b", "B"); + Predicate attributeFilter = attribute -> false; + this.generator = new BeanDefinitionPropertiesCodeGenerator(this.hints, + attributeFilter, this.generatedMethods, (name, value) -> null); + testCompiledResult((actual, compiled) -> { + assertThat(compiled.getSourceFile()).doesNotContain("setAttribute"); + assertThat(actual.getAttribute("a")).isNull(); + assertThat(actual.getAttribute("b")).isNull(); + }); + } + + @Test + void attributesWhenSomeFiltered() { + this.beanDefinition.setAttribute("a", "A"); + this.beanDefinition.setAttribute("b", "B"); + Predicate attributeFilter = attribute -> "a".equals(attribute); + this.generator = new BeanDefinitionPropertiesCodeGenerator(this.hints, + attributeFilter, this.generatedMethods, (name, value) -> null); + testCompiledResult(this.beanDefinition, (actual, compiled) -> { + assertThat(actual.getAttribute("a")).isEqualTo("A"); + assertThat(actual.getAttribute("b")).isNull(); + }); + } + + @Test + void multipleItems() { + this.beanDefinition.setPrimary(true); + this.beanDefinition.setScope("test"); + this.beanDefinition.setRole(BeanDefinition.ROLE_SUPPORT); + testCompiledResult((actual, compiled) -> { + assertThat(actual.isPrimary()).isTrue(); + assertThat(actual.getScope()).isEqualTo("test"); + assertThat(actual.getRole()).isEqualTo(BeanDefinition.ROLE_SUPPORT); + }); + } + + private void testCompiledResult(BiConsumer result) { + testCompiledResult(this.beanDefinition, result); + } + + private void testCompiledResult(RootBeanDefinition beanDefinition, + BiConsumer result) { + testCompiledResult(() -> this.generator.generateCode(beanDefinition), result); + } + + private void testCompiledResult(Supplier codeBlock, + BiConsumer result) { + JavaFile javaFile = createJavaFile(codeBlock); + TestCompiler.forSystem().compile(javaFile::writeTo, compiled -> { + RootBeanDefinition beanDefinition = (RootBeanDefinition) compiled + .getInstance(Supplier.class).get(); + result.accept(beanDefinition, compiled); + }); + } + + private JavaFile createJavaFile(Supplier codeBlock) { + TypeSpec.Builder builder = TypeSpec.classBuilder("BeanSupplier"); + builder.addModifiers(Modifier.PUBLIC); + builder.addSuperinterface( + ParameterizedTypeName.get(Supplier.class, RootBeanDefinition.class)); + builder.addMethod(MethodSpec.methodBuilder("get").addModifiers(Modifier.PUBLIC) + .returns(RootBeanDefinition.class) + .addStatement("$T beanDefinition = new $T()", RootBeanDefinition.class, + RootBeanDefinition.class) + .addStatement("$T beanFactory = new $T()", + DefaultListableBeanFactory.class, + DefaultListableBeanFactory.class) + .addCode(codeBlock.get()).addStatement("return beanDefinition").build()); + this.generatedMethods.doWithMethodSpecs(builder::addMethod); + return JavaFile.builder("com.example", builder.build()).build(); + } + + static class InitDestroyBean { + + void i1() { + } + + void i2() { + } + + void d1() { + } + + void d2() { + } + + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGeneratorTests.java new file mode 100644 index 00000000000..6acd67d10e0 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGeneratorTests.java @@ -0,0 +1,483 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.io.InputStream; +import java.io.OutputStream; +import java.time.temporal.ChronoUnit; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.test.generator.compile.Compiled; +import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.aot.test.generator.file.SourceFile; +import org.springframework.beans.factory.config.BeanReference; +import org.springframework.beans.factory.config.RuntimeBeanNameReference; +import org.springframework.beans.factory.support.ManagedList; +import org.springframework.beans.factory.support.ManagedMap; +import org.springframework.beans.factory.support.ManagedSet; +import org.springframework.core.ResolvableType; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.javapoet.TypeSpec; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link BeanDefinitionPropertyValueCodeGenerator}. + * + * @author Stephane Nicoll + * @author Phillip Webb + * @since 6.0 + * @see BeanDefinitionPropertyValueCodeGeneratorTests + */ +class BeanDefinitionPropertyValueCodeGeneratorTests { + + private GeneratedMethods generatedMethods = new GeneratedMethods(); + + private BeanDefinitionPropertyValueCodeGenerator instance = new BeanDefinitionPropertyValueCodeGenerator( + generatedMethods); + + private void compile(Object value, BiConsumer result) { + CodeBlock code = instance.generateCode(value); + JavaFile javaFile = createJavaFile(code); + TestCompiler.forSystem().compile(SourceFile.of(javaFile::writeTo), + compiled -> result.accept(compiled.getInstance(Supplier.class).get(), + compiled)); + } + + private JavaFile createJavaFile(CodeBlock code) { + TypeSpec.Builder builder = TypeSpec.classBuilder("InstanceSupplier"); + builder.addModifiers(Modifier.PUBLIC); + builder.addSuperinterface( + ParameterizedTypeName.get(Supplier.class, Object.class)); + builder.addMethod(MethodSpec.methodBuilder("get").addModifiers(Modifier.PUBLIC) + .returns(Object.class).addStatement("return $L", code).build()); + generatedMethods.doWithMethodSpecs(builder::addMethod); + return JavaFile.builder("com.example", builder.build()).build(); + } + + @Nested + class NullTests { + + @Test + void generateWhenNull() { + compile(null, (instance, compiled) -> assertThat(instance).isNull()); + } + + } + + @Nested + class PrimitiveTests { + + @Test + void generateWhenBoolean() { + compile(true, (instance, compiled) -> { + assertThat(instance).isEqualTo(Boolean.TRUE); + assertThat(compiled.getSourceFile()).contains("true"); + }); + } + + @Test + void generateWhenByte() { + compile((byte) 2, (instance, compiled) -> { + assertThat(instance).isEqualTo((byte) 2); + assertThat(compiled.getSourceFile()).contains("(byte) 2"); + }); + } + + @Test + void generateWhenShort() { + compile((short) 3, (instance, compiled) -> { + assertThat(instance).isEqualTo((short) 3); + assertThat(compiled.getSourceFile()).contains("(short) 3"); + }); + } + + @Test + void generateWhenInt() { + compile(4, (instance, compiled) -> { + assertThat(instance).isEqualTo(4); + assertThat(compiled.getSourceFile()).contains("return 4;"); + }); + } + + @Test + void generateWhenLong() { + compile(5L, (instance, compiled) -> { + assertThat(instance).isEqualTo(5L); + assertThat(compiled.getSourceFile()).contains("5L"); + }); + } + + @Test + void generateWhenFloat() { + compile(0.1F, (instance, compiled) -> { + assertThat(instance).isEqualTo(0.1F); + assertThat(compiled.getSourceFile()).contains("0.1F"); + }); + } + + @Test + void generateWhenDouble() { + compile(0.2, (instance, compiled) -> { + assertThat(instance).isEqualTo(0.2); + assertThat(compiled.getSourceFile()).contains("(double) 0.2"); + }); + } + + @Test + void generateWhenChar() { + compile('a', (instance, compiled) -> { + assertThat(instance).isEqualTo('a'); + assertThat(compiled.getSourceFile()).contains("'a'"); + }); + } + + @Test + void generateWhenSimpleEscapedCharReturnsEscaped() { + testEscaped('\b', "'\\b'"); + testEscaped('\t', "'\\t'"); + testEscaped('\n', "'\\n'"); + testEscaped('\f', "'\\f'"); + testEscaped('\r', "'\\r'"); + testEscaped('\"', "'\"'"); + testEscaped('\'', "'\\''"); + testEscaped('\\', "'\\\\'"); + } + + @Test + void generatedWhenUnicodeEscapedCharReturnsEscaped() { + testEscaped('\u007f', "'\\u007f'"); + } + + private void testEscaped(char value, String expectedSourceContent) { + compile(value, (instance, compiled) -> { + assertThat(instance).isEqualTo(value); + assertThat(compiled.getSourceFile()).contains(expectedSourceContent); + }); + } + + } + + @Nested + class StringTests { + + @Test + void generateWhenString() { + compile("test\n", (instance, compiled) -> { + assertThat(instance).isEqualTo("test\n"); + assertThat(compiled.getSourceFile()).contains("\n"); + }); + } + + } + + @Nested + class EnumTests { + + @Test + void generateWhenEnum() { + compile(ChronoUnit.DAYS, (instance, compiled) -> { + assertThat(instance).isEqualTo(ChronoUnit.DAYS); + assertThat(compiled.getSourceFile()).contains("ChronoUnit.DAYS"); + }); + } + + @Test + void generateWhenEnumWithClassBody() { + compile(EnumWithClassBody.TWO, (instance, compiled) -> { + assertThat(instance).isEqualTo(EnumWithClassBody.TWO); + assertThat(compiled.getSourceFile()).contains("EnumWithClassBody.TWO"); + }); + } + + } + + @Nested + class ClassTests { + + @Test + void generateWhenClass() { + compile(InputStream.class, (instance, compiled) -> assertThat(instance) + .isEqualTo(InputStream.class)); + } + + @Test + void generateWhenCglibClass() { + compile(ExampleClass$$GeneratedBy.class, (instance, + compiled) -> assertThat(instance).isEqualTo(ExampleClass.class)); + } + + } + + @Nested + class ResolvableTypeTests { + + @Test + void generateWhenSimpleResolvableType() { + ResolvableType resolvableType = ResolvableType.forClass(String.class); + compile(resolvableType, (instance, compiled) -> assertThat(instance) + .isEqualTo(resolvableType)); + } + + @Test + void generateWhenNoneResolvableType() { + ResolvableType resolvableType = ResolvableType.NONE; + compile(resolvableType, (instance, compiled) -> { + assertThat(instance).isEqualTo(resolvableType); + assertThat(compiled.getSourceFile()).contains("ResolvableType.NONE"); + }); + } + + @Test + void generateWhenGenericResolvableType() { + ResolvableType resolvableType = ResolvableType + .forClassWithGenerics(List.class, String.class); + compile(resolvableType, (instance, compiled) -> assertThat(instance) + .isEqualTo(resolvableType)); + } + + @Test + void generateWhenNestedGenericResolvableType() { + ResolvableType stringList = ResolvableType.forClassWithGenerics(List.class, + String.class); + ResolvableType resolvableType = ResolvableType.forClassWithGenerics(Map.class, + ResolvableType.forClass(Integer.class), stringList); + compile(resolvableType, (instance, compiled) -> assertThat(instance) + .isEqualTo(resolvableType)); + } + + } + + @Nested + class ArrayTests { + + @Test + void generateWhenPrimitiveArray() { + byte[] bytes = { 0, 1, 2 }; + compile(bytes, (instance, compiler) -> { + assertThat(instance).isEqualTo(bytes); + assertThat(compiler.getSourceFile()).contains("new byte[]"); + }); + } + + @Test + void generateWhenWrapperArray() { + Byte[] bytes = { 0, 1, 2 }; + compile(bytes, (instance, compiler) -> { + assertThat(instance).isEqualTo(bytes); + assertThat(compiler.getSourceFile()).contains("new Byte[]"); + }); + } + + @Test + void generateWhenClassArray() { + Class[] classes = new Class[] { InputStream.class, OutputStream.class }; + compile(classes, (instance, compiler) -> { + assertThat(instance).isEqualTo(classes); + assertThat(compiler.getSourceFile()).contains("new Class[]"); + }); + } + + } + + @Nested + class ManagedListTests { + + @Test + void generateWhenStringManagedList() { + ManagedList list = new ManagedList<>(); + list.add("a"); + list.add("b"); + list.add("c"); + compile(list, (instance, compiler) -> assertThat(instance).isEqualTo(list) + .isInstanceOf(ManagedList.class)); + } + + @Test + void generateWhenEmptyManagedList() { + ManagedList list = new ManagedList<>(); + compile(list, (instance, compiler) -> assertThat(instance).isEqualTo(list) + .isInstanceOf(ManagedList.class)); + } + + } + + @Nested + class ManagedSetTests { + + @Test + void generateWhenStringManagedSet() { + ManagedSet set = new ManagedSet<>(); + set.add("a"); + set.add("b"); + set.add("c"); + compile(set, (instance, compiler) -> assertThat(instance).isEqualTo(set) + .isInstanceOf(ManagedSet.class)); + } + + @Test + void generateWhenEmptyManagedSet() { + ManagedSet set = new ManagedSet<>(); + compile(set, (instance, compiler) -> assertThat(instance).isEqualTo(set) + .isInstanceOf(ManagedSet.class)); + } + + } + + @Nested + class ManagedMapTests { + + @Test + void generateWhenManagedMap() { + ManagedMap map = new ManagedMap<>(); + map.put("k1", "v1"); + map.put("k2", "v2"); + compile(map, (instance, compiler) -> assertThat(instance).isEqualTo(map) + .isInstanceOf(ManagedMap.class)); + } + + @Test + void generateWhenEmptyManagedMap() { + ManagedMap map = new ManagedMap<>(); + compile(map, (instance, compiler) -> assertThat(instance).isEqualTo(map) + .isInstanceOf(ManagedMap.class)); + } + + } + + @Nested + class ListTests { + + @Test + void generateWhenStringList() { + List list = List.of("a", "b", "c"); + compile(list, (instance, compiler) -> assertThat(instance).isEqualTo(list) + .isNotInstanceOf(ManagedList.class)); + } + + @Test + void generateWhenEmptyList() { + List list = List.of(); + compile(list, (instance, compiler) -> { + assertThat(instance).isEqualTo(list); + assertThat(compiler.getSourceFile()).contains("Collections.emptyList();"); + }); + } + + } + + @Nested + class SetTests { + + @Test + void generateWhenStringSet() { + Set set = Set.of("a", "b", "c"); + compile(set, (instance, compiler) -> assertThat(instance).isEqualTo(set) + .isNotInstanceOf(ManagedSet.class)); + } + + @Test + void generateWhenEmptySet() { + Set set = Set.of(); + compile(set, (instance, compiler) -> { + assertThat(instance).isEqualTo(set); + assertThat(compiler.getSourceFile()).contains("Collections.emptySet();"); + }); + } + + @Test + void generateWhenLinkedHashSet() { + Set set = new LinkedHashSet<>(List.of("a", "b", "c")); + compile(set, (instance, compiler) -> { + assertThat(instance).isEqualTo(set).isInstanceOf(LinkedHashSet.class); + assertThat(compiler.getSourceFile()) + .contains("new LinkedHashSet(List.of("); + }); + } + + } + + @Nested + class MapTests { + + @Test + void generateWhenSmallMap() { + Map map = Map.of("k1", "v1", "k2", "v2"); + compile(map, (instance, compiler) -> { + assertThat(instance).isEqualTo(map); + assertThat(compiler.getSourceFile()).contains("Map.of("); + }); + } + + @Test + void generateWhenMapWithOverTenElements() { + Map map = new HashMap<>(); + for (int i = 1; i <= 11; i++) { + map.put("k" + i, "v" + i); + } + compile(map, (instance, compiler) -> { + assertThat(instance).isEqualTo(map); + assertThat(compiler.getSourceFile()).contains("Map.ofEntries("); + }); + } + + @Test + void generateWhenLinkedHashMap() { + Map map = new LinkedHashMap<>(); + map.put("a", "A"); + map.put("b", "B"); + map.put("c", "C"); + compile(map, (instance, compiler) -> { + assertThat(instance).isEqualTo(map).isInstanceOf(LinkedHashMap.class); + assertThat(compiler.getSourceFile()).contains("getMap()"); + }); + } + + } + + @Nested + class BeanReferenceTests { + + @Test + void generatedWhenBeanReference() { + BeanReference beanReference = new RuntimeBeanNameReference("test"); + compile(beanReference, + (instance, + compiler) -> assertThat( + ((BeanReference) instance).getBeanName()) + .isEqualTo(beanReference.getBeanName())); + } + + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java new file mode 100644 index 00000000000..1cca7e1bdf7 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java @@ -0,0 +1,178 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.DefaultGenerationContext; +import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.InMemoryGeneratedFiles; +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.test.generator.compile.Compiled; +import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.TestBean; +import org.springframework.core.mock.MockSpringFactoriesLoader; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.javapoet.TypeSpec; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link BeanRegistrationsAotContribution}. + * + * @author Phillip Webb + */ +class BeanRegistrationsAotContributionTests { + + private InMemoryGeneratedFiles generatedFiles; + + private DefaultGenerationContext generationContext; + + private DefaultListableBeanFactory beanFactory; + + private MockSpringFactoriesLoader springFactoriesLoader; + + private BeanDefinitionMethodGeneratorFactory methodGeneratorFactory; + + private final MockBeanFactoryInitializationCode beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); + + @BeforeEach + void setup() { + this.generatedFiles = new InMemoryGeneratedFiles(); + this.generationContext = new DefaultGenerationContext(this.generatedFiles); + this.beanFactory = new DefaultListableBeanFactory(); + this.springFactoriesLoader = new MockSpringFactoriesLoader(); + this.methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( + new AotFactoriesLoader(this.beanFactory, this.springFactoriesLoader)); + } + + @Test + void applyToAppliesContribution() { + Map registrations = new LinkedHashMap<>(); + RegisteredBean registeredBean = registerBean( + new RootBeanDefinition(TestBean.class)); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList(), Collections.emptyList()); + registrations.put("testBean", generator); + BeanRegistrationsAotContribution contribution = new BeanRegistrationsAotContribution( + registrations); + contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); + testCompiledResult((consumer, compiled) -> { + DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); + consumer.accept(freshBeanFactory); + assertThat(freshBeanFactory.getBean(TestBean.class)).isNotNull(); + }); + } + + @Test + void applyToCallsRegistrationsWithBeanRegistrationsCode() { + List beanRegistrationsCodes = new ArrayList<>(); + Map registrations = new LinkedHashMap<>(); + RegisteredBean registeredBean = registerBean( + new RootBeanDefinition(TestBean.class)); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList(), Collections.emptyList()) { + + @Override + MethodReference generateBeanDefinitionMethod( + GenerationContext generationContext, + BeanRegistrationsCode beanRegistrationsCode) { + beanRegistrationsCodes.add(beanRegistrationsCode); + return super.generateBeanDefinitionMethod(generationContext, + beanRegistrationsCode); + } + + }; + registrations.put("testBean", generator); + BeanRegistrationsAotContribution contribution = new BeanRegistrationsAotContribution( + registrations); + contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); + assertThat(beanRegistrationsCodes).hasSize(1); + BeanRegistrationsCode actual = beanRegistrationsCodes.get(0); + assertThat(actual.getMethodGenerator()).isNotNull(); + } + + private RegisteredBean registerBean(RootBeanDefinition rootBeanDefinition) { + String beanName = "testBean"; + this.beanFactory.registerBeanDefinition(beanName, rootBeanDefinition); + return RegisteredBean.of(this.beanFactory, beanName); + } + + @SuppressWarnings({ "unchecked", "cast" }) + private void testCompiledResult( + BiConsumer, Compiled> result) { + this.generationContext.writeGeneratedContent(); + JavaFile javaFile = createJavaFile(); + TestCompiler.forSystem().withFiles(this.generatedFiles).compile(javaFile::writeTo, + compiled -> result.accept(compiled.getInstance(Consumer.class), + compiled)); + } + + private JavaFile createJavaFile() { + MethodReference initializer = this.beanFactoryInitializationCode.initializers + .get(0); + TypeSpec.Builder builder = TypeSpec.classBuilder("BeanFactoryConsumer"); + builder.addModifiers(Modifier.PUBLIC); + builder.addSuperinterface(ParameterizedTypeName.get(Consumer.class, + DefaultListableBeanFactory.class)); + builder.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) + .addParameter(DefaultListableBeanFactory.class, "beanFactory") + .addStatement(initializer.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .build()); + return JavaFile.builder("__", builder.build()).build(); + } + + class MockBeanFactoryInitializationCode implements BeanFactoryInitializationCode { + + private final GeneratedMethods generatedMethods = new GeneratedMethods(); + + private final List initializers = new ArrayList<>(); + + @Override + public MethodGenerator getMethodGenerator() { + return this.generatedMethods; + } + + @Override + public void addInitializer(MethodReference methodReference) { + this.initializers.add(methodReference); + } + + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotProcessorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotProcessorTests.java new file mode 100644 index 00000000000..471e006a885 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotProcessorTests.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import org.assertj.core.api.InstanceOfAssertFactories; +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.AnnotatedBean; +import org.springframework.beans.testfixture.beans.TestBean; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link BeanRegistrationsAotProcessor}. + * + * @author Phillip Webb + */ +class BeanRegistrationsAotProcessorTests { + + @Test + void processAheadOfTimeReturnsBeanRegistrationsAotContributionWithRegistrations() { + BeanRegistrationsAotProcessor processor = new BeanRegistrationsAotProcessor(); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("b1", new RootBeanDefinition(TestBean.class)); + beanFactory.registerBeanDefinition("b2", + new RootBeanDefinition(AnnotatedBean.class)); + BeanRegistrationsAotContribution contribution = processor + .processAheadOfTime(beanFactory); + assertThat(contribution).extracting("registrations") + .asInstanceOf(InstanceOfAssertFactories.MAP).containsKeys("b1", "b2"); + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/ConstructorOrFactoryMethodResolverTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/ConstructorOrFactoryMethodResolverTests.java new file mode 100644 index 00000000000..0db5b543ef0 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/ConstructorOrFactoryMethodResolverTests.java @@ -0,0 +1,499 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Executable; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.Executor; + +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolder; +import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolderFactoryBean; +import org.springframework.beans.testfixture.beans.factory.generator.factory.SampleFactory; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.MergedAnnotations.SearchStrategy; +import org.springframework.lang.Nullable; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; + +/** + * Tests for {@link ConstructorOrFactoryMethodResolver}. + * + * @author Stephane Nicoll + * @author Phillip Webb + */ +class ConstructorOrFactoryMethodResolverTests { + + @Test + void detectBeanInstanceExecutableWithBeanClassAndFactoryMethodName() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("testBean", "test"); + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(SampleFactory.class).setFactoryMethod("create") + .addConstructorArgReference("testBean").getBeanDefinition(); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + ReflectionUtils.findMethod(SampleFactory.class, "create", String.class)); + } + + @Test + void detectBeanInstanceExecutableWithBeanClassNameAndFactoryMethodName() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("testBean", "test"); + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(SampleFactory.class.getName()) + .setFactoryMethod("create").addConstructorArgReference("testBean") + .getBeanDefinition(); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + ReflectionUtils.findMethod(SampleFactory.class, "create", String.class)); + } + + @Test + void beanDefinitionWithFactoryMethodNameAndAssignableConstructorArg() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("testNumber", 1L); + beanFactory.registerSingleton("testBean", "test"); + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(SampleFactory.class).setFactoryMethod("create") + .addConstructorArgReference("testNumber") + .addConstructorArgReference("testBean").getBeanDefinition(); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo(ReflectionUtils + .findMethod(SampleFactory.class, "create", Number.class, String.class)); + } + + @Test + void beanDefinitionWithFactoryMethodNameAndMatchingMethodNamesThatShouldBeIgnored() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(DummySampleFactory.class).setFactoryMethod("of") + .addConstructorArgValue(42).getBeanDefinition(); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo(ReflectionUtils + .findMethod(DummySampleFactory.class, "of", Integer.class)); + } + + @Test + void detectBeanInstanceExecutableWithBeanClassAndFactoryMethodNameIgnoreTargetType() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("testBean", "test"); + RootBeanDefinition beanDefinition = (RootBeanDefinition) BeanDefinitionBuilder + .rootBeanDefinition(SampleFactory.class).setFactoryMethod("create") + .addConstructorArgReference("testBean").getBeanDefinition(); + beanDefinition.setTargetType(String.class); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + ReflectionUtils.findMethod(SampleFactory.class, "create", String.class)); + } + + @Test + void beanDefinitionWithConstructorArgsForMultipleConstructors() throws Exception { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("testNumber", 1L); + beanFactory.registerSingleton("testBean", "test"); + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(SampleBeanWithConstructors.class) + .addConstructorArgReference("testNumber") + .addConstructorArgReference("testBean").getBeanDefinition(); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo(SampleBeanWithConstructors.class + .getDeclaredConstructor(Number.class, String.class)); + } + + @Test + void genericBeanDefinitionWithConstructorArgsForMultipleConstructors() + throws Exception { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("testNumber", 1L); + beanFactory.registerSingleton("testBean", "test"); + BeanDefinition beanDefinition = BeanDefinitionBuilder + .genericBeanDefinition(SampleBeanWithConstructors.class) + .addConstructorArgReference("testNumber") + .addConstructorArgReference("testBean").getBeanDefinition(); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo(SampleBeanWithConstructors.class + .getDeclaredConstructor(Number.class, String.class)); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndMatchingValue() + throws NoSuchMethodException { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorSample.class) + .addConstructorArgValue(42).getBeanDefinition(); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + MultiConstructorSample.class.getDeclaredConstructor(Integer.class)); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndMatchingArrayValue() + throws NoSuchMethodException { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorArraySample.class) + .addConstructorArgValue(42).getBeanDefinition(); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isNotNull().isEqualTo(MultiConstructorArraySample.class + .getDeclaredConstructor(Integer[].class)); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndMatchingListValue() + throws NoSuchMethodException { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorListSample.class) + .addConstructorArgValue(42).getBeanDefinition(); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + MultiConstructorListSample.class.getDeclaredConstructor(List.class)); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndMatchingValueAsInnerBean() + throws NoSuchMethodException { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorSample.class) + .addConstructorArgValue( + BeanDefinitionBuilder.rootBeanDefinition(Integer.class, "valueOf") + .addConstructorArgValue("42").getBeanDefinition()) + .getBeanDefinition(); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + MultiConstructorSample.class.getDeclaredConstructor(Integer.class)); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndMatchingValueAsInnerBeanFactory() + throws NoSuchMethodException { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorSample.class) + .addConstructorArgValue(BeanDefinitionBuilder + .rootBeanDefinition(IntegerFactoryBean.class).getBeanDefinition()) + .getBeanDefinition(); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + MultiConstructorSample.class.getDeclaredConstructor(Integer.class)); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndNonMatchingValue() { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorSample.class) + .addConstructorArgValue(Locale.ENGLISH).getBeanDefinition(); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isNull(); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndNonMatchingValueAsInnerBean() { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorSample.class) + .addConstructorArgValue(BeanDefinitionBuilder + .rootBeanDefinition(Locale.class, "getDefault") + .getBeanDefinition()) + .getBeanDefinition(); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isNull(); + } + + @Test + void detectBeanInstanceExecutableWithFactoryBeanSetInBeanClass() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(); + beanDefinition.setTargetType( + ResolvableType.forClassWithGenerics(NumberHolder.class, Integer.class)); + beanDefinition.setBeanClass(NumberHolderFactoryBean.class); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull() + .isEqualTo(NumberHolderFactoryBean.class.getDeclaredConstructors()[0]); + } + + @Test + void detectBeanInstanceExecutableWithFactoryBeanSetInBeanClassAndNoResolvableType() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(); + beanDefinition.setBeanClass(NumberHolderFactoryBean.class); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull() + .isEqualTo(NumberHolderFactoryBean.class.getDeclaredConstructors()[0]); + } + + @Test + void detectBeanInstanceExecutableWithFactoryBeanSetInBeanClassThatDoesNotMatchTargetType() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(); + beanDefinition.setTargetType( + ResolvableType.forClassWithGenerics(NumberHolder.class, String.class)); + beanDefinition.setBeanClass(NumberHolderFactoryBean.class); + assertThatIllegalStateException() + .isThrownBy(() -> resolve(beanFactory, beanDefinition)) + .withMessageContaining("Incompatible target type") + .withMessageContaining(NumberHolder.class.getName()) + .withMessageContaining(NumberHolderFactoryBean.class.getName()); + } + + @Test + void beanDefinitionWithClassArrayConstructorArgAndStringArrayValueType() + throws NoSuchMethodException { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(ConstructorClassArraySample.class.getName()) + .addConstructorArgValue(new String[] { "test1, test2" }) + .getBeanDefinition(); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + ConstructorClassArraySample.class.getDeclaredConstructor(Class[].class)); + } + + @Test + void beanDefinitionWithClassArrayConstructorArgAndStringValueType() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(ConstructorClassArraySample.class.getName()) + .addConstructorArgValue("test1").getBeanDefinition(); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + ConstructorClassArraySample.class.getDeclaredConstructors()[0]); + } + + @Test + void beanDefinitionWithClassArrayConstructorArgAndAnotherMatchingConstructor() + throws NoSuchMethodException { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorClassArraySample.class.getName()) + .addConstructorArgValue(new String[] { "test1, test2" }) + .getBeanDefinition(); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull() + .isEqualTo(MultiConstructorClassArraySample.class + .getDeclaredConstructor(String[].class)); + } + + @Test + void beanDefinitionWithClassArrayFactoryMethodArgAndStringArrayValueType() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(ClassArrayFactoryMethodSample.class.getName()) + .setFactoryMethod("of") + .addConstructorArgValue(new String[] { "test1, test2" }) + .getBeanDefinition(); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo(ReflectionUtils + .findMethod(ClassArrayFactoryMethodSample.class, "of", Class[].class)); + } + + @Test + void beanDefinitionWithClassArrayFactoryMethodArgAndAnotherMatchingConstructor() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition( + ClassArrayFactoryMethodSampleWithAnotherFactoryMethod.class.getName()) + .setFactoryMethod("of").addConstructorArgValue("test1") + .getBeanDefinition(); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull() + .isEqualTo(ReflectionUtils.findMethod( + ClassArrayFactoryMethodSampleWithAnotherFactoryMethod.class, "of", + String[].class)); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndPrimitiveConversion() + throws NoSuchMethodException { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(ConstructorPrimitiveFallback.class) + .addConstructorArgValue("true").getBeanDefinition(); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isEqualTo( + ConstructorPrimitiveFallback.class.getDeclaredConstructor(boolean.class)); + } + + @Test + void beanDefinitionWithFactoryWithOverloadedClassMethodsOnInterface() { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(FactoryWithOverloadedClassMethodsOnInterface.class) + .setFactoryMethod("byAnnotation").addConstructorArgValue(Nullable.class) + .getBeanDefinition(); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isEqualTo(ReflectionUtils.findMethod( + FactoryWithOverloadedClassMethodsOnInterface.class, "byAnnotation", + Class.class)); + } + + private Executable resolve(DefaultListableBeanFactory beanFactory, + BeanDefinition beanDefinition) { + return new ConstructorOrFactoryMethodResolver(beanFactory) + .resolve(beanDefinition); + } + + static class IntegerFactoryBean implements FactoryBean { + + @Override + public Integer getObject() { + return 42; + } + + @Override + public Class getObjectType() { + return Integer.class; + } + } + + @SuppressWarnings("unused") + static class MultiConstructorSample { + + MultiConstructorSample(String name) { + } + + MultiConstructorSample(Integer value) { + } + + } + + @SuppressWarnings("unused") + static class MultiConstructorArraySample { + + public MultiConstructorArraySample(String... names) { + } + + public MultiConstructorArraySample(Integer... values) { + } + } + + @SuppressWarnings("unused") + static class MultiConstructorListSample { + + public MultiConstructorListSample(String name) { + } + + public MultiConstructorListSample(List values) { + } + + } + + interface DummyInterface { + + static String of(Object o) { + return o.toString(); + } + } + + @SuppressWarnings("unused") + static class DummySampleFactory implements DummyInterface { + + static String of(Integer value) { + return value.toString(); + } + + private String of(String ignored) { + return ignored; + } + } + + @SuppressWarnings("unused") + static class ConstructorClassArraySample { + + ConstructorClassArraySample(Class... classArrayArg) { + } + + ConstructorClassArraySample(Executor somethingElse) { + } + } + + @SuppressWarnings("unused") + static class MultiConstructorClassArraySample { + + MultiConstructorClassArraySample(Class... classArrayArg) { + } + + MultiConstructorClassArraySample(String... stringArrayArg) { + } + } + + @SuppressWarnings("unused") + static class ClassArrayFactoryMethodSample { + + static String of(Class[] classArrayArg) { + return "test"; + } + + } + + @SuppressWarnings("unused") + static class ClassArrayFactoryMethodSampleWithAnotherFactoryMethod { + + static String of(Class[] classArrayArg) { + return "test"; + } + + static String of(String[] classArrayArg) { + return "test"; + } + + } + + @SuppressWarnings("unnused") + static class ConstructorPrimitiveFallback { + + public ConstructorPrimitiveFallback(boolean useDefaultExecutor) { + } + + public ConstructorPrimitiveFallback(Executor executor) { + } + + } + + static class SampleBeanWithConstructors { + + public SampleBeanWithConstructors() { + } + + public SampleBeanWithConstructors(String name) { + } + + public SampleBeanWithConstructors(Number number, String name) { + } + + } + + interface FactoryWithOverloadedClassMethodsOnInterface { + + static FactoryWithOverloadedClassMethodsOnInterface byAnnotation( + Class annotationType) { + return byAnnotation(annotationType, SearchStrategy.INHERITED_ANNOTATIONS); + } + + static FactoryWithOverloadedClassMethodsOnInterface byAnnotation( + Class annotationType, + SearchStrategy searchStrategy) { + return null; + } + + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/EnumWithClassBody.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/EnumWithClassBody.java new file mode 100644 index 00000000000..f7d702c36b3 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/EnumWithClassBody.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +/** + * Test enum that include a class body. + * + * @author Phillip Webb + */ +public enum EnumWithClassBody { + + /** + * No class body. + */ + ONE, + + /** + * With class body. + */ + TWO { + + @Override + public String toString() { + return "2"; + } + + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/ExampleClass$$GeneratedBy.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/ExampleClass$$GeneratedBy.java new file mode 100644 index 00000000000..8b40ef58def --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/ExampleClass$$GeneratedBy.java @@ -0,0 +1,26 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +/** + * Fake CGLIB generated class. + * + * @author Phillip Webb + */ +class ExampleClass$$GeneratedBy extends ExampleClass { + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/ExampleClass.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/ExampleClass.java new file mode 100644 index 00000000000..c549b9befab --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/ExampleClass.java @@ -0,0 +1,26 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +/** + * Public example class used for test. + * + * @author Phillip Webb + */ +public class ExampleClass { + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java new file mode 100644 index 00000000000..d4de4edc3b6 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java @@ -0,0 +1,347 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.lang.reflect.Executable; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + +import javax.lang.model.element.Modifier; + +import org.assertj.core.api.ThrowingConsumer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.DefaultGenerationContext; +import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.generate.InMemoryGeneratedFiles; +import org.springframework.aot.hint.ExecutableHint; +import org.springframework.aot.hint.ExecutableMode; +import org.springframework.aot.hint.ReflectionHints; +import org.springframework.aot.hint.TypeHint; +import org.springframework.aot.test.generator.compile.Compiled; +import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.InstanceSupplier; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.TestBean; +import org.springframework.beans.testfixture.beans.TestBeanWithPrivateConstructor; +import org.springframework.beans.testfixture.beans.factory.generator.InnerComponentConfiguration; +import org.springframework.beans.testfixture.beans.factory.generator.InnerComponentConfiguration.EnvironmentAwareComponent; +import org.springframework.beans.testfixture.beans.factory.generator.InnerComponentConfiguration.NoDependencyComponent; +import org.springframework.beans.testfixture.beans.factory.generator.SimpleConfiguration; +import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolder; +import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolderFactoryBean; +import org.springframework.beans.testfixture.beans.factory.generator.factory.SampleFactory; +import org.springframework.beans.testfixture.beans.factory.generator.injection.InjectionComponent; +import org.springframework.core.env.StandardEnvironment; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.javapoet.TypeSpec; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link InstanceSupplierCodeGenerator}. + * + * @author Phillip Webb + * @author Stephane Nicoll + */ +class InstanceSupplierCodeGeneratorTests { + + private InMemoryGeneratedFiles generatedFiles; + + private DefaultGenerationContext generationContext; + + private boolean allowDirectSupplierShortcut = false; + + private ClassName className = ClassName.get("__", "InstanceSupplierSupplier"); + + + @BeforeEach + void setup() { + this.generatedFiles = new InMemoryGeneratedFiles(); + this.generationContext = new DefaultGenerationContext(this.generatedFiles); + } + + + @Test + void generateWhenHasDefaultConstructor() { + BeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + TestBean bean = getBean(beanFactory, beanDefinition, instanceSupplier); + assertThat(bean).isInstanceOf(TestBean.class); + assertThat(compiled.getSourceFile()) + .contains("InstanceSupplier.using(TestBean::new)"); + }); + assertThat(getReflectionHints().getTypeHint(TestBean.class)) + .satisfies(hasConstructorWithMode(ExecutableMode.INTROSPECT)); + } + + @Test + void generateWhenHasConstructorWithParameter() { + BeanDefinition beanDefinition = new RootBeanDefinition(InjectionComponent.class); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("injected", "injected"); + testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + InjectionComponent bean = getBean(beanFactory, beanDefinition, + instanceSupplier); + assertThat(bean).isInstanceOf(InjectionComponent.class).extracting("bean") + .isEqualTo("injected"); + }); + assertThat(getReflectionHints().getTypeHint(InjectionComponent.class)) + .satisfies(hasConstructorWithMode(ExecutableMode.INTROSPECT)); + } + + @Test + void generateWhenHasConstructorWithInnerClassAndDefaultConstructor() { + RootBeanDefinition beanDefinition = new RootBeanDefinition( + NoDependencyComponent.class); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("configuration", new InnerComponentConfiguration()); + testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + NoDependencyComponent bean = getBean(beanFactory, beanDefinition, + instanceSupplier); + assertThat(bean).isInstanceOf(NoDependencyComponent.class); + assertThat(compiled.getSourceFile()).contains( + "getBeanFactory().getBean(InnerComponentConfiguration.class).new NoDependencyComponent()"); + }); + assertThat(getReflectionHints().getTypeHint(NoDependencyComponent.class)) + .satisfies(hasConstructorWithMode(ExecutableMode.INTROSPECT)); + } + + @Test + void generateWhenHasConstructorWithInnerClassAndParameter() { + BeanDefinition beanDefinition = new RootBeanDefinition( + EnvironmentAwareComponent.class); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("configuration", new InnerComponentConfiguration()); + beanFactory.registerSingleton("environment", new StandardEnvironment()); + testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + EnvironmentAwareComponent bean = getBean(beanFactory, beanDefinition, + instanceSupplier); + assertThat(bean).isInstanceOf(EnvironmentAwareComponent.class); + assertThat(compiled.getSourceFile()).contains( + "getBeanFactory().getBean(InnerComponentConfiguration.class).new EnvironmentAwareComponent("); + }); + assertThat(getReflectionHints().getTypeHint(EnvironmentAwareComponent.class)) + .satisfies(hasConstructorWithMode(ExecutableMode.INTROSPECT)); + } + + @Test + void generateWhenHasConstructorWithGeneric() { + BeanDefinition beanDefinition = new RootBeanDefinition( + NumberHolderFactoryBean.class); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("number", 123); + testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + NumberHolder bean = getBean(beanFactory, beanDefinition, instanceSupplier); + assertThat(bean).isInstanceOf(NumberHolder.class); + assertThat(bean).extracting("number").isNull(); // No property + // actually set + assertThat(compiled.getSourceFile()).contains("NumberHolderFactoryBean::new"); + }); + assertThat(getReflectionHints().getTypeHint(NumberHolderFactoryBean.class)) + .satisfies(hasConstructorWithMode(ExecutableMode.INTROSPECT)); + } + + @Test + void generateWhenHasPrivateConstructor() { + BeanDefinition beanDefinition = new RootBeanDefinition( + TestBeanWithPrivateConstructor.class); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + TestBeanWithPrivateConstructor bean = getBean(beanFactory, beanDefinition, + instanceSupplier); + assertThat(bean).isInstanceOf(TestBeanWithPrivateConstructor.class); + assertThat(compiled.getSourceFile()) + .contains("resolveAndInstantiate(registeredBean)"); + }); + assertThat(getReflectionHints().getTypeHint(TestBeanWithPrivateConstructor.class)) + .satisfies(hasConstructorWithMode(ExecutableMode.INVOKE)); + } + + @Test + void generateWhenHasFactoryMethodWithNoArg() { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(String.class) + .setFactoryMethodOnBean("stringBean", "config").getBeanDefinition(); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("config", BeanDefinitionBuilder + .genericBeanDefinition(SimpleConfiguration.class).getBeanDefinition()); + testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + String bean = getBean(beanFactory, beanDefinition, instanceSupplier); + assertThat(bean).isInstanceOf(String.class); + assertThat(bean).isEqualTo("Hello"); + assertThat(compiled.getSourceFile()).contains( + "getBeanFactory().getBean(SimpleConfiguration.class).stringBean()"); + }); + assertThat(getReflectionHints().getTypeHint(SimpleConfiguration.class)) + .satisfies(hasMethodWithMode(ExecutableMode.INTROSPECT)); + } + + @Test + void generateWhenHasPrivateStaticFactoryMethodWithNoArg() { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(String.class) + .setFactoryMethodOnBean("privateStaticStringBean", "config") + .getBeanDefinition(); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("config", BeanDefinitionBuilder + .genericBeanDefinition(SimpleConfiguration.class).getBeanDefinition()); + testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + String bean = getBean(beanFactory, beanDefinition, instanceSupplier); + assertThat(bean).isInstanceOf(String.class); + assertThat(bean).isEqualTo("Hello"); + assertThat(compiled.getSourceFile()) + .contains("resolveAndInstantiate(registeredBean)"); + }); + assertThat(getReflectionHints().getTypeHint(SimpleConfiguration.class)) + .satisfies(hasMethodWithMode(ExecutableMode.INVOKE)); + } + + @Test + void generateWhenHasStaticFactoryMethodWithNoArg() { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(Integer.class) + .setFactoryMethodOnBean("integerBean", "config").getBeanDefinition(); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("config", BeanDefinitionBuilder + .genericBeanDefinition(SimpleConfiguration.class).getBeanDefinition()); + testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + Integer bean = getBean(beanFactory, beanDefinition, instanceSupplier); + assertThat(bean).isInstanceOf(Integer.class); + assertThat(bean).isEqualTo(42); + assertThat(compiled.getSourceFile()) + .contains("SimpleConfiguration::integerBean"); + }); + assertThat(getReflectionHints().getTypeHint(SimpleConfiguration.class)) + .satisfies(hasMethodWithMode(ExecutableMode.INTROSPECT)); + } + + @Test + void generateWhenHasStaticFactoryMethodWithArg() { + RootBeanDefinition beanDefinition = (RootBeanDefinition) BeanDefinitionBuilder + .rootBeanDefinition(String.class) + .setFactoryMethodOnBean("create", "config").getBeanDefinition(); + beanDefinition.setResolvedFactoryMethod(ReflectionUtils + .findMethod(SampleFactory.class, "create", Number.class, String.class)); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("config", BeanDefinitionBuilder + .genericBeanDefinition(SampleFactory.class).getBeanDefinition()); + beanFactory.registerSingleton("number", 42); + beanFactory.registerSingleton("string", "test"); + testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + String bean = getBean(beanFactory, beanDefinition, instanceSupplier); + assertThat(bean).isInstanceOf(String.class); + assertThat(bean).isEqualTo("42test"); + assertThat(compiled.getSourceFile()).contains("SampleFactory.create("); + }); + assertThat(getReflectionHints().getTypeHint(SampleFactory.class)) + .satisfies(hasMethodWithMode(ExecutableMode.INTROSPECT)); + } + + @Test + void generateWhenHasStaticFactoryMethodCheckedException() { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(Integer.class) + .setFactoryMethodOnBean("throwingIntegerBean", "config") + .getBeanDefinition(); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("config", BeanDefinitionBuilder + .genericBeanDefinition(SimpleConfiguration.class).getBeanDefinition()); + testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + Integer bean = getBean(beanFactory, beanDefinition, instanceSupplier); + assertThat(bean).isInstanceOf(Integer.class); + assertThat(bean).isEqualTo(42); + assertThat(compiled.getSourceFile()).contains(") throws Exception {"); + }); + assertThat(getReflectionHints().getTypeHint(SimpleConfiguration.class)) + .satisfies(hasMethodWithMode(ExecutableMode.INTROSPECT)); + } + + private ReflectionHints getReflectionHints() { + return this.generationContext.getRuntimeHints().reflection(); + } + + private ThrowingConsumer hasConstructorWithMode(ExecutableMode mode) { + return hint -> assertThat(hint.constructors()).anySatisfy(hasMode(mode)); + } + + private ThrowingConsumer hasMethodWithMode(ExecutableMode mode) { + return hint -> assertThat(hint.methods()).anySatisfy(hasMode(mode)); + } + + private ThrowingConsumer hasMode(ExecutableMode mode) { + return hint -> assertThat(hint.getModes()).containsExactly(mode); + } + + @SuppressWarnings("unchecked") + private T getBean(DefaultListableBeanFactory beanFactory, + BeanDefinition beanDefinition, InstanceSupplier instanceSupplier) { + ((RootBeanDefinition) beanDefinition).setInstanceSupplier(instanceSupplier); + beanFactory.registerBeanDefinition("testBean", beanDefinition); + return (T) beanFactory.getBean("testBean"); + } + + @SuppressWarnings("unchecked") + private void testCompiledResult(DefaultListableBeanFactory beanFactory, + BeanDefinition beanDefinition, + BiConsumer, Compiled> result) { + this.generationContext.writeGeneratedContent(); + DefaultListableBeanFactory registrationBeanFactory = new DefaultListableBeanFactory( + beanFactory); + registrationBeanFactory.registerBeanDefinition("testBean", beanDefinition); + RegisteredBean registeredBean = RegisteredBean.of(registrationBeanFactory, + "testBean"); + GeneratedMethods generatedMethods = new GeneratedMethods(); + InstanceSupplierCodeGenerator generator = new InstanceSupplierCodeGenerator( + this.generationContext, this.className, generatedMethods, + this.allowDirectSupplierShortcut); + Executable constructorOrFactoryMethod = ConstructorOrFactoryMethodResolver + .resolve(registeredBean); + CodeBlock generatedCode = generator.generateCode(registeredBean, + constructorOrFactoryMethod); + JavaFile javaFile = createJavaFile(generatedCode, generatedMethods); + TestCompiler.forSystem().withFiles(this.generatedFiles).compile(javaFile::writeTo, + compiled -> result.accept( + (InstanceSupplier) compiled.getInstance(Supplier.class).get(), + compiled)); + } + + private JavaFile createJavaFile(CodeBlock generatedCode, + GeneratedMethods generatedMethods) { + TypeSpec.Builder builder = TypeSpec.classBuilder("InstanceSupplierSupplier"); + builder.addModifiers(Modifier.PUBLIC); + builder.addSuperinterface( + ParameterizedTypeName.get(Supplier.class, InstanceSupplier.class)); + builder.addMethod(MethodSpec.methodBuilder("get").addModifiers(Modifier.PUBLIC) + .returns(InstanceSupplier.class).addStatement("return $L", generatedCode) + .build()); + generatedMethods.doWithMethodSpecs(builder::addMethod); + return JavaFile.builder("__", builder.build()).build(); + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/MockBeanRegistrationsCode.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/MockBeanRegistrationsCode.java new file mode 100644 index 00000000000..0a358fc1155 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/MockBeanRegistrationsCode.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.javapoet.ClassName; + +/** + * Mock {@link BeanRegistrationsCode} implementation. + * + * @author Phillip Webb + */ +class MockBeanRegistrationsCode implements BeanRegistrationsCode { + + private final ClassName className; + + private final GeneratedMethods generatedMethods = new GeneratedMethods(); + + + MockBeanRegistrationsCode(ClassName className) { + this.className = className; + } + + + @Override + public ClassName getClassName() { + return this.className; + } + + @Override + public MethodGenerator getMethodGenerator() { + return this.generatedMethods; + } + + GeneratedMethods getGeneratedMethods() { + return this.generatedMethods; + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/PackagePrivateTestBean.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/PackagePrivateTestBean.java new file mode 100644 index 00000000000..52b6e88ebce --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/PackagePrivateTestBean.java @@ -0,0 +1,26 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +/** + * Package-private test bean. + * + * @author Phillip Webb + */ +class PackagePrivateTestBean { + +} diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/factory/aot/TestBeanRegistrationsAotProcessor.java b/spring-beans/src/testFixtures/java/org/springframework/beans/factory/aot/TestBeanRegistrationsAotProcessor.java new file mode 100644 index 00000000000..079799b229c --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/factory/aot/TestBeanRegistrationsAotProcessor.java @@ -0,0 +1,26 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +/** + * Public variant of {@link BeanRegistrationAotProcessor} for use in tests. + * + * @author Phillip Webb + */ +public class TestBeanRegistrationsAotProcessor extends BeanRegistrationsAotProcessor { + +} diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPackagePrivateConstructor.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPackagePrivateConstructor.java new file mode 100644 index 00000000000..b3032107ee5 --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPackagePrivateConstructor.java @@ -0,0 +1,24 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans; + +public class TestBeanWithPackagePrivateConstructor { + + TestBeanWithPackagePrivateConstructor() { + } + +} diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPrivateConstructor.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPrivateConstructor.java new file mode 100644 index 00000000000..cf848d6c2ed --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPrivateConstructor.java @@ -0,0 +1,24 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans; + +public class TestBeanWithPrivateConstructor { + + private TestBeanWithPrivateConstructor() { + } + +} diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPrivateMethod.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPrivateMethod.java new file mode 100644 index 00000000000..222b4314886 --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPrivateMethod.java @@ -0,0 +1,28 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans; + +@SuppressWarnings("unused") +public class TestBeanWithPrivateMethod { + + private int age; + + private void setAge(int age) { + this.age = age; + } + +} diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPublicField.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPublicField.java new file mode 100644 index 00000000000..7ebf18974bb --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPublicField.java @@ -0,0 +1,23 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans; + +public class TestBeanWithPublicField { + + public int age; + +} diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/generator/SimpleConfiguration.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/generator/SimpleConfiguration.java index 0f5ba99ad05..89d459297bf 100644 --- a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/generator/SimpleConfiguration.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/generator/SimpleConfiguration.java @@ -16,6 +16,8 @@ package org.springframework.beans.testfixture.beans.factory.generator; +import java.io.IOException; + public class SimpleConfiguration { public SimpleConfiguration() { @@ -25,7 +27,20 @@ public class SimpleConfiguration { return "Hello"; } - public Integer integerBean() { + @SuppressWarnings("unused") + private static String privateStaticStringBean() { + return "Hello"; + } + + static String packageStaticStringBean() { + return "Hello"; + } + + public static Integer integerBean() { + return 42; + } + + public Integer throwingIntegerBean() throws IOException { return 42; }