From 5bc701d4fe32b45481bcf63f093759df75df7f57 Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Fri, 4 Mar 2022 14:11:02 +0100 Subject: [PATCH] Introduce BeanFactoryContribution This commit introduces an infrastructure to contribute generated code ahead of time to initialize a BeanFactory. Code and hints can be contributed to a BeanFactorInitialization, with the ability to write to other packages if necessary. An implementation of that new interface that registers a BeanDefinition is also included in this commit. It delegates to a BeanInstantiationGenerator for geenerating the instance supplier that creates the bean instance. For corner cases, a BeanRegistrationContributionProvider can be implemented. It allows to return a custom BeanFactoryContribution for a particualr bean definition. This usually uses the default implementation with a custom instance supplier. Note that this commit adds an temporary executable resolution that is meant to be replaced by the use of ConstructorResolver See gh-28088 --- .../generator/BeanFactoryContribution.java | 34 + .../generator/BeanFactoryInitialization.java | 110 +++ .../generator/BeanInstantiationGenerator.java | 47 ++ .../generator/BeanParameterGenerator.java | 5 + ...anRegistrationBeanFactoryContribution.java | 469 +++++++++++++ .../BeanRegistrationContributionProvider.java | 43 ++ .../DefaultBeanInstantiationGenerator.java | 16 +- ...tBeanRegistrationContributionProvider.java | 494 ++++++++++++++ ...anRegistrationBeanFactoryContribution.java | 42 ++ ...istrationBeanFactoryContributionTests.java | 646 ++++++++++++++++++ ...RegistrationContributionProviderTests.java | 67 ++ 11 files changed, 1965 insertions(+), 8 deletions(-) create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryContribution.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryInitialization.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanInstantiationGenerator.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanRegistrationBeanFactoryContribution.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanRegistrationContributionProvider.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/generator/DefaultBeanRegistrationContributionProvider.java create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/generator/InnerBeanRegistrationBeanFactoryContribution.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanRegistrationBeanFactoryContributionTests.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/generator/DefaultBeanRegistrationContributionProviderTests.java diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryContribution.java new file mode 100644 index 00000000000..2eeeac9a988 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryContribution.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +/** + * Contribute optimizations ahead of time to initialize a bean factory. + * + * @author Stephane Nicoll + * @since 6.0 + */ +public interface BeanFactoryContribution { + + /** + * Contribute ahead of time optimizations to the specific + * {@link BeanFactoryInitialization}. + * @param initialization {@link BeanFactoryInitialization} to contribute to + */ + void applyTo(BeanFactoryInitialization initialization); + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryInitialization.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryInitialization.java new file mode 100644 index 00000000000..afe7326c83d --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryInitialization.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import java.util.function.Consumer; +import java.util.function.Supplier; + +import javax.lang.model.element.Modifier; + +import org.springframework.aot.generator.GeneratedType; +import org.springframework.aot.generator.GeneratedTypeContext; +import org.springframework.aot.generator.ProtectedAccess; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.javapoet.MethodSpec; + +/** + * The initialization of a {@link BeanFactory}. + * + * @author Andy Wilkinson + * @author Stephane Nicoll + * @since 6.0 + */ +public class BeanFactoryInitialization { + + private final GeneratedTypeContext generatedTypeContext; + + private final CodeBlock.Builder codeContributions; + + public BeanFactoryInitialization(GeneratedTypeContext generatedTypeContext) { + this.generatedTypeContext = generatedTypeContext; + this.codeContributions = CodeBlock.builder(); + } + + /** + * Return the {@link GeneratedTypeContext} to use to contribute + * additional methods or hints. + * @return the generation context + */ + public GeneratedTypeContext generatedTypeContext() { + return this.generatedTypeContext; + } + + /** + * Contribute code that initializes the bean factory and that does not + * require any privileged access. + * @param code the code to contribute + */ + public void contribute(Consumer code) { + CodeBlock.Builder builder = CodeBlock.builder(); + code.accept(builder); + CodeBlock codeBlock = builder.build(); + this.codeContributions.add(codeBlock); + if (!codeBlock.toString().endsWith("\n")) { + this.codeContributions.add("\n"); + } + } + + /** + * Contribute code that initializes the bean factory. If privileged access + * is required, a public method in the target package is created and + * invoked, rather than contributing the code directly. + * @param protectedAccess the {@link ProtectedAccess} instance to use + * @param methodName a method name to use if privileged access is required + * @param methodBody the contribution + */ + public void contribute(ProtectedAccess protectedAccess, Supplier methodName, + Consumer methodBody) { + String targetPackageName = this.generatedTypeContext.getMainGeneratedType().getClassName().packageName(); + String protectedPackageName = protectedAccess.getPrivilegedPackageName(targetPackageName); + if (protectedPackageName != null) { + GeneratedType type = this.generatedTypeContext.getGeneratedType(protectedPackageName); + MethodSpec.Builder method = MethodSpec.methodBuilder(methodName.get()) + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .addParameter(DefaultListableBeanFactory.class, "beanFactory"); + CodeBlock.Builder code = CodeBlock.builder(); + methodBody.accept(code); + method.addCode(code.build()); + contribute(main -> main.addStatement("$T.$N(beanFactory)", type.getClassName(), type.addMethod(method))); + } + else { + contribute(methodBody); + } + } + + /** + * Return the code that has been contributed to this instance. + * @return the code + */ + public CodeBlock toCodeBlock() { + return this.codeContributions.build(); + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanInstantiationGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanInstantiationGenerator.java new file mode 100644 index 00000000000..5a776428c6d --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanInstantiationGenerator.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import java.lang.reflect.Executable; + +import org.springframework.aot.generator.CodeContribution; +import org.springframework.aot.hint.RuntimeHints; + +/** + * Generate code that instantiate a particular bean. + * + * @author Stephane Nicoll + * @since 6.0 + */ +public interface BeanInstantiationGenerator { + + /** + * Return the {@link Executable} that is used to create the bean instance + * for further metadata processing. + * @return the executable that is used to create the bean instance + */ + Executable getInstanceCreator(); + + /** + * Return the necessary code to instantiate a bean. + * @param runtimeHints the runtime hints instance to use + * @return a code contribution that provides an initialized bean instance + */ + CodeContribution generateBeanInstantiation(RuntimeHints runtimeHints); + +} + diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanParameterGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanParameterGenerator.java index e78c6589eb5..c372db9e6ea 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanParameterGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanParameterGenerator.java @@ -49,6 +49,11 @@ import org.springframework.util.ObjectUtils; */ public final class BeanParameterGenerator { + /** + * A default instance that does not handle inner bean definitions. + */ + public static final BeanParameterGenerator INSTANCE = new BeanParameterGenerator(); + private final ResolvableTypeGenerator typeGenerator = new ResolvableTypeGenerator(); private final Function innerBeanDefinitionGenerator; diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanRegistrationBeanFactoryContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanRegistrationBeanFactoryContribution.java new file mode 100644 index 00000000000..29f9057a420 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanRegistrationBeanFactoryContribution.java @@ -0,0 +1,469 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import java.beans.BeanInfo; +import java.beans.IntrospectionException; +import java.beans.Introspector; +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.function.Predicate; + +import javax.lang.model.SourceVersion; + +import org.springframework.aot.generator.CodeContribution; +import org.springframework.aot.generator.ProtectedAccess; +import org.springframework.aot.generator.ResolvableTypeGenerator; +import org.springframework.aot.hint.ExecutableMode; +import org.springframework.aot.hint.ReflectionHints; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.beans.BeanInfoFactory; +import org.springframework.beans.ExtendedBeanInfoFactory; +import org.springframework.beans.MutablePropertyValues; +import org.springframework.beans.PropertyValue; +import org.springframework.beans.PropertyValues; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.beans.factory.config.ConstructorArgumentValues; +import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; +import org.springframework.beans.factory.generator.config.BeanDefinitionRegistrar; +import org.springframework.beans.factory.support.AbstractBeanDefinition; +import org.springframework.core.AttributeAccessor; +import org.springframework.core.ResolvableType; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.javapoet.support.MultiStatement; +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * A {@link BeanFactoryContribution} that registers a bean with the bean + * factory. + * + * @author Stephane Nicoll + * @since 6.0 + */ +public class BeanRegistrationBeanFactoryContribution implements BeanFactoryContribution { + + private static final BeanInfoFactory beanInfoFactory = new ExtendedBeanInfoFactory(); + + private static final ResolvableTypeGenerator typeGenerator = new ResolvableTypeGenerator(); + + private final String beanName; + + private final BeanDefinition beanDefinition; + + private final BeanInstantiationGenerator beanInstantiationGenerator; + + @Nullable + private final DefaultBeanRegistrationContributionProvider innerBeanRegistrationContributionProvider; + + private int nesting = 0; + + BeanRegistrationBeanFactoryContribution(String beanName, BeanDefinition beanDefinition, + BeanInstantiationGenerator beanInstantiationGenerator, + @Nullable DefaultBeanRegistrationContributionProvider innerBeanRegistrationContributionProvider) { + this.beanName = beanName; + this.beanDefinition = beanDefinition; + this.beanInstantiationGenerator = beanInstantiationGenerator; + this.innerBeanRegistrationContributionProvider = innerBeanRegistrationContributionProvider; + } + + public BeanRegistrationBeanFactoryContribution(String beanName, BeanDefinition beanDefinition, + BeanInstantiationGenerator beanInstantiationGenerator) { + this(beanName, beanDefinition, beanInstantiationGenerator, null); + } + + String getBeanName() { + return this.beanName; + } + + BeanDefinition getBeanDefinition() { + return this.beanDefinition; + } + + @Override + public void applyTo(BeanFactoryInitialization initialization) { + RuntimeHints runtimeHints = initialization.generatedTypeContext().runtimeHints(); + registerRuntimeHints(runtimeHints); + CodeContribution beanInstanceContribution = generateBeanInstance(runtimeHints); + // Write everything in one place + ProtectedAccess protectedAccess = beanInstanceContribution.protectedAccess(); + protectedAccess.analyze(this.beanDefinition.getResolvableType()); + initialization.contribute(protectedAccess, this::registerBeanMethodName, code -> + code.add(generateBeanRegistration(runtimeHints, beanInstanceContribution.statements()))); + } + + /** + * Register the necessary hints that are required to process the bean + * registration generated by this instance. + * @param runtimeHints the runtime hints to use + */ + void registerRuntimeHints(RuntimeHints runtimeHints) { + registerPropertyValuesRuntimeHints(runtimeHints); + } + + /** + * Generate the necessary code to register a {@link BeanDefinition} in the + * bean registry. + * @param runtimeHints the hints to use + * @param beanInstanceStatements the {@linkplain MultiStatement statements} + * to create and initialize the bean instance + * @return bean registration code + */ + CodeBlock generateBeanRegistration(RuntimeHints runtimeHints, MultiStatement beanInstanceStatements) { + BeanParameterGenerator parameterGenerator = createBeanParameterGenerator(runtimeHints); + Generator generator = new Generator(parameterGenerator); + return generator.generateBeanRegistration(beanInstanceStatements); + } + + /** + * Generate the necessary code to create a {@link BeanDefinition}. + * @param runtimeHints the hints to use + * @return bean definition code + */ + CodeBlock generateBeanDefinition(RuntimeHints runtimeHints) { + CodeContribution beanInstanceContribution = generateBeanInstance(runtimeHints); + BeanParameterGenerator parameterGenerator = createBeanParameterGenerator(runtimeHints); + Generator generator = new Generator(parameterGenerator); + return generator.generateBeanDefinition(beanInstanceContribution.statements()); + } + + private BeanParameterGenerator createBeanParameterGenerator(RuntimeHints runtimeHints) { + return new BeanParameterGenerator(beanDefinition -> + generateInnerBeanDefinition(beanDefinition, runtimeHints)); + } + + /** + * Return the predicate to use to include Bean Definition + * {@link AttributeAccessor attributes}. + * @return the bean definition's attributes include filter + */ + protected Predicate getAttributeFilter() { + return candidate -> false; + } + + /** + * Specify if the creator {@link Executable} should be defined. By default, + * a creator is specified if the {@code instanceSupplier} callback is used + * with an {@code instanceContext} callback. + * @param instanceCreator the executable to use to instantiate the bean + * @return {@code true} to declare the creator + */ + protected boolean shouldDeclareCreator(Executable instanceCreator) { + if (instanceCreator instanceof Method) { + return true; + } + if (instanceCreator instanceof Constructor constructor) { + int minArgs = isInnerClass(constructor.getDeclaringClass()) ? 2 : 1; + return instanceCreator.getParameterCount() >= minArgs; + } + return false; + } + + /** + * Return the necessary code to instantiate and post-process a bean. + * @param runtimeHints the {@link RuntimeHints} to use + * @return a code contribution that provides an initialized bean instance + */ + protected CodeContribution generateBeanInstance(RuntimeHints runtimeHints) { + return this.beanInstantiationGenerator.generateBeanInstantiation(runtimeHints); + } + + private void registerPropertyValuesRuntimeHints(RuntimeHints runtimeHints) { + if (!this.beanDefinition.hasPropertyValues()) { + return; + } + BeanInfo beanInfo = getBeanInfo(this.beanDefinition.getResolvableType().toClass()); + if (beanInfo != null) { + ReflectionHints reflectionHints = runtimeHints.reflection(); + this.beanDefinition.getPropertyValues().getPropertyValueList().forEach(propertyValue -> { + Method writeMethod = findWriteMethod(beanInfo, propertyValue.getName()); + if (writeMethod != null) { + reflectionHints.registerMethod(writeMethod, hint -> hint.withMode(ExecutableMode.INVOKE)); + } + }); + } + } + + @Nullable + private BeanInfo getBeanInfo(Class beanType) { + try { + BeanInfo beanInfo = beanInfoFactory.getBeanInfo(beanType); + if (beanInfo != null) { + return beanInfo; + } + return Introspector.getBeanInfo(beanType, Introspector.IGNORE_ALL_BEANINFO); + } + catch (IntrospectionException ex) { + return null; + } + } + + @Nullable + private Method findWriteMethod(BeanInfo beanInfo, String propertyName) { + return Arrays.stream(beanInfo.getPropertyDescriptors()) + .filter(pd -> propertyName.equals(pd.getName())) + .map(java.beans.PropertyDescriptor::getWriteMethod) + .filter(Objects::nonNull).findFirst().orElse(null); + } + + protected CodeBlock initializeBeanDefinitionRegistrar() { + return CodeBlock.of("$T.of($S, ", BeanDefinitionRegistrar.class, this.beanName); + } + + private Class getUserBeanClass() { + return ClassUtils.getUserClass(this.beanDefinition.getResolvableType().toClass()); + } + + private void handleCreatorReference(Builder code, Executable creator) { + if (creator instanceof Method) { + code.add(".withFactoryMethod($T.class, $S", creator.getDeclaringClass(), creator.getName()); + if (creator.getParameterCount() > 0) { + code.add(", "); + } + } + else { + code.add(".withConstructor("); + } + code.add(BeanParameterGenerator.INSTANCE.generateExecutableParameterTypes(creator)); + code.add(")"); + } + + private CodeBlock generateInnerBeanDefinition(BeanDefinition beanDefinition, RuntimeHints runtimeHints) { + if (this.innerBeanRegistrationContributionProvider == null) { + throw new IllegalStateException("This generator does not handle inner bean definition " + beanDefinition); + } + BeanRegistrationBeanFactoryContribution innerBeanRegistrationContribution = this.innerBeanRegistrationContributionProvider + .getInnerBeanRegistrationContribution(this, beanDefinition); + innerBeanRegistrationContribution.nesting = this.nesting + 1; + innerBeanRegistrationContribution.registerRuntimeHints(runtimeHints); + return innerBeanRegistrationContribution.generateBeanDefinition(runtimeHints); + } + + private String registerBeanMethodName() { + Executable instanceCreator = this.beanInstantiationGenerator.getInstanceCreator(); + if (instanceCreator instanceof Method method) { + String target = (isValidName(this.beanName)) ? this.beanName : method.getName(); + return String.format("register%s_%s", method.getDeclaringClass().getSimpleName(), target); + } + else if (instanceCreator.getDeclaringClass().getEnclosingClass() != null) { + String target = (isValidName(this.beanName)) ? this.beanName : getUserBeanClass().getSimpleName(); + Class enclosingClass = instanceCreator.getDeclaringClass().getEnclosingClass(); + return String.format("register%s_%s", enclosingClass.getSimpleName(), target); + } + else { + String target = (isValidName(this.beanName)) ? this.beanName : getUserBeanClass().getSimpleName(); + return "register" + StringUtils.capitalize(target); + } + } + + private boolean isValidName(@Nullable String name) { + return name != null && SourceVersion.isIdentifier(name) && !SourceVersion.isKeyword(name); + } + + private String determineVariableName(String name) { + return name + "_".repeat(this.nesting); + } + + private static boolean isInnerClass(Class type) { + return type.isMemberClass() && !java.lang.reflect.Modifier.isStatic(type.getModifiers()); + } + + class Generator { + + private final BeanParameterGenerator parameterGenerator; + + private final BeanDefinition beanDefinition; + + Generator(BeanParameterGenerator parameterGenerator) { + this.parameterGenerator = parameterGenerator; + this.beanDefinition = BeanRegistrationBeanFactoryContribution.this.beanDefinition; + } + + CodeBlock generateBeanRegistration(MultiStatement instanceStatements) { + CodeBlock.Builder code = CodeBlock.builder(); + initializeBeanDefinitionRegistrar(instanceStatements, code); + code.addStatement(".register(beanFactory)"); + return code.build(); + } + + CodeBlock generateBeanDefinition(MultiStatement instanceStatements) { + CodeBlock.Builder code = CodeBlock.builder(); + initializeBeanDefinitionRegistrar(instanceStatements, code); + code.add(".toBeanDefinition()"); + return code.build(); + } + + private void initializeBeanDefinitionRegistrar(MultiStatement instanceStatements, Builder code) { + Executable instanceCreator = BeanRegistrationBeanFactoryContribution.this.beanInstantiationGenerator.getInstanceCreator(); + code.add(BeanRegistrationBeanFactoryContribution.this.initializeBeanDefinitionRegistrar()); + generateBeanType(code); + code.add(")"); + boolean shouldDeclareCreator = shouldDeclareCreator(instanceCreator); + if (shouldDeclareCreator) { + handleCreatorReference(code, instanceCreator); + } + code.add("\n").indent().indent(); + code.add(".instanceSupplier("); + code.add(instanceStatements.toCodeBlock()); + code.add(")").unindent().unindent(); + handleBeanDefinitionMetadata(code); + } + + private void generateBeanType(Builder code) { + ResolvableType resolvableType = this.beanDefinition.getResolvableType(); + if (resolvableType.hasGenerics() && !hasUnresolvedGenerics(resolvableType)) { + code.add(typeGenerator.generateTypeFor(resolvableType)); + } + else { + code.add("$T.class", getUserBeanClass()); + } + } + + private boolean hasUnresolvedGenerics(ResolvableType resolvableType) { + if (resolvableType.hasUnresolvableGenerics()) { + return true; + } + for (ResolvableType generic : resolvableType.getGenerics()) { + if (hasUnresolvedGenerics(generic)) { + return true; + } + } + return false; + } + + private void handleBeanDefinitionMetadata(Builder code) { + String bdVariable = determineVariableName("bd"); + MultiStatement statements = new MultiStatement(); + if (this.beanDefinition.isPrimary()) { + statements.addStatement("$L.setPrimary(true)", bdVariable); + } + String scope = this.beanDefinition.getScope(); + if (StringUtils.hasText(scope) && !ConfigurableBeanFactory.SCOPE_SINGLETON.equals(scope)) { + statements.addStatement("$L.setScope($S)", bdVariable, scope); + } + String[] dependsOn = this.beanDefinition.getDependsOn(); + if (!ObjectUtils.isEmpty(dependsOn)) { + statements.addStatement("$L.setDependsOn($L)", bdVariable, + this.parameterGenerator.generateParameterValue(dependsOn)); + } + if (this.beanDefinition.isLazyInit()) { + statements.addStatement("$L.setLazyInit(true)", bdVariable); + } + if (!this.beanDefinition.isAutowireCandidate()) { + statements.addStatement("$L.setAutowireCandidate(false)", bdVariable); + } + if (this.beanDefinition instanceof AbstractBeanDefinition + && ((AbstractBeanDefinition) this.beanDefinition).isSynthetic()) { + statements.addStatement("$L.setSynthetic(true)", bdVariable); + } + if (this.beanDefinition.getRole() != BeanDefinition.ROLE_APPLICATION) { + statements.addStatement("$L.setRole($L)", bdVariable, this.beanDefinition.getRole()); + } + Map indexedArgumentValues = this.beanDefinition.getConstructorArgumentValues() + .getIndexedArgumentValues(); + if (!indexedArgumentValues.isEmpty()) { + handleArgumentValues(statements, bdVariable, indexedArgumentValues); + } + if (this.beanDefinition.hasPropertyValues()) { + handlePropertyValues(statements, bdVariable, this.beanDefinition.getPropertyValues()); + } + if (this.beanDefinition.attributeNames().length > 0) { + handleAttributes(statements, bdVariable); + } + if (statements.isEmpty()) { + return; + } + code.add(statements.toCodeBlock(".customize((" + bdVariable + ") ->")); + code.add(")"); + } + + private void handleArgumentValues(MultiStatement statements, String bdVariable, + Map indexedArgumentValues) { + if (indexedArgumentValues.size() == 1) { + Entry entry = indexedArgumentValues.entrySet().iterator().next(); + statements.addStatement(generateArgumentValue(bdVariable + ".getConstructorArgumentValues().", + entry.getKey(), entry.getValue())); + } + else { + String avVariable = determineVariableName("argumentValues"); + statements.addStatement("$T $L = $L.getConstructorArgumentValues()", ConstructorArgumentValues.class, avVariable, bdVariable); + statements.addAll(indexedArgumentValues.entrySet(), entry -> generateArgumentValue(avVariable + ".", + entry.getKey(), entry.getValue())); + } + } + + private CodeBlock generateArgumentValue(String prefix, Integer index, ValueHolder valueHolder) { + Builder code = CodeBlock.builder(); + code.add(prefix); + code.add("addIndexedArgumentValue($L, ", index); + Object value = valueHolder.getValue(); + code.add(this.parameterGenerator.generateParameterValue(value)); + code.add(")"); + return code.build(); + } + + private void handlePropertyValues(MultiStatement statements, String bdVariable, + PropertyValues propertyValues) { + PropertyValue[] properties = propertyValues.getPropertyValues(); + if (properties.length == 1) { + statements.addStatement(generatePropertyValue(bdVariable + ".getPropertyValues().", properties[0])); + } + else { + String pvVariable = determineVariableName("propertyValues"); + statements.addStatement("$T $L = $L.getPropertyValues()", MutablePropertyValues.class, pvVariable, bdVariable); + for (PropertyValue property : properties) { + statements.addStatement(generatePropertyValue(pvVariable + ".", property)); + } + } + } + + private CodeBlock generatePropertyValue(String prefix, PropertyValue property) { + Builder code = CodeBlock.builder(); + code.add(prefix); + code.add("addPropertyValue($S, ", property.getName()); + Object value = property.getValue(); + code.add(this.parameterGenerator.generateParameterValue(value)); + code.add(")"); + return code.build(); + } + + private void handleAttributes(MultiStatement statements, String bdVariable) { + String[] attributeNames = this.beanDefinition.attributeNames(); + Predicate filter = getAttributeFilter(); + for (String attributeName : attributeNames) { + if (filter.test(attributeName)) { + Object value = this.beanDefinition.getAttribute(attributeName); + Builder code = CodeBlock.builder(); + code.add("$L.setAttribute($S, ", bdVariable, attributeName); + code.add((this.parameterGenerator.generateParameterValue(value))); + code.add(")"); + statements.addStatement(code.build()); + } + } + } + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanRegistrationContributionProvider.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanRegistrationContributionProvider.java new file mode 100644 index 00000000000..77098fb3ece --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanRegistrationContributionProvider.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.lang.Nullable; + +/** + * Strategy interface to be implemented by components that require custom + * contribution for a bean definition. + * + * @author Stephane Nicoll + * @since 6.0 + */ +@FunctionalInterface +public interface BeanRegistrationContributionProvider { + + /** + * Return the {@link BeanFactoryContribution} that is capable of contributing + * the registration of a bean for the given {@link RootBeanDefinition} or + * {@code null} if the specified bean definition is not supported. + * @param beanName the bean name to handle + * @param beanDefinition the merged bean definition + * @return a contribution for the specified bean definition or {@code null} + */ + @Nullable + BeanFactoryContribution getContributionFor(String beanName, RootBeanDefinition beanDefinition); + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/DefaultBeanInstantiationGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/DefaultBeanInstantiationGenerator.java index e8e57d13965..de781d9a5b6 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/generator/DefaultBeanInstantiationGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/DefaultBeanInstantiationGenerator.java @@ -33,12 +33,12 @@ import org.springframework.javapoet.CodeBlock; import org.springframework.util.ClassUtils; /** - * Generate the necessary statements to instantiate a bean. + * Default {@link BeanInstantiationGenerator} implementation. * * @author Stephane Nicoll * @see BeanInstantiationContribution */ -class DefaultBeanInstantiationGenerator { +class DefaultBeanInstantiationGenerator implements BeanInstantiationGenerator { private final Executable instanceCreator; @@ -57,12 +57,12 @@ class DefaultBeanInstantiationGenerator { .assignReturnType(member -> !this.contributions.isEmpty()).build(); } - /** - * Return the necessary code to instantiate and post-process the bean - * handled by this instance. - * @param runtimeHints the runtime hints instance to use - * @return a code contribution that provides an initialized bean instance - */ + @Override + public Executable getInstanceCreator() { + return this.instanceCreator; + } + + @Override public CodeContribution generateBeanInstantiation(RuntimeHints runtimeHints) { DefaultCodeContribution codeContribution = new DefaultCodeContribution(runtimeHints); codeContribution.protectedAccess().analyze(this.instanceCreator, this.beanInstanceOptions); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/DefaultBeanRegistrationContributionProvider.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/DefaultBeanRegistrationContributionProvider.java new file mode 100644 index 00000000000..a993e0dc7c7 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/DefaultBeanRegistrationContributionProvider.java @@ -0,0 +1,494 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.BeanReference; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.config.ConstructorArgumentValues; +import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; +import org.springframework.beans.factory.support.AbstractBeanDefinition; +import org.springframework.beans.factory.support.BeanDefinitionValueResolver; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.core.OrderComparator; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.MergedAnnotations; +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; +import org.springframework.util.function.SingletonSupplier; + +/** + * Default {@link BeanRegistrationContributionProvider} implementation. + * + * @author Stephane Nicoll + * @since 6.0 + */ +public final class DefaultBeanRegistrationContributionProvider implements BeanRegistrationContributionProvider { + + private final DefaultListableBeanFactory beanFactory; + + private final ExecutableProvider executableProvider; + + private final Supplier> beanPostProcessors; + + public DefaultBeanRegistrationContributionProvider(DefaultListableBeanFactory beanFactory) { + this.beanFactory = beanFactory; + this.executableProvider = new ExecutableProvider(beanFactory); + this.beanPostProcessors = new SingletonSupplier<>(null, + () -> loadAotContributingBeanPostProcessors(beanFactory)); + } + + private static List loadAotContributingBeanPostProcessors( + DefaultListableBeanFactory beanFactory) { + String[] postProcessorNames = beanFactory.getBeanNamesForType(AotContributingBeanPostProcessor.class, true, false); + List postProcessors = new ArrayList<>(); + for (String ppName : postProcessorNames) { + postProcessors.add(beanFactory.getBean(ppName, AotContributingBeanPostProcessor.class)); + } + sortPostProcessors(postProcessors, beanFactory); + return postProcessors; + } + + @Override + public BeanRegistrationBeanFactoryContribution getContributionFor( + String beanName, RootBeanDefinition beanDefinition) { + BeanInstantiationGenerator beanInstantiationGenerator = getBeanInstantiationGenerator( + beanName, beanDefinition); + return new BeanRegistrationBeanFactoryContribution(beanName, beanDefinition, beanInstantiationGenerator, this); + } + + public BeanInstantiationGenerator getBeanInstantiationGenerator( + String beanName, RootBeanDefinition beanDefinition) { + return new DefaultBeanInstantiationGenerator(determineExecutable(beanDefinition), + determineBeanInstanceContributions(beanName, beanDefinition)); + } + + /** + * Return a {@link BeanRegistrationBeanFactoryContribution} that is capable of + * contributing the specified inner {@link BeanDefinition}. + * @param parent the contribution of the parent bean definition + * @param innerBeanDefinition the inner bean definition + * @return a contribution for the specified inner bean definition + */ + BeanRegistrationBeanFactoryContribution getInnerBeanRegistrationContribution( + BeanRegistrationBeanFactoryContribution parent, BeanDefinition innerBeanDefinition) { + BeanDefinitionValueResolver bdvr = new BeanDefinitionValueResolver(this.beanFactory, + parent.getBeanName(), parent.getBeanDefinition()); + return bdvr.resolveInnerBean(null, innerBeanDefinition, (beanName, bd) -> + new InnerBeanRegistrationBeanFactoryContribution(beanName, bd, + getBeanInstantiationGenerator(beanName, bd), this)); + } + + private Executable determineExecutable(RootBeanDefinition beanDefinition) { + Executable executable = this.executableProvider.detectBeanInstanceExecutable(beanDefinition); + if (executable == null) { + throw new IllegalStateException("No suitable executor found for " + beanDefinition); + } + return executable; + } + + private List determineBeanInstanceContributions( + String beanName, RootBeanDefinition beanDefinition) { + List contributions = new ArrayList<>(); + for (AotContributingBeanPostProcessor pp : this.beanPostProcessors.get()) { + BeanInstantiationContribution contribution = pp.contribute(beanDefinition, + beanDefinition.getResolvableType().toClass(), beanName); + if (contribution != null) { + contributions.add(contribution); + } + } + return contributions; + } + + private static void sortPostProcessors(List postProcessors, ConfigurableListableBeanFactory beanFactory) { + // Nothing to sort? + if (postProcessors.size() <= 1) { + return; + } + Comparator comparatorToUse = null; + if (beanFactory instanceof DefaultListableBeanFactory) { + comparatorToUse = ((DefaultListableBeanFactory) beanFactory).getDependencyComparator(); + } + if (comparatorToUse == null) { + comparatorToUse = OrderComparator.INSTANCE; + } + postProcessors.sort(comparatorToUse); + } + + // FIXME: copy-paste from Spring Native that should go away in favor of ConstructorResolver + private static class ExecutableProvider { + + private static final Log logger = LogFactory.getLog(ExecutableProvider.class); + + private final ConfigurableBeanFactory beanFactory; + + private final ClassLoader classLoader; + + ExecutableProvider(ConfigurableBeanFactory beanFactory) { + this.beanFactory = beanFactory; + this.classLoader = (beanFactory.getBeanClassLoader() != null + ? beanFactory.getBeanClassLoader() : getClass().getClassLoader()); + } + + @Nullable + Executable detectBeanInstanceExecutable(BeanDefinition beanDefinition) { + Supplier beanType = () -> getBeanType(beanDefinition); + List valueTypes = beanDefinition.hasConstructorArgumentValues() + ? determineParameterValueTypes(beanDefinition.getConstructorArgumentValues()) : Collections.emptyList(); + Method resolvedFactoryMethod = resolveFactoryMethod(beanDefinition, valueTypes); + if (resolvedFactoryMethod != null) { + return resolvedFactoryMethod; + } + Class factoryBeanClass = getFactoryBeanClass(beanDefinition); + if (factoryBeanClass != null && !factoryBeanClass.equals(beanDefinition.getResolvableType().toClass())) { + ResolvableType resolvableType = beanDefinition.getResolvableType(); + boolean isCompatible = ResolvableType.forClass(factoryBeanClass).as(FactoryBean.class) + .getGeneric(0).isAssignableFrom(resolvableType); + if (isCompatible) { + return resolveConstructor(() -> ResolvableType.forClass(factoryBeanClass), valueTypes); + } + else { + throw new IllegalStateException(String.format("Incompatible target type '%s' for factory bean '%s'", + resolvableType.toClass().getName(), factoryBeanClass.getName())); + } + } + Executable resolvedConstructor = resolveConstructor(beanType, valueTypes); + if (resolvedConstructor != null) { + return resolvedConstructor; + } + Executable resolvedConstructorOrFactoryMethod = getField(beanDefinition, + "resolvedConstructorOrFactoryMethod", Executable.class); + if (resolvedConstructorOrFactoryMethod != null) { + logger.error("resolvedConstructorOrFactoryMethod required for " + beanDefinition); + return resolvedConstructorOrFactoryMethod; + } + return null; + } + + private List determineParameterValueTypes(ConstructorArgumentValues constructorArgumentValues) { + List parameterTypes = new ArrayList<>(); + for (ValueHolder valueHolder : constructorArgumentValues.getIndexedArgumentValues().values()) { + if (valueHolder.getType() != null) { + parameterTypes.add(ResolvableType.forClass(loadClass(valueHolder.getType()))); + } + else { + Object value = valueHolder.getValue(); + if (value instanceof BeanReference) { + parameterTypes.add(ResolvableType.forClass( + this.beanFactory.getType(((BeanReference) value).getBeanName(), false))); + } + else if (value instanceof BeanDefinition) { + parameterTypes.add(extractTypeFromBeanDefinition(getBeanType((BeanDefinition) value))); + } + else { + parameterTypes.add(ResolvableType.forInstance(value)); + } + } + } + return parameterTypes; + } + + private ResolvableType extractTypeFromBeanDefinition(ResolvableType type) { + if (FactoryBean.class.isAssignableFrom(type.toClass())) { + return type.as(FactoryBean.class).getGeneric(0); + } + return type; + } + + @Nullable + private Method resolveFactoryMethod(BeanDefinition beanDefinition, List valueTypes) { + if (beanDefinition instanceof RootBeanDefinition rbd) { + Method resolvedFactoryMethod = rbd.getResolvedFactoryMethod(); + if (resolvedFactoryMethod != null) { + return resolvedFactoryMethod; + } + } + String factoryMethodName = beanDefinition.getFactoryMethodName(); + if (factoryMethodName != null) { + List methods = new ArrayList<>(); + Class beanClass = getBeanClass(beanDefinition); + if (beanClass == null) { + throw new IllegalStateException("Failed to determine bean class of " + beanDefinition); + } + ReflectionUtils.doWithMethods(beanClass, methods::add, + method -> isFactoryMethodCandidate(beanClass, method, factoryMethodName)); + if (methods.size() >= 1) { + Function> parameterTypesFactory = method -> { + List types = new ArrayList<>(); + for (int i = 0; i < method.getParameterCount(); i++) { + types.add(ResolvableType.forMethodParameter(method, i)); + } + return types; + }; + return (Method) resolveFactoryMethod(methods, parameterTypesFactory, valueTypes); + } + } + return null; + } + + private boolean isFactoryMethodCandidate(Class beanClass, Method method, String factoryMethodName) { + if (method.getName().equals(factoryMethodName)) { + if (Modifier.isStatic(method.getModifiers())) { + return method.getDeclaringClass().equals(beanClass); + } + return !Modifier.isPrivate(method.getModifiers()); + } + return false; + } + + @Nullable + private Executable resolveConstructor(Supplier beanType, List valueTypes) { + Class type = ClassUtils.getUserClass(beanType.get().toClass()); + Constructor[] constructors = type.getDeclaredConstructors(); + if (constructors.length == 1) { + return constructors[0]; + } + for (Constructor constructor : constructors) { + if (MergedAnnotations.from(constructor).isPresent(Autowired.class)) { + return constructor; + } + } + Function, List> parameterTypesFactory = executable -> { + List types = new ArrayList<>(); + for (int i = 0; i < executable.getParameterCount(); i++) { + types.add(ResolvableType.forConstructorParameter(executable, i)); + } + return types; + }; + List matches = Arrays.stream(constructors) + .filter(executable -> match(parameterTypesFactory.apply(executable), + valueTypes, FallbackMode.NONE)).toList(); + if (matches.size() == 1) { + return matches.get(0); + } + List assignableElementFallbackMatches = Arrays.stream(constructors) + .filter(executable -> match(parameterTypesFactory.apply(executable), + valueTypes, FallbackMode.ASSIGNABLE_ELEMENT)).toList(); + if (assignableElementFallbackMatches.size() == 1) { + return assignableElementFallbackMatches.get(0); + } + List typeConversionFallbackMatches = Arrays.stream(constructors) + .filter(executable -> match(parameterTypesFactory.apply(executable), + valueTypes, ExecutableProvider.FallbackMode.TYPE_CONVERSION)).toList(); + return (typeConversionFallbackMatches.size() == 1) ? typeConversionFallbackMatches.get(0) : null; + } + + private Executable resolveFactoryMethod(List executables, + Function> parameterTypesFactory, List valueTypes) { + List matches = executables.stream() + .filter(executable -> match(parameterTypesFactory.apply(executable), + valueTypes, ExecutableProvider.FallbackMode.NONE)).toList(); + if (matches.size() == 1) { + return matches.get(0); + } + List assignableElementFallbackMatches = executables.stream() + .filter(executable -> match(parameterTypesFactory.apply(executable), + valueTypes, ExecutableProvider.FallbackMode.ASSIGNABLE_ELEMENT)).toList(); + if (assignableElementFallbackMatches.size() == 1) { + return assignableElementFallbackMatches.get(0); + } + List typeConversionFallbackMatches = executables.stream() + .filter(executable -> match(parameterTypesFactory.apply(executable), + valueTypes, ExecutableProvider.FallbackMode.TYPE_CONVERSION)).toList(); + if (typeConversionFallbackMatches.size() > 1) { + throw new IllegalStateException("Multiple matches with parameters '" + + valueTypes + "': " + typeConversionFallbackMatches); + } + return (typeConversionFallbackMatches.size() == 1) ? typeConversionFallbackMatches.get(0) : null; + } + + private boolean match(List parameterTypes, List valueTypes, + ExecutableProvider.FallbackMode fallbackMode) { + if (parameterTypes.size() != valueTypes.size()) { + return false; + } + for (int i = 0; i < parameterTypes.size(); i++) { + if (!isMatch(parameterTypes.get(i), valueTypes.get(i), fallbackMode)) { + return false; + } + } + return true; + } + + private boolean isMatch(ResolvableType parameterType, ResolvableType valueType, + ExecutableProvider.FallbackMode fallbackMode) { + if (isAssignable(valueType).test(parameterType)) { + return true; + } + return switch (fallbackMode) { + case ASSIGNABLE_ELEMENT -> isAssignable(valueType).test(extractElementType(parameterType)); + case TYPE_CONVERSION -> typeConversionFallback(valueType).test(parameterType); + default -> false; + }; + } + + private Predicate isAssignable(ResolvableType valueType) { + return parameterType -> { + if (valueType.hasUnresolvableGenerics()) { + return parameterType.toClass().isAssignableFrom(valueType.toClass()); + } + else { + return parameterType.isAssignableFrom(valueType); + } + }; + } + + private ResolvableType extractElementType(ResolvableType parameterType) { + if (parameterType.isArray()) { + return parameterType.getComponentType(); + } + if (Collection.class.isAssignableFrom(parameterType.toClass())) { + return parameterType.as(Collection.class).getGeneric(0); + } + return ResolvableType.NONE; + } + + private Predicate typeConversionFallback(ResolvableType valueType) { + return parameterType -> { + if (valueOrCollection(valueType, this::isStringForClassFallback).test(parameterType)) { + return true; + } + return valueOrCollection(valueType, this::isSimpleConvertibleType).test(parameterType); + }; + } + + private Predicate valueOrCollection(ResolvableType valueType, + Function> predicateProvider) { + return parameterType -> { + if (predicateProvider.apply(valueType).test(parameterType)) { + return true; + } + if (predicateProvider.apply(extractElementType(valueType)).test(extractElementType(parameterType))) { + return true; + } + return (predicateProvider.apply(valueType).test(extractElementType(parameterType))); + }; + } + + /** + * Return a {@link Predicate} for a parameter type that checks if its target value + * is a {@link Class} and the value type is a {@link String}. This is a regular use + * cases where a {@link Class} is defined in the bean definition as an FQN. + * @param valueType the type of the value + * @return a predicate to indicate a fallback match for a String to Class parameter + */ + private Predicate isStringForClassFallback(ResolvableType valueType) { + return parameterType -> (valueType.isAssignableFrom(String.class) + && parameterType.isAssignableFrom(Class.class)); + } + + private Predicate isSimpleConvertibleType(ResolvableType valueType) { + return parameterType -> isSimpleConvertibleType(parameterType.toClass()) + && isSimpleConvertibleType(valueType.toClass()); + } + + @Nullable + private Class getFactoryBeanClass(BeanDefinition beanDefinition) { + if (beanDefinition instanceof RootBeanDefinition rbd) { + if (rbd.hasBeanClass()) { + Class beanClass = rbd.getBeanClass(); + return FactoryBean.class.isAssignableFrom(beanClass) ? beanClass : null; + } + } + return null; + } + + @Nullable + private Class getBeanClass(BeanDefinition beanDefinition) { + if (beanDefinition instanceof AbstractBeanDefinition abd) { + return abd.hasBeanClass() ? abd.getBeanClass() : loadClass(abd.getBeanClassName()); + } + return (beanDefinition.getBeanClassName() != null) ? loadClass(beanDefinition.getBeanClassName()) : null; + } + + private ResolvableType getBeanType(BeanDefinition beanDefinition) { + ResolvableType resolvableType = beanDefinition.getResolvableType(); + if (resolvableType != ResolvableType.NONE) { + return resolvableType; + } + if (beanDefinition instanceof RootBeanDefinition rbd) { + if (rbd.hasBeanClass()) { + return ResolvableType.forClass(rbd.getBeanClass()); + } + } + String beanClassName = beanDefinition.getBeanClassName(); + if (beanClassName != null) { + return ResolvableType.forClass(loadClass(beanClassName)); + } + throw new IllegalStateException("Failed to determine bean class of " + beanDefinition); + } + + private Class loadClass(String beanClassName) { + try { + return ClassUtils.forName(beanClassName, this.classLoader); + } + catch (ClassNotFoundException ex) { + throw new IllegalStateException("Failed to load class " + beanClassName); + } + } + + @Nullable + private T getField(BeanDefinition beanDefinition, String fieldName, Class targetType) { + Field field = ReflectionUtils.findField(RootBeanDefinition.class, fieldName); + ReflectionUtils.makeAccessible(field); + return targetType.cast(ReflectionUtils.getField(field, beanDefinition)); + } + + public static boolean isSimpleConvertibleType(Class type) { + return (type.isPrimitive() && type != void.class) || + type == Double.class || type == Float.class || type == Long.class || + type == Integer.class || type == Short.class || type == Character.class || + type == Byte.class || type == Boolean.class || type == String.class; + } + + + enum FallbackMode { + + NONE, + + ASSIGNABLE_ELEMENT, + + TYPE_CONVERSION + + } + + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/InnerBeanRegistrationBeanFactoryContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/InnerBeanRegistrationBeanFactoryContribution.java new file mode 100644 index 00000000000..436135cfd7b --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/InnerBeanRegistrationBeanFactoryContribution.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.generator.config.BeanDefinitionRegistrar; +import org.springframework.javapoet.CodeBlock; + +/** + * A specialization of {@link BeanRegistrationContributionProvider} that handles + * inner bean definitions. + * + * @author Stephane Nicoll + */ +class InnerBeanRegistrationBeanFactoryContribution extends BeanRegistrationBeanFactoryContribution { + + InnerBeanRegistrationBeanFactoryContribution(String beanName, BeanDefinition beanDefinition, + BeanInstantiationGenerator beanInstantiationGenerator, + DefaultBeanRegistrationContributionProvider innerBeanRegistrationContributionProvider) { + super(beanName, beanDefinition, beanInstantiationGenerator, innerBeanRegistrationContributionProvider); + } + + @Override + protected CodeBlock initializeBeanDefinitionRegistrar() { + return CodeBlock.of("$T.inner(", BeanDefinitionRegistrar.class); + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanRegistrationBeanFactoryContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanRegistrationBeanFactoryContributionTests.java new file mode 100644 index 00000000000..1dbb4685fcb --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanRegistrationBeanFactoryContributionTests.java @@ -0,0 +1,646 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import java.io.IOException; +import java.io.StringWriter; +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generator.DefaultGeneratedTypeContext; +import org.springframework.aot.generator.GeneratedType; +import org.springframework.aot.hint.ExecutableHint; +import org.springframework.aot.hint.ExecutableMode; +import org.springframework.aot.hint.ReflectionHints; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.TypeReference; +import org.springframework.beans.MutablePropertyValues; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.beans.factory.config.ConstructorArgumentValues; +import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.support.AbstractBeanDefinition; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.factory.generator.InnerComponentConfiguration.EnvironmentAwareComponent; +import org.springframework.beans.testfixture.beans.factory.generator.InnerComponentConfiguration.NoDependencyComponent; +import org.springframework.beans.testfixture.beans.factory.generator.SimpleConfiguration; +import org.springframework.beans.testfixture.beans.factory.generator.factory.SampleFactory; +import org.springframework.beans.testfixture.beans.factory.generator.injection.InjectionComponent; +import org.springframework.beans.testfixture.beans.factory.generator.property.ConfigurableBean; +import org.springframework.beans.testfixture.beans.factory.generator.visibility.ProtectedConstructorComponent; +import org.springframework.beans.testfixture.beans.factory.generator.visibility.ProtectedFactoryMethod; +import org.springframework.core.env.Environment; +import org.springframework.core.testfixture.aot.generator.visibility.PublicFactoryBean; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.javapoet.support.CodeSnippet; +import org.springframework.javapoet.support.MultiStatement; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoInteractions; + +/** + * Tests for {@link BeanRegistrationBeanFactoryContribution}. + * + * @author Stephane Nicoll + */ +class BeanRegistrationBeanFactoryContributionTests { + + private final DefaultGeneratedTypeContext generatedTypeContext = new DefaultGeneratedTypeContext("com.example", packageName -> GeneratedType.of(ClassName.get(packageName, "Test"))); + + private final BeanFactoryInitialization initialization = new BeanFactoryInitialization(this.generatedTypeContext); + + @Test + void generateUsingConstructor() { + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(InjectionComponent.class).getBeanDefinition(); + CodeSnippet registration = beanRegistration(beanDefinition, singleConstructor(InjectionComponent.class), code -> code.add("() -> test")); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", InjectionComponent.class).withConstructor(String.class) + .instanceSupplier(() -> test).register(beanFactory); + """); + } + + @Test + void generateUsingConstructorWithNoArgument() { + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(SimpleConfiguration.class).getBeanDefinition(); + CodeSnippet registration = beanRegistration(beanDefinition, singleConstructor(SimpleConfiguration.class), code -> code.add("() -> test")); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", SimpleConfiguration.class) + .instanceSupplier(() -> test).register(beanFactory); + """); + } + + @Test + void generateUsingConstructorOnInnerClass() { + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(EnvironmentAwareComponent.class).getBeanDefinition(); + CodeSnippet registration = beanRegistration(beanDefinition, singleConstructor(EnvironmentAwareComponent.class), code -> code.add("() -> test")); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", InnerComponentConfiguration.EnvironmentAwareComponent.class).withConstructor(InnerComponentConfiguration.class, Environment.class) + .instanceSupplier(() -> test).register(beanFactory); + """); + } + + @Test + void generateUsingConstructorOnInnerClassWithNoExtraArg() { + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(NoDependencyComponent.class).getBeanDefinition(); + CodeSnippet registration = beanRegistration(beanDefinition, singleConstructor(NoDependencyComponent.class), code -> code.add("() -> test")); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", InnerComponentConfiguration.NoDependencyComponent.class) + .instanceSupplier(() -> test).register(beanFactory); + """); + } + + @Test + void generateUsingFactoryMethod() { + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(String.class).getBeanDefinition(); + CodeSnippet registration = beanRegistration(beanDefinition, method(SampleFactory.class, "create", String.class), code -> code.add("() -> test")); + assertThat(registration.hasImport(SampleFactory.class)).isTrue(); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", String.class).withFactoryMethod(SampleFactory.class, "create", String.class) + .instanceSupplier(() -> test).register(beanFactory); + """); + } + + @Test + void generateUsingFactoryMethodWithNoArgument() { + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(Integer.class).getBeanDefinition(); + CodeSnippet registration = beanRegistration(beanDefinition, method(SampleFactory.class, "integerBean"), code -> code.add("() -> test")); + assertThat(registration.hasImport(SampleFactory.class)).isTrue(); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", Integer.class).withFactoryMethod(SampleFactory.class, "integerBean") + .instanceSupplier(() -> test).register(beanFactory); + """); + } + + @Test + void generateUsingPublicAccessDoesNotAccessAnotherPackage() { + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(SimpleConfiguration.class).getBeanDefinition(); + getContribution(beanDefinition, singleConstructor(SimpleConfiguration.class)).applyTo(this.initialization); + assertThat(this.generatedTypeContext.toJavaFiles()).hasSize(1); + assertThat(CodeSnippet.of(this.initialization.toCodeBlock()).getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", SimpleConfiguration.class) + .instanceSupplier(SimpleConfiguration::new).register(beanFactory); + """); + } + + @Test + void generateUsingProtectedConstructorWritesToBlessedPackage() { + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(ProtectedConstructorComponent.class).getBeanDefinition(); + getContribution(beanDefinition, singleConstructor(ProtectedConstructorComponent.class)).applyTo(this.initialization); + assertThat(this.generatedTypeContext.hasGeneratedType(ProtectedConstructorComponent.class.getPackageName())).isTrue(); + GeneratedType generatedType = this.generatedTypeContext.getGeneratedType(ProtectedConstructorComponent.class.getPackageName()); + assertThat(removeIndent(codeOf(generatedType), 1)).containsSequence(""" + public static void registerTest(DefaultListableBeanFactory beanFactory) { + BeanDefinitionRegistrar.of("test", ProtectedConstructorComponent.class) + .instanceSupplier(ProtectedConstructorComponent::new).register(beanFactory); + }"""); + assertThat(CodeSnippet.of(this.initialization.toCodeBlock()).getSnippet()).isEqualTo( + ProtectedConstructorComponent.class.getPackageName() + ".Test.registerTest(beanFactory);\n"); + } + + @Test + void generateUsingProtectedFactoryMethodWritesToBlessedPackage() { + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(String.class).getBeanDefinition(); + getContribution(beanDefinition, method(ProtectedFactoryMethod.class, "testBean", Integer.class)) + .applyTo(this.initialization); + assertThat(this.generatedTypeContext.hasGeneratedType(ProtectedFactoryMethod.class.getPackageName())).isTrue(); + GeneratedType generatedType = this.generatedTypeContext.getGeneratedType(ProtectedConstructorComponent.class.getPackageName()); + assertThat(removeIndent(codeOf(generatedType), 1)).containsSequence(""" + public static void registerProtectedFactoryMethod_test(DefaultListableBeanFactory beanFactory) { + BeanDefinitionRegistrar.of("test", String.class).withFactoryMethod(ProtectedFactoryMethod.class, "testBean", Integer.class) + .instanceSupplier((instanceContext) -> instanceContext.create(beanFactory, (attributes) -> beanFactory.getBean(ProtectedFactoryMethod.class).testBean(attributes.get(0)))).register(beanFactory); + }"""); + assertThat(CodeSnippet.of(this.initialization.toCodeBlock()).getSnippet()).isEqualTo( + ProtectedConstructorComponent.class.getPackageName() + ".Test.registerProtectedFactoryMethod_test(beanFactory);\n"); + } + + @Test + void generateUsingProtectedGenericTypeWritesToBlessedPackage() { + RootBeanDefinition beanDefinition = (RootBeanDefinition) BeanDefinitionBuilder.rootBeanDefinition( + PublicFactoryBean.class).getBeanDefinition(); + beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, String.class); + // This resolve the generic parameter to a protected type + beanDefinition.setTargetType(PublicFactoryBean.resolveToProtectedGenericParameter()); + getContribution(beanDefinition, singleConstructor(PublicFactoryBean.class)).applyTo(this.initialization); + assertThat(this.generatedTypeContext.hasGeneratedType(PublicFactoryBean.class.getPackageName())).isTrue(); + GeneratedType generatedType = this.generatedTypeContext.getGeneratedType(PublicFactoryBean.class.getPackageName()); + assertThat(removeIndent(codeOf(generatedType), 1)).containsSequence(""" + public static void registerTest(DefaultListableBeanFactory beanFactory) { + BeanDefinitionRegistrar.of("test", ResolvableType.forClassWithGenerics(PublicFactoryBean.class, ProtectedType.class)).withConstructor(Class.class) + .instanceSupplier((instanceContext) -> instanceContext.create(beanFactory, (attributes) -> new PublicFactoryBean(attributes.get(0)))).customize((bd) -> bd.getConstructorArgumentValues().addIndexedArgumentValue(0, String.class)).register(beanFactory); + }"""); + assertThat(CodeSnippet.of(this.initialization.toCodeBlock()).getSnippet()).isEqualTo( + PublicFactoryBean.class.getPackageName() + ".Test.registerTest(beanFactory);\n"); + } + + @Test + void generateWithBeanDefinitionHavingSyntheticFlag() { + assertThat(simpleConfigurationRegistration(bd -> bd.setSynthetic(true)).getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", SimpleConfiguration.class) + .instanceSupplier(() -> SimpleConfiguration::new).customize((bd) -> bd.setSynthetic(true)).register(beanFactory); + """); + } + + @Test + void generateWithBeanDefinitionHavingDependsOn() { + assertThat(simpleConfigurationRegistration(bd -> bd.setDependsOn("test")).getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", SimpleConfiguration.class) + .instanceSupplier(() -> SimpleConfiguration::new).customize((bd) -> bd.setDependsOn(new String[] { "test" })).register(beanFactory); + """); + } + + @Test + void generateWithBeanDefinitionHavingLazyInit() { + assertThat(simpleConfigurationRegistration(bd -> bd.setLazyInit(true)).getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", SimpleConfiguration.class) + .instanceSupplier(() -> SimpleConfiguration::new).customize((bd) -> bd.setLazyInit(true)).register(beanFactory); + """); + } + + @Test + void generateWithBeanDefinitionHavingRole() { + assertThat(simpleConfigurationRegistration(bd -> bd.setRole(BeanDefinition.ROLE_INFRASTRUCTURE)).getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", SimpleConfiguration.class) + .instanceSupplier(() -> SimpleConfiguration::new).customize((bd) -> bd.setRole(2)).register(beanFactory); + """); + } + + @Test + void generateWithBeanDefinitionHavingScope() { + assertThat(simpleConfigurationRegistration(bd -> bd.setScope(ConfigurableBeanFactory.SCOPE_PROTOTYPE)).getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", SimpleConfiguration.class) + .instanceSupplier(() -> SimpleConfiguration::new).customize((bd) -> bd.setScope("prototype")).register(beanFactory); + """); + } + + @Test + void generateWithBeanDefinitionHavingAutowiredCandidate() { + assertThat(simpleConfigurationRegistration(bd -> bd.setAutowireCandidate(false)).getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", SimpleConfiguration.class) + .instanceSupplier(() -> SimpleConfiguration::new).customize((bd) -> bd.setAutowireCandidate(false)).register(beanFactory); + """); + } + + @Test + void generateWithBeanDefinitionHavingDefaultAutowiredCandidateDoesNotConfigureIt() { + assertThat(simpleConfigurationRegistration(bd -> bd.setAutowireCandidate(true)).getSnippet()) + .doesNotContain("bd.setAutowireCandidate("); + } + + @Test + void generateWithBeanDefinitionHavingMultipleAttributes() { + assertThat(simpleConfigurationRegistration(bd -> { + bd.setSynthetic(true); + bd.setPrimary(true); + }).getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", SimpleConfiguration.class) + .instanceSupplier(() -> SimpleConfiguration::new).customize((bd) -> { + bd.setPrimary(true); + bd.setSynthetic(true); + }).register(beanFactory); + """); + } + + @Test + void generateWithBeanDefinitionHavingProperty() { + assertThat(simpleConfigurationRegistration(bd -> bd.getPropertyValues().addPropertyValue("test", "Hello")).getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", SimpleConfiguration.class) + .instanceSupplier(() -> SimpleConfiguration::new).customize((bd) -> bd.getPropertyValues().addPropertyValue("test", "Hello")).register(beanFactory); + """); + } + + @Test + void generateWithBeanDefinitionHavingSeveralProperties() { + CodeSnippet registration = simpleConfigurationRegistration(bd -> { + bd.getPropertyValues().addPropertyValue("test", "Hello"); + bd.getPropertyValues().addPropertyValue("counter", 42); + }); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", SimpleConfiguration.class) + .instanceSupplier(() -> SimpleConfiguration::new).customize((bd) -> { + MutablePropertyValues propertyValues = bd.getPropertyValues(); + propertyValues.addPropertyValue("test", "Hello"); + propertyValues.addPropertyValue("counter", 42); + }).register(beanFactory); + """); + assertThat(registration.hasImport(MutablePropertyValues.class)).isTrue(); + } + + @Test + void generateWithBeanDefinitionHavingPropertyReference() { + CodeSnippet registration = simpleConfigurationRegistration(bd -> bd.getPropertyValues() + .addPropertyValue("myService", new RuntimeBeanReference("test"))); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", SimpleConfiguration.class) + .instanceSupplier(() -> SimpleConfiguration::new).customize((bd) -> bd.getPropertyValues().addPropertyValue("myService", new RuntimeBeanReference("test"))).register(beanFactory); + """); + assertThat(registration.hasImport(RuntimeBeanReference.class)).isTrue(); + } + + @Test + void generateWithBeanDefinitionHavingPropertyAsBeanDefinition() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + BeanDefinition innerBeanDefinition = BeanDefinitionBuilder.rootBeanDefinition(SimpleConfiguration.class, "stringBean") + .getBeanDefinition(); + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(ConfigurableBean.class) + .addPropertyValue("name", innerBeanDefinition).getBeanDefinition(); + getContribution(beanFactory, beanDefinition).applyTo(this.initialization); + CodeSnippet registration = CodeSnippet.of(this.initialization.toCodeBlock()); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", ConfigurableBean.class) + .instanceSupplier(ConfigurableBean::new).customize((bd) -> bd.getPropertyValues().addPropertyValue("name", BeanDefinitionRegistrar.inner(SimpleConfiguration.class).withFactoryMethod(SimpleConfiguration.class, "stringBean") + .instanceSupplier(() -> beanFactory.getBean(SimpleConfiguration.class).stringBean()).toBeanDefinition())).register(beanFactory); + """); + assertThat(registration.hasImport(SimpleConfiguration.class)).isTrue(); + } + + @Test + void generateWithBeanDefinitionHavingPropertyAsListOfBeanDefinitions() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + BeanDefinition innerBeanDefinition = BeanDefinitionBuilder.rootBeanDefinition(SimpleConfiguration.class, "stringBean") + .getBeanDefinition(); + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(ConfigurableBean.class) + .addPropertyValue("names", List.of(innerBeanDefinition, innerBeanDefinition)).getBeanDefinition(); + getContribution(beanFactory, beanDefinition).applyTo(this.initialization); + CodeSnippet registration = CodeSnippet.of(this.initialization.toCodeBlock()); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", ConfigurableBean.class) + .instanceSupplier(ConfigurableBean::new).customize((bd) -> bd.getPropertyValues().addPropertyValue("names", List.of(BeanDefinitionRegistrar.inner(SimpleConfiguration.class).withFactoryMethod(SimpleConfiguration.class, "stringBean") + .instanceSupplier(() -> beanFactory.getBean(SimpleConfiguration.class).stringBean()).toBeanDefinition(), BeanDefinitionRegistrar.inner(SimpleConfiguration.class).withFactoryMethod(SimpleConfiguration.class, "stringBean") + .instanceSupplier(() -> beanFactory.getBean(SimpleConfiguration.class).stringBean()).toBeanDefinition()))).register(beanFactory); + """); + assertThat(registration.hasImport(SimpleConfiguration.class)).isTrue(); + } + + @Test + void generateWithBeanDefinitionHavingPropertyAsBeanDefinitionUseDedicatedVariableNames() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + BeanDefinition innerBeanDefinition = BeanDefinitionBuilder.rootBeanDefinition(SimpleConfiguration.class, "stringBean") + .setRole(2).getBeanDefinition(); + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(ConfigurableBean.class) + .addPropertyValue("name", innerBeanDefinition).getBeanDefinition(); + getContribution(beanFactory, beanDefinition).applyTo(this.initialization); + CodeSnippet registration = CodeSnippet.of(this.initialization.toCodeBlock()); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", ConfigurableBean.class) + .instanceSupplier(ConfigurableBean::new).customize((bd) -> bd.getPropertyValues().addPropertyValue("name", BeanDefinitionRegistrar.inner(SimpleConfiguration.class).withFactoryMethod(SimpleConfiguration.class, "stringBean") + .instanceSupplier(() -> beanFactory.getBean(SimpleConfiguration.class).stringBean()).customize((bd_) -> bd_.setRole(2)).toBeanDefinition())).register(beanFactory); + """); + assertThat(registration.hasImport(SimpleConfiguration.class)).isTrue(); + } + + @Test + void generateUsingSingleConstructorArgument() { + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(String.class).getBeanDefinition(); + beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, "hello"); + CodeSnippet registration = beanRegistration(beanDefinition, method(SampleFactory.class, "create", String.class), + code -> code.add("() -> test")); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", String.class).withFactoryMethod(SampleFactory.class, "create", String.class) + .instanceSupplier(() -> test).customize((bd) -> bd.getConstructorArgumentValues().addIndexedArgumentValue(0, "hello")).register(beanFactory); + """); + } + + @Test + void generateUsingSeveralConstructorArguments() { + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(String.class) + .addConstructorArgValue(42).addConstructorArgReference("testBean") + .getBeanDefinition(); + CodeSnippet registration = beanRegistration(beanDefinition, method(SampleFactory.class, "create", Number.class, String.class), + code -> code.add("() -> test")); + assertThat(registration.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("test", String.class).withFactoryMethod(SampleFactory.class, "create", Number.class, String.class) + .instanceSupplier(() -> test).customize((bd) -> { + ConstructorArgumentValues argumentValues = bd.getConstructorArgumentValues(); + argumentValues.addIndexedArgumentValue(0, 42); + argumentValues.addIndexedArgumentValue(1, new RuntimeBeanReference("testBean")); + }).register(beanFactory); + """); + assertThat(registration.hasImport(ConstructorArgumentValues.class)).isTrue(); + } + + @Test + void registerRuntimeHintsWithNoPropertyValuesDoesNotAccessRuntimeHints() { + RootBeanDefinition bd = new RootBeanDefinition(String.class); + RuntimeHints runtimeHints = mock(RuntimeHints.class); + getContribution(new DefaultListableBeanFactory(), bd).registerRuntimeHints(runtimeHints); + verifyNoInteractions(runtimeHints); + } + + @Test + void registerRuntimeHintsWithInvalidProperty() { + BeanDefinition bd = BeanDefinitionBuilder.rootBeanDefinition(ConfigurableBean.class) + .addPropertyValue("notAProperty", "invalid").addPropertyValue("name", "hello") + .getBeanDefinition(); + RuntimeHints runtimeHints = new RuntimeHints(); + getContribution(new DefaultListableBeanFactory(), bd).registerRuntimeHints(runtimeHints); + assertThat(runtimeHints.reflection().getTypeHint(ConfigurableBean.class)).satisfies(hint -> { + assertThat(hint.fields()).isEmpty(); + assertThat(hint.constructors()).isEmpty(); + assertThat(hint.methods()).singleElement().satisfies(methodHint -> { + assertThat(methodHint.getName()).isEqualTo("setName"); + assertThat(methodHint.getParameterTypes()).containsExactly(TypeReference.of(String.class)); + assertThat(methodHint.getModes()).containsOnly(ExecutableMode.INVOKE); + }); + assertThat(hint.getMemberCategories()).isEmpty(); + }); + } + + @Test + void registerRuntimeHintsForPropertiesUseDeclaringClass() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("environment", mock(Environment.class)); + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(IntegerFactoryBean.class) + .addConstructorArgReference("environment") + .addPropertyValue("name", "Hello").getBeanDefinition(); + getContribution(beanFactory, beanDefinition).applyTo(this.initialization); + ReflectionHints reflectionHints = this.initialization.generatedTypeContext().runtimeHints().reflection(); + assertThat(reflectionHints.typeHints()).anySatisfy(typeHint -> { + assertThat(typeHint.getType()).isEqualTo(TypeReference.of(BaseFactoryBean.class)); + assertThat(typeHint.constructors()).isEmpty(); + assertThat(typeHint.methods()).singleElement() + .satisfies(methodHint("setName", String.class)); + assertThat(typeHint.fields()).isEmpty(); + }).anySatisfy(typeHint -> { + assertThat(typeHint.getType()).isEqualTo(TypeReference.of(IntegerFactoryBean.class)); + assertThat(typeHint.constructors()).singleElement() + .satisfies(constructorHint(Environment.class)); + assertThat(typeHint.methods()).isEmpty(); + assertThat(typeHint.fields()).isEmpty(); + }).hasSize(2); + } + + + @Test + void registerRuntimeHintsForProperties() { + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(NameAndCountersComponent.class) + .addPropertyValue("name", "Hello").addPropertyValue("counter", 42).getBeanDefinition(); + getContribution(new DefaultListableBeanFactory(), beanDefinition).applyTo(this.initialization); + ReflectionHints reflectionHints = this.initialization.generatedTypeContext().runtimeHints().reflection(); + assertThat(reflectionHints.typeHints()).singleElement().satisfies(typeHint -> { + assertThat(typeHint.getType()).isEqualTo(TypeReference.of(NameAndCountersComponent.class)); + assertThat(typeHint.constructors()).isEmpty(); + assertThat(typeHint.methods()).anySatisfy(methodHint("setName", String.class)) + .anySatisfy(methodHint("setCounter", Integer.class)).hasSize(2); + assertThat(typeHint.fields()).isEmpty(); + }); + } + + + @Test + void registerReflectionEntriesForInnerBeanDefinition() { + AbstractBeanDefinition innerBd = BeanDefinitionBuilder.rootBeanDefinition(IntegerFactoryBean.class) + .addPropertyValue("name", "test").getBeanDefinition(); + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(NameAndCountersComponent.class) + .addPropertyValue("counter", innerBd).getBeanDefinition(); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("environment", Environment.class); + getContribution(beanFactory, beanDefinition).applyTo(this.initialization); + ReflectionHints reflectionHints = this.initialization.generatedTypeContext().runtimeHints().reflection(); + assertThat(reflectionHints.typeHints()).anySatisfy(typeHint -> { + assertThat(typeHint.getType()).isEqualTo(TypeReference.of(NameAndCountersComponent.class)); + assertThat(typeHint.constructors()).isEmpty(); + assertThat(typeHint.methods()).singleElement().satisfies(methodHint("setCounter", Integer.class)); + assertThat(typeHint.fields()).isEmpty(); + }).anySatisfy(typeHint -> { + assertThat(typeHint.getType()).isEqualTo(TypeReference.of(BaseFactoryBean.class)); + assertThat(typeHint.methods()).singleElement().satisfies(methodHint("setName", String.class)); + }).anySatisfy(typeHint -> { + assertThat(typeHint.getType()).isEqualTo(TypeReference.of(IntegerFactoryBean.class)); + assertThat(typeHint.constructors()).singleElement().satisfies(constructorHint(Environment.class)); + }).hasSize(3); + } + + @Test + void registerReflectionEntriesForListOfInnerBeanDefinition() { + AbstractBeanDefinition innerBd1 = BeanDefinitionBuilder.rootBeanDefinition(IntegerFactoryBean.class) + .addPropertyValue("name", "test").getBeanDefinition(); + AbstractBeanDefinition innerBd2 = BeanDefinitionBuilder.rootBeanDefinition(AnotherIntegerFactoryBean.class) + .addPropertyValue("name", "test").getBeanDefinition(); + BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(NameAndCountersComponent.class) + .addPropertyValue("counters", List.of(innerBd1, innerBd2)).getBeanDefinition(); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("environment", Environment.class); + getContribution(beanFactory, beanDefinition).applyTo(this.initialization); + ReflectionHints reflectionHints = this.initialization.generatedTypeContext().runtimeHints().reflection(); + assertThat(reflectionHints.typeHints()).anySatisfy(typeHint -> { + assertThat(typeHint.getType()).isEqualTo(TypeReference.of(NameAndCountersComponent.class)); + assertThat(typeHint.constructors()).isEmpty(); + assertThat(typeHint.methods()).singleElement().satisfies(methodHint("setCounters", List.class)); + assertThat(typeHint.fields()).isEmpty(); + }).anySatisfy(typeHint -> { + assertThat(typeHint.getType()).isEqualTo(TypeReference.of(BaseFactoryBean.class)); + assertThat(typeHint.methods()).singleElement().satisfies(methodHint("setName", String.class)); + }).anySatisfy(typeHint -> { + assertThat(typeHint.getType()).isEqualTo(TypeReference.of(IntegerFactoryBean.class)); + assertThat(typeHint.constructors()).singleElement().satisfies(constructorHint(Environment.class)); + }).anySatisfy(typeHint -> { + assertThat(typeHint.getType()).isEqualTo(TypeReference.of(AnotherIntegerFactoryBean.class)); + assertThat(typeHint.constructors()).singleElement().satisfies(constructorHint(Environment.class)); + }).hasSize(4); + } + + private Consumer methodHint(String name, Class... parameterTypes) { + return executableHint -> { + assertThat(executableHint.getName()).isEqualTo(name); + assertThat(executableHint.getParameterTypes()).containsExactly(Arrays.stream(parameterTypes) + .map(TypeReference::of).toArray(TypeReference[]::new)); + }; + } + + private Consumer constructorHint(Class... parameterTypes) { + return methodHint("", parameterTypes); + } + + + private CodeSnippet simpleConfigurationRegistration(Consumer bd) { + RootBeanDefinition beanDefinition = (RootBeanDefinition) BeanDefinitionBuilder + .rootBeanDefinition(SimpleConfiguration.class).getBeanDefinition(); + bd.accept(beanDefinition); + return beanRegistration(beanDefinition, singleConstructor(SimpleConfiguration.class), + code -> code.add("() -> SimpleConfiguration::new")); + } + + private BeanRegistrationBeanFactoryContribution getContribution(DefaultListableBeanFactory beanFactory, BeanDefinition beanDefinition) { + BeanRegistrationBeanFactoryContribution contribution = new DefaultBeanRegistrationContributionProvider(beanFactory) + .getContributionFor("test", (RootBeanDefinition) beanDefinition); + assertThat(contribution).isNotNull(); + return contribution; + } + + private BeanFactoryContribution getContribution(BeanDefinition beanDefinition, Executable instanceCreator) { + return new BeanRegistrationBeanFactoryContribution("test", beanDefinition, + new DefaultBeanInstantiationGenerator(instanceCreator, Collections.emptyList())); + } + + private CodeSnippet beanRegistration(BeanDefinition beanDefinition, Executable instanceCreator, Consumer instanceSupplier) { + BeanRegistrationBeanFactoryContribution generator = new BeanRegistrationBeanFactoryContribution("test", beanDefinition, + new DefaultBeanInstantiationGenerator(instanceCreator, Collections.emptyList())); + return CodeSnippet.of(generator.generateBeanRegistration(new RuntimeHints(), + toMultiStatements(instanceSupplier))); + } + + private Constructor singleConstructor(Class type) { + return type.getDeclaredConstructors()[0]; + } + + private Method method(Class type, String name, Class... parameterTypes) { + Method method = ReflectionUtils.findMethod(type, name, parameterTypes); + assertThat(method).isNotNull(); + return method; + } + + private MultiStatement toMultiStatements(Consumer instanceSupplier) { + Builder code = CodeBlock.builder(); + instanceSupplier.accept(code); + MultiStatement statements = new MultiStatement(); + statements.add(code.build()); + return statements; + } + + private String codeOf(GeneratedType type) { + try { + StringWriter out = new StringWriter(); + type.toJavaFile().writeTo(out); + return out.toString(); + } + catch (IOException ex) { + throw new IllegalStateException(ex); + } + } + + private String removeIndent(String content, int indent) { + return content.lines().map(line -> { + for (int i = 0; i < indent; i++) { + if (line.startsWith("\t")) { + line = line.substring(1); + } + } + return line; + }).collect(Collectors.joining("\n")); + } + + static abstract class BaseFactoryBean { + + public void setName(String name) { + + } + + } + + @SuppressWarnings("unused") + static class IntegerFactoryBean extends BaseFactoryBean implements FactoryBean { + + public IntegerFactoryBean(Environment environment) { + + } + + @Override + public Class getObjectType() { + return Integer.class; + } + + @Override + public Integer getObject() { + return 42; + } + + } + + @SuppressWarnings("unused") + static class AnotherIntegerFactoryBean extends IntegerFactoryBean { + + public AnotherIntegerFactoryBean(Environment environment) { + super(environment); + } + + } + + static class NameAndCountersComponent { + + private String name; + + private List counters; + + public void setName(String name) { + this.name = name; + } + + public void setCounter(Integer counter) { + setCounters(List.of(counter)); + } + + public void setCounters(List counters) { + this.counters = counters; + } + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/generator/DefaultBeanRegistrationContributionProviderTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/generator/DefaultBeanRegistrationContributionProviderTests.java new file mode 100644 index 00000000000..412deeb47bb --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/generator/DefaultBeanRegistrationContributionProviderTests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.factory.generator.SimpleConfiguration; +import org.springframework.core.Ordered; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Tests for {@link DefaultBeanRegistrationContributionProvider}. + * + * @author Stephane Nicoll + */ +class DefaultBeanRegistrationContributionProviderTests { + + @Test + void aotContributingBeanPostProcessorsAreIncluded() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + AotContributingBeanPostProcessor first = mockNoOpPostProcessor(-1); + AotContributingBeanPostProcessor second = mockNoOpPostProcessor(5); + beanFactory.registerBeanDefinition("second", BeanDefinitionBuilder.rootBeanDefinition( + AotContributingBeanPostProcessor.class, () -> second).getBeanDefinition()); + beanFactory.registerBeanDefinition("first", BeanDefinitionBuilder.rootBeanDefinition( + AotContributingBeanPostProcessor.class, () -> first).getBeanDefinition()); + RootBeanDefinition beanDefinition = new RootBeanDefinition(SimpleConfiguration.class); + new DefaultBeanRegistrationContributionProvider(beanFactory).getContributionFor( + "test", beanDefinition); + verify((Ordered) second).getOrder(); + verify((Ordered) first).getOrder(); + verify(first).contribute(beanDefinition, SimpleConfiguration.class, "test"); + verify(second).contribute(beanDefinition, SimpleConfiguration.class, "test"); + verifyNoMoreInteractions(first, second); + } + + + private AotContributingBeanPostProcessor mockNoOpPostProcessor(int order) { + AotContributingBeanPostProcessor postProcessor = mock(AotContributingBeanPostProcessor.class); + given(postProcessor.contribute(any(), any(), any())).willReturn(null); + given(postProcessor.getOrder()).willReturn(order); + return postProcessor; + } + +}