Merge AOT constructor and factory method resolution into ConstructorResolver

This moves related code into the same class, unifies candidate determination for constructors and factory methods, and gets rid of the package cycle around the hard-coded Autowired annotation check (which is implicitly coming from AutowiredAnnotationBeanPostProcessor via the determineCandidateConstructors SPI now). The API entry point for AOT pre-resolution purposes is in RegisteredBean.

Closes gh-27920
This commit is contained in:
Juergen Hoeller 2022-10-06 11:59:11 +02:00
parent 3af0c232b7
commit aedef9321a
7 changed files with 377 additions and 468 deletions

View File

@ -76,12 +76,12 @@ class BeanDefinitionMethodGenerator {
this.methodGeneratorFactory = methodGeneratorFactory;
this.registeredBean = registeredBean;
this.constructorOrFactoryMethod = ConstructorOrFactoryMethodResolver
.resolve(registeredBean);
this.constructorOrFactoryMethod = registeredBean.resolveConstructorOrFactoryMethod();
this.innerBeanPropertyName = innerBeanPropertyName;
this.aotContributions = aotContributions;
}
/**
* Generate the method that returns the {@link BeanDefinition} to be
* registered.
@ -95,22 +95,17 @@ class BeanDefinitionMethodGenerator {
registerRuntimeHintsIfNecessary(generationContext.getRuntimeHints());
BeanRegistrationCodeFragments codeFragments = getCodeFragments(generationContext,
beanRegistrationsCode);
ClassName target = codeFragments.getTarget(this.registeredBean,
this.constructorOrFactoryMethod);
ClassName target = codeFragments.getTarget(this.registeredBean, this.constructorOrFactoryMethod);
if (!target.canonicalName().startsWith("java.")) {
GeneratedClass generatedClass = lookupGeneratedClass(generationContext, target);
GeneratedMethods generatedMethods = generatedClass.getMethods()
.withPrefix(getName());
GeneratedMethod generatedMethod = generateBeanDefinitionMethod(
generationContext, generatedClass.getName(), generatedMethods,
codeFragments, Modifier.PUBLIC);
GeneratedMethods generatedMethods = generatedClass.getMethods().withPrefix(getName());
GeneratedMethod generatedMethod = generateBeanDefinitionMethod(generationContext,
generatedClass.getName(), generatedMethods, codeFragments, Modifier.PUBLIC);
return generatedMethod.toMethodReference();
}
GeneratedMethods generatedMethods = beanRegistrationsCode.getMethods()
.withPrefix(getName());
GeneratedMethods generatedMethods = beanRegistrationsCode.getMethods().withPrefix(getName());
GeneratedMethod generatedMethod = generateBeanDefinitionMethod(generationContext,
beanRegistrationsCode.getClassName(), generatedMethods, codeFragments,
Modifier.PRIVATE);
beanRegistrationsCode.getClassName(), generatedMethods, codeFragments, Modifier.PRIVATE);
return generatedMethod.toMethodReference();
}
@ -219,6 +214,7 @@ class BeanDefinitionMethodGenerator {
}
}
private static class ProxyRuntimeHintsRegistrar {
private final AutowireCandidateResolver candidateResolver;
@ -231,8 +227,7 @@ class BeanDefinitionMethodGenerator {
Class<?>[] parameterTypes = method.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
MethodParameter methodParam = new MethodParameter(method, i);
DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(
methodParam, true);
DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(methodParam, true);
registerProxyIfNecessary(runtimeHints, dependencyDescriptor);
}
}
@ -248,13 +243,11 @@ class BeanDefinitionMethodGenerator {
}
private void registerProxyIfNecessary(RuntimeHints runtimeHints, DependencyDescriptor dependencyDescriptor) {
Class<?> proxyType = this.candidateResolver
.getLazyResolutionProxyClass(dependencyDescriptor, null);
Class<?> proxyType = this.candidateResolver.getLazyResolutionProxyClass(dependencyDescriptor, null);
if (proxyType != null && Proxy.isProxyClass(proxyType)) {
runtimeHints.proxies().registerJdkProxy(proxyType.getInterfaces());
}
}
}
}

View File

@ -1,407 +0,0 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.beans.factory.aot;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanReference;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.beans.factory.config.ConstructorArgumentValues;
import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.AbstractBeanDefinition;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
/**
* Resolves the {@link Executable} (factory method or constructor) that should
* be used to create a bean. This class is similar to
* {@code org.springframework.beans.factory.support.ConstructorResolver} but it
* doesn't need bean initialization.
*
* @author Stephane Nicoll
* @author Phillip Webb
* @since 6.0
*/
class ConstructorOrFactoryMethodResolver {
private final ConfigurableBeanFactory beanFactory;
@Nullable
private final ClassLoader classLoader;
ConstructorOrFactoryMethodResolver(ConfigurableBeanFactory beanFactory) {
this.beanFactory = beanFactory;
this.classLoader = (beanFactory.getBeanClassLoader() != null ?
beanFactory.getBeanClassLoader() : ClassUtils.getDefaultClassLoader());
}
Executable resolve(BeanDefinition beanDefinition) {
Supplier<ResolvableType> beanType = () -> getBeanType(beanDefinition);
List<ResolvableType> valueTypes = (beanDefinition.hasConstructorArgumentValues() ?
determineParameterValueTypes(beanDefinition.getConstructorArgumentValues()) :
Collections.emptyList());
Method resolvedFactoryMethod = resolveFactoryMethod(beanDefinition, valueTypes);
if (resolvedFactoryMethod != null) {
return resolvedFactoryMethod;
}
Class<?> factoryBeanClass = getFactoryBeanClass(beanDefinition);
if (factoryBeanClass != null && !factoryBeanClass.equals(beanDefinition.getResolvableType().toClass())) {
ResolvableType resolvableType = beanDefinition.getResolvableType();
boolean isCompatible = ResolvableType.forClass(factoryBeanClass)
.as(FactoryBean.class).getGeneric(0).isAssignableFrom(resolvableType);
Assert.state(isCompatible, () -> String.format(
"Incompatible target type '%s' for factory bean '%s'",
resolvableType.toClass().getName(), factoryBeanClass.getName()));
Executable executable = resolveConstructor(() -> ResolvableType.forClass(factoryBeanClass), valueTypes);
if (executable != null) {
return executable;
}
throw new IllegalStateException("No suitable FactoryBean constructor found for "
+ beanDefinition + " and argument types " + valueTypes);
}
Executable resolvedConstructor = resolveConstructor(beanType, valueTypes);
if (resolvedConstructor != null) {
return resolvedConstructor;
}
throw new IllegalStateException("No constructor or factory method candidate found for "
+ beanDefinition + " and argument types " + valueTypes);
}
private List<ResolvableType> determineParameterValueTypes(
ConstructorArgumentValues constructorArgumentValues) {
List<ResolvableType> parameterTypes = new ArrayList<>();
for (ValueHolder valueHolder : constructorArgumentValues
.getIndexedArgumentValues().values()) {
parameterTypes.add(determineParameterValueType(valueHolder));
}
return parameterTypes;
}
private ResolvableType determineParameterValueType(ValueHolder valueHolder) {
if (valueHolder.getType() != null) {
return ResolvableType.forClass(loadClass(valueHolder.getType()));
}
Object value = valueHolder.getValue();
if (value instanceof BeanReference br) {
if (value instanceof RuntimeBeanReference rbr) {
if (rbr.getBeanType() != null) {
return ResolvableType.forClass(rbr.getBeanType());
}
}
return ResolvableType.forClass(this.beanFactory.getType(br.getBeanName(), false));
}
if (value instanceof BeanDefinition bd) {
return extractTypeFromBeanDefinition(getBeanType(bd));
}
if (value instanceof Class<?> clazz) {
return ResolvableType.forClassWithGenerics(Class.class, clazz);
}
return ResolvableType.forInstance(value);
}
private ResolvableType extractTypeFromBeanDefinition(ResolvableType type) {
if (FactoryBean.class.isAssignableFrom(type.toClass())) {
return type.as(FactoryBean.class).getGeneric(0);
}
return type;
}
@Nullable
private Method resolveFactoryMethod(BeanDefinition beanDefinition, List<ResolvableType> valueTypes) {
if (beanDefinition instanceof RootBeanDefinition rbd) {
Method resolvedFactoryMethod = rbd.getResolvedFactoryMethod();
if (resolvedFactoryMethod != null) {
return resolvedFactoryMethod;
}
}
String factoryMethodName = beanDefinition.getFactoryMethodName();
if (factoryMethodName != null) {
String factoryBeanName = beanDefinition.getFactoryBeanName();
Class<?> beanClass = getBeanClass(factoryBeanName != null ?
this.beanFactory.getMergedBeanDefinition(factoryBeanName) : beanDefinition);
List<Method> methods = new ArrayList<>();
Assert.state(beanClass != null,
() -> "Failed to determine bean class of " + beanDefinition);
ReflectionUtils.doWithMethods(beanClass, methods::add,
method -> isFactoryMethodCandidate(beanClass, method, factoryMethodName));
if (methods.size() >= 1) {
Function<Method, List<ResolvableType>> parameterTypesFactory = method -> {
List<ResolvableType> types = new ArrayList<>();
for (int i = 0; i < method.getParameterCount(); i++) {
types.add(ResolvableType.forMethodParameter(method, i));
}
return types;
};
return (Method) resolveFactoryMethod(methods, parameterTypesFactory, valueTypes);
}
}
return null;
}
private boolean isFactoryMethodCandidate(Class<?> beanClass, Method method, String factoryMethodName) {
if (method.getName().equals(factoryMethodName)) {
if (Modifier.isStatic(method.getModifiers())) {
return method.getDeclaringClass().equals(beanClass);
}
return !Modifier.isPrivate(method.getModifiers());
}
return false;
}
@Nullable
private Executable resolveConstructor(Supplier<ResolvableType> beanType, List<ResolvableType> valueTypes) {
Class<?> type = ClassUtils.getUserClass(beanType.get().toClass());
Constructor<?>[] constructors = type.getDeclaredConstructors();
if (constructors.length == 1) {
return constructors[0];
}
for (Constructor<?> constructor : constructors) {
if (MergedAnnotations.from(constructor).isPresent(Autowired.class)) {
return constructor;
}
}
Function<Constructor<?>, List<ResolvableType>> parameterTypesFactory = executable -> {
List<ResolvableType> types = new ArrayList<>();
for (int i = 0; i < executable.getParameterCount(); i++) {
types.add(ResolvableType.forConstructorParameter(executable, i));
}
return types;
};
List<? extends Executable> matches = Arrays.stream(constructors)
.filter(executable -> match(parameterTypesFactory.apply(executable),
valueTypes, FallbackMode.NONE))
.toList();
if (matches.size() == 1) {
return matches.get(0);
}
List<? extends Executable> assignableElementFallbackMatches = Arrays
.stream(constructors)
.filter(executable -> match(parameterTypesFactory.apply(executable),
valueTypes, FallbackMode.ASSIGNABLE_ELEMENT))
.toList();
if (assignableElementFallbackMatches.size() == 1) {
return assignableElementFallbackMatches.get(0);
}
List<? extends Executable> typeConversionFallbackMatches = Arrays
.stream(constructors)
.filter(executable -> match(parameterTypesFactory.apply(executable),
valueTypes, FallbackMode.TYPE_CONVERSION))
.toList();
return (typeConversionFallbackMatches.size() == 1)
? typeConversionFallbackMatches.get(0) : null;
}
@Nullable
private Executable resolveFactoryMethod(List<Method> executables,
Function<Method, List<ResolvableType>> parameterTypesFactory,
List<ResolvableType> valueTypes) {
List<? extends Executable> matches = executables.stream()
.filter(executable -> match(parameterTypesFactory.apply(executable), valueTypes, FallbackMode.NONE))
.toList();
if (matches.size() == 1) {
return matches.get(0);
}
List<? extends Executable> assignableElementFallbackMatches = executables.stream()
.filter(executable -> match(parameterTypesFactory.apply(executable),
valueTypes, FallbackMode.ASSIGNABLE_ELEMENT))
.toList();
if (assignableElementFallbackMatches.size() == 1) {
return assignableElementFallbackMatches.get(0);
}
List<? extends Executable> typeConversionFallbackMatches = executables.stream()
.filter(executable -> match(parameterTypesFactory.apply(executable),
valueTypes, FallbackMode.TYPE_CONVERSION))
.toList();
Assert.state(typeConversionFallbackMatches.size() <= 1,
() -> "Multiple matches with parameters '" + valueTypes + "': " + typeConversionFallbackMatches);
return (typeConversionFallbackMatches.size() == 1 ? typeConversionFallbackMatches.get(0) : null);
}
private boolean match(
List<ResolvableType> parameterTypes, List<ResolvableType> valueTypes, FallbackMode fallbackMode) {
if (parameterTypes.size() != valueTypes.size()) {
return false;
}
for (int i = 0; i < parameterTypes.size(); i++) {
if (!isMatch(parameterTypes.get(i), valueTypes.get(i), fallbackMode)) {
return false;
}
}
return true;
}
private boolean isMatch(ResolvableType parameterType, ResolvableType valueType, FallbackMode fallbackMode) {
if (isAssignable(valueType).test(parameterType)) {
return true;
}
return switch (fallbackMode) {
case ASSIGNABLE_ELEMENT ->
isAssignable(valueType).test(extractElementType(parameterType));
case TYPE_CONVERSION -> typeConversionFallback(valueType).test(parameterType);
default -> false;
};
}
private Predicate<ResolvableType> isAssignable(ResolvableType valueType) {
return parameterType -> parameterType.isAssignableFrom(valueType);
}
private ResolvableType extractElementType(ResolvableType parameterType) {
if (parameterType.isArray()) {
return parameterType.getComponentType();
}
if (Collection.class.isAssignableFrom(parameterType.toClass())) {
return parameterType.as(Collection.class).getGeneric(0);
}
return ResolvableType.NONE;
}
private Predicate<ResolvableType> typeConversionFallback(ResolvableType valueType) {
return parameterType -> {
if (valueOrCollection(valueType, this::isStringForClassFallback).test(parameterType)) {
return true;
}
return valueOrCollection(valueType, this::isSimpleValueType).test(parameterType);
};
}
private Predicate<ResolvableType> valueOrCollection(ResolvableType valueType,
Function<ResolvableType, Predicate<ResolvableType>> predicateProvider) {
return parameterType -> {
if (predicateProvider.apply(valueType).test(parameterType)) {
return true;
}
if (predicateProvider.apply(extractElementType(valueType)).test(extractElementType(parameterType))) {
return true;
}
return (predicateProvider.apply(valueType).test(extractElementType(parameterType)));
};
}
/**
* Return a {@link Predicate} for a parameter type that checks if its target
* value is a {@link Class} and the value type is a {@link String}. This is
* a regular use cases where a {@link Class} is defined in the bean
* definition as an FQN.
* @param valueType the type of the value
* @return a predicate to indicate a fallback match for a String to Class
* parameter
*/
private Predicate<ResolvableType> isStringForClassFallback(ResolvableType valueType) {
return parameterType -> (valueType.isAssignableFrom(String.class) &&
parameterType.isAssignableFrom(Class.class));
}
private Predicate<ResolvableType> isSimpleValueType(ResolvableType valueType) {
return parameterType -> (BeanUtils.isSimpleValueType(parameterType.toClass()) &&
BeanUtils.isSimpleValueType(valueType.toClass()));
}
@Nullable
private Class<?> getFactoryBeanClass(BeanDefinition beanDefinition) {
if (beanDefinition instanceof RootBeanDefinition rbd) {
if (rbd.hasBeanClass()) {
Class<?> beanClass = rbd.getBeanClass();
return (FactoryBean.class.isAssignableFrom(beanClass) ? beanClass : null);
}
}
return null;
}
@Nullable
private Class<?> getBeanClass(BeanDefinition beanDefinition) {
if (beanDefinition instanceof AbstractBeanDefinition abd && abd.hasBeanClass()) {
return abd.getBeanClass();
}
return (beanDefinition.getBeanClassName() != null ? loadClass(beanDefinition.getBeanClassName()) : null);
}
private ResolvableType getBeanType(BeanDefinition beanDefinition) {
ResolvableType resolvableType = beanDefinition.getResolvableType();
if (resolvableType != ResolvableType.NONE) {
return resolvableType;
}
if (beanDefinition instanceof RootBeanDefinition rbd) {
if (rbd.hasBeanClass()) {
return ResolvableType.forClass(rbd.getBeanClass());
}
}
String beanClassName = beanDefinition.getBeanClassName();
if (beanClassName != null) {
return ResolvableType.forClass(loadClass(beanClassName));
}
throw new IllegalStateException(
"Failed to determine bean class of " + beanDefinition);
}
private Class<?> loadClass(String beanClassName) {
try {
return ClassUtils.forName(beanClassName, this.classLoader);
}
catch (ClassNotFoundException ex) {
throw new IllegalStateException("Failed to load class " + beanClassName);
}
}
static Executable resolve(RegisteredBean registeredBean) {
return new ConstructorOrFactoryMethodResolver(registeredBean.getBeanFactory())
.resolve(registeredBean.getMergedBeanDefinition());
}
enum FallbackMode {
NONE,
ASSIGNABLE_ELEMENT,
TYPE_CONVERSION
}
}

View File

@ -25,6 +25,7 @@ import java.lang.reflect.Modifier;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashSet;
@ -32,10 +33,14 @@ import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.apache.commons.logging.Log;
import org.springframework.beans.BeanMetadataElement;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.BeanWrapper;
import org.springframework.beans.BeanWrapperImpl;
import org.springframework.beans.BeansException;
@ -43,18 +48,23 @@ import org.springframework.beans.TypeConverter;
import org.springframework.beans.TypeMismatchException;
import org.springframework.beans.factory.BeanCreationException;
import org.springframework.beans.factory.BeanDefinitionStoreException;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.InjectionPoint;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.UnsatisfiedDependencyException;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanReference;
import org.springframework.beans.factory.config.ConstructorArgumentValues;
import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder;
import org.springframework.beans.factory.config.DependencyDescriptor;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.core.CollectionFactory;
import org.springframework.core.MethodParameter;
import org.springframework.core.NamedThreadLocal;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.core.ResolvableType;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
@ -74,9 +84,12 @@ import org.springframework.util.StringUtils;
* @author Costin Leau
* @author Sebastien Deleuze
* @author Sam Brannen
* @author Stephane Nicoll
* @author Phil Webb
* @since 2.0
* @see #autowireConstructor
* @see #instantiateUsingFactoryMethod
* @see #resolveConstructorOrFactoryMethod
* @see AbstractAutowireCapableBeanFactory
*/
class ConstructorResolver {
@ -108,6 +121,8 @@ class ConstructorResolver {
}
// BeanWrapper-based construction
/**
* "autowire constructor" (with constructor arguments by type) behavior.
* Also applied if explicit constructor argument values are specified,
@ -329,7 +344,7 @@ class ConstructorResolver {
Method[] candidates = getCandidateMethods(factoryClass, mbd);
Method uniqueCandidate = null;
for (Method candidate : candidates) {
if (Modifier.isStatic(candidate.getModifiers()) == isStatic && mbd.isFactoryMethod(candidate)) {
if ((!isStatic || isStaticCandidate(candidate, factoryClass)) && mbd.isFactoryMethod(candidate)) {
if (uniqueCandidate == null) {
uniqueCandidate = candidate;
}
@ -359,6 +374,10 @@ class ConstructorResolver {
ReflectionUtils.getAllDeclaredMethods(factoryClass) : factoryClass.getMethods());
}
private boolean isStaticCandidate(Method method, Class<?> factoryClass) {
return (Modifier.isStatic(method.getModifiers()) && method.getDeclaringClass() == factoryClass);
}
/**
* Instantiate the bean using a named factory method. The method may be static, if the
* bean definition parameter specifies a class, rather than a "factory-bean", or
@ -451,7 +470,7 @@ class ConstructorResolver {
candidates = new ArrayList<>();
Method[] rawCandidates = getCandidateMethods(factoryClass, mbd);
for (Method candidate : rawCandidates) {
if (Modifier.isStatic(candidate.getModifiers()) == isStatic && mbd.isFactoryMethod(candidate)) {
if ((!isStatic || isStaticCandidate(candidate, factoryClass)) && mbd.isFactoryMethod(candidate)) {
candidates.add(candidate);
}
}
@ -593,7 +612,7 @@ class ConstructorResolver {
throw new BeanCreationException(mbd.getResourceDescription(), beanName,
"No matching factory method found on class [" + factoryClass.getName() + "]: " +
(mbd.getFactoryBeanName() != null ?
"factory bean '" + mbd.getFactoryBeanName() + "'; " : "") +
"factory bean '" + mbd.getFactoryBeanName() + "'; " : "") +
"factory method '" + mbd.getFactoryMethodName() + "(" + argDesc + ")'. " +
"Check that a method with the specified name " +
(minNrOfArgs > 0 ? "and arguments " : "") +
@ -882,6 +901,311 @@ class ConstructorResolver {
}
}
// AOT-oriented pre-resolution
public Executable resolveConstructorOrFactoryMethod(String beanName, RootBeanDefinition mbd) {
Supplier<ResolvableType> beanType = () -> getBeanType(beanName, mbd);
List<ResolvableType> valueTypes = (mbd.hasConstructorArgumentValues() ?
determineParameterValueTypes(mbd) : Collections.emptyList());
Method resolvedFactoryMethod = resolveFactoryMethod(beanName, mbd, valueTypes);
if (resolvedFactoryMethod != null) {
return resolvedFactoryMethod;
}
Class<?> factoryBeanClass = getFactoryBeanClass(beanName, mbd);
if (factoryBeanClass != null && !factoryBeanClass.equals(mbd.getResolvableType().toClass())) {
ResolvableType resolvableType = mbd.getResolvableType();
boolean isCompatible = ResolvableType.forClass(factoryBeanClass)
.as(FactoryBean.class).getGeneric(0).isAssignableFrom(resolvableType);
Assert.state(isCompatible, () -> String.format(
"Incompatible target type '%s' for factory bean '%s'",
resolvableType.toClass().getName(), factoryBeanClass.getName()));
Executable executable = resolveConstructor(beanName, mbd,
() -> ResolvableType.forClass(factoryBeanClass), valueTypes);
if (executable != null) {
return executable;
}
throw new IllegalStateException("No suitable FactoryBean constructor found for " +
mbd + " and argument types " + valueTypes);
}
Executable resolvedConstructor = resolveConstructor(beanName, mbd, beanType, valueTypes);
if (resolvedConstructor != null) {
return resolvedConstructor;
}
throw new IllegalStateException("No constructor or factory method candidate found for " +
mbd + " and argument types " + valueTypes);
}
private List<ResolvableType> determineParameterValueTypes(RootBeanDefinition mbd) {
List<ResolvableType> parameterTypes = new ArrayList<>();
for (ValueHolder valueHolder : mbd.getConstructorArgumentValues().getIndexedArgumentValues().values()) {
parameterTypes.add(determineParameterValueType(mbd, valueHolder));
}
return parameterTypes;
}
private ResolvableType determineParameterValueType(RootBeanDefinition mbd, ValueHolder valueHolder) {
if (valueHolder.getType() != null) {
return ResolvableType.forClass(
ClassUtils.resolveClassName(valueHolder.getType(), this.beanFactory.getBeanClassLoader()));
}
Object value = valueHolder.getValue();
if (value instanceof BeanReference br) {
if (value instanceof RuntimeBeanReference rbr) {
if (rbr.getBeanType() != null) {
return ResolvableType.forClass(rbr.getBeanType());
}
}
return ResolvableType.forClass(this.beanFactory.getType(br.getBeanName(), false));
}
if (value instanceof BeanDefinition innerBd) {
String nameToUse = "(inner bean)";
ResolvableType type = getBeanType(nameToUse,
this.beanFactory.getMergedBeanDefinition(nameToUse, innerBd, mbd));
return (FactoryBean.class.isAssignableFrom(type.toClass()) ?
type.as(FactoryBean.class).getGeneric(0) : type);
}
if (value instanceof Class<?> clazz) {
return ResolvableType.forClassWithGenerics(Class.class, clazz);
}
return ResolvableType.forInstance(value);
}
@Nullable
private Executable resolveConstructor(String beanName, RootBeanDefinition mbd,
Supplier<ResolvableType> beanType, List<ResolvableType> valueTypes) {
Class<?> type = ClassUtils.getUserClass(beanType.get().toClass());
Constructor<?>[] ctors = this.beanFactory.determineConstructorsFromBeanPostProcessors(type, beanName);
if (ctors == null) {
if (!mbd.hasConstructorArgumentValues()) {
ctors = mbd.getPreferredConstructors();
}
if (ctors == null) {
ctors = (mbd.isNonPublicAccessAllowed() ? type.getDeclaredConstructors() : type.getConstructors());
}
}
if (ctors.length == 1) {
return ctors[0];
}
Function<Constructor<?>, List<ResolvableType>> parameterTypesFactory = executable -> {
List<ResolvableType> types = new ArrayList<>();
for (int i = 0; i < executable.getParameterCount(); i++) {
types.add(ResolvableType.forConstructorParameter(executable, i));
}
return types;
};
List<? extends Executable> matches = Arrays.stream(ctors)
.filter(executable -> match(parameterTypesFactory.apply(executable),
valueTypes, FallbackMode.NONE))
.toList();
if (matches.size() == 1) {
return matches.get(0);
}
List<? extends Executable> assignableElementFallbackMatches = Arrays
.stream(ctors)
.filter(executable -> match(parameterTypesFactory.apply(executable),
valueTypes, FallbackMode.ASSIGNABLE_ELEMENT))
.toList();
if (assignableElementFallbackMatches.size() == 1) {
return assignableElementFallbackMatches.get(0);
}
List<? extends Executable> typeConversionFallbackMatches = Arrays
.stream(ctors)
.filter(executable -> match(parameterTypesFactory.apply(executable),
valueTypes, FallbackMode.TYPE_CONVERSION))
.toList();
return (typeConversionFallbackMatches.size() == 1 ? typeConversionFallbackMatches.get(0) : null);
}
@Nullable
private Method resolveFactoryMethod(String beanName, RootBeanDefinition mbd, List<ResolvableType> valueTypes) {
if (mbd.isFactoryMethodUnique) {
Method resolvedFactoryMethod = mbd.getResolvedFactoryMethod();
if (resolvedFactoryMethod != null) {
return resolvedFactoryMethod;
}
}
String factoryMethodName = mbd.getFactoryMethodName();
if (factoryMethodName != null) {
String factoryBeanName = mbd.getFactoryBeanName();
Class<?> factoryClass;
boolean isStatic;
if (factoryBeanName != null) {
factoryClass = this.beanFactory.getType(factoryBeanName);
isStatic = false;
}
else {
factoryClass = this.beanFactory.resolveBeanClass(mbd, beanName);
isStatic = true;
}
Assert.state(factoryClass != null, () -> "Failed to determine bean class of " + mbd);
Method[] rawCandidates = getCandidateMethods(factoryClass, mbd);
List<Method> candidates = new ArrayList<>();
for (Method candidate : rawCandidates) {
if ((!isStatic || isStaticCandidate(candidate, factoryClass)) && mbd.isFactoryMethod(candidate)) {
candidates.add(candidate);
}
}
Method result = null;
if (candidates.size() == 1) {
result = candidates.get(0);
}
else if (candidates.size() > 1) {
Function<Method, List<ResolvableType>> parameterTypesFactory = method -> {
List<ResolvableType> types = new ArrayList<>();
for (int i = 0; i < method.getParameterCount(); i++) {
types.add(ResolvableType.forMethodParameter(method, i));
}
return types;
};
result = (Method) resolveFactoryMethod(candidates, parameterTypesFactory, valueTypes);
}
if (result == null) {
throw new BeanCreationException(mbd.getResourceDescription(), beanName,
"No matching factory method found on class [" + factoryClass.getName() + "]: " +
(mbd.getFactoryBeanName() != null ?
"factory bean '" + mbd.getFactoryBeanName() + "'; " : "") +
"factory method '" + mbd.getFactoryMethodName() + "'. ");
}
return result;
}
return null;
}
@Nullable
private Executable resolveFactoryMethod(List<Method> executables,
Function<Method, List<ResolvableType>> parameterTypesFactory,
List<ResolvableType> valueTypes) {
List<? extends Executable> matches = executables.stream()
.filter(executable -> match(parameterTypesFactory.apply(executable), valueTypes, FallbackMode.NONE))
.toList();
if (matches.size() == 1) {
return matches.get(0);
}
List<? extends Executable> assignableElementFallbackMatches = executables.stream()
.filter(executable -> match(parameterTypesFactory.apply(executable),
valueTypes, FallbackMode.ASSIGNABLE_ELEMENT))
.toList();
if (assignableElementFallbackMatches.size() == 1) {
return assignableElementFallbackMatches.get(0);
}
List<? extends Executable> typeConversionFallbackMatches = executables.stream()
.filter(executable -> match(parameterTypesFactory.apply(executable),
valueTypes, FallbackMode.TYPE_CONVERSION))
.toList();
Assert.state(typeConversionFallbackMatches.size() <= 1,
() -> "Multiple matches with parameters '" + valueTypes + "': " + typeConversionFallbackMatches);
return (typeConversionFallbackMatches.size() == 1 ? typeConversionFallbackMatches.get(0) : null);
}
private boolean match(
List<ResolvableType> parameterTypes, List<ResolvableType> valueTypes, FallbackMode fallbackMode) {
if (parameterTypes.size() != valueTypes.size()) {
return false;
}
for (int i = 0; i < parameterTypes.size(); i++) {
if (!isMatch(parameterTypes.get(i), valueTypes.get(i), fallbackMode)) {
return false;
}
}
return true;
}
private boolean isMatch(ResolvableType parameterType, ResolvableType valueType, FallbackMode fallbackMode) {
if (isAssignable(valueType).test(parameterType)) {
return true;
}
return switch (fallbackMode) {
case ASSIGNABLE_ELEMENT -> isAssignable(valueType).test(extractElementType(parameterType));
case TYPE_CONVERSION -> typeConversionFallback(valueType).test(parameterType);
default -> false;
};
}
private Predicate<ResolvableType> isAssignable(ResolvableType valueType) {
return parameterType -> parameterType.isAssignableFrom(valueType);
}
private ResolvableType extractElementType(ResolvableType parameterType) {
if (parameterType.isArray()) {
return parameterType.getComponentType();
}
if (Collection.class.isAssignableFrom(parameterType.toClass())) {
return parameterType.as(Collection.class).getGeneric(0);
}
return ResolvableType.NONE;
}
private Predicate<ResolvableType> typeConversionFallback(ResolvableType valueType) {
return parameterType -> {
if (valueOrCollection(valueType, this::isStringForClassFallback).test(parameterType)) {
return true;
}
return valueOrCollection(valueType, this::isSimpleValueType).test(parameterType);
};
}
private Predicate<ResolvableType> valueOrCollection(ResolvableType valueType,
Function<ResolvableType, Predicate<ResolvableType>> predicateProvider) {
return parameterType -> {
if (predicateProvider.apply(valueType).test(parameterType)) {
return true;
}
if (predicateProvider.apply(extractElementType(valueType)).test(extractElementType(parameterType))) {
return true;
}
return (predicateProvider.apply(valueType).test(extractElementType(parameterType)));
};
}
/**
* Return a {@link Predicate} for a parameter type that checks if its target
* value is a {@link Class} and the value type is a {@link String}. This is
* a regular use cases where a {@link Class} is defined in the bean
* definition as an FQN.
* @param valueType the type of the value
* @return a predicate to indicate a fallback match for a String to Class
* parameter
*/
private Predicate<ResolvableType> isStringForClassFallback(ResolvableType valueType) {
return parameterType -> (valueType.isAssignableFrom(String.class) &&
parameterType.isAssignableFrom(Class.class));
}
private Predicate<ResolvableType> isSimpleValueType(ResolvableType valueType) {
return parameterType -> (BeanUtils.isSimpleValueType(parameterType.toClass()) &&
BeanUtils.isSimpleValueType(valueType.toClass()));
}
@Nullable
private Class<?> getFactoryBeanClass(String beanName, RootBeanDefinition mbd) {
Class<?> beanClass = this.beanFactory.resolveBeanClass(mbd, beanName);
return (beanClass != null && FactoryBean.class.isAssignableFrom(beanClass) ? beanClass : null);
}
private ResolvableType getBeanType(String beanName, RootBeanDefinition mbd) {
ResolvableType resolvableType = mbd.getResolvableType();
if (resolvableType != ResolvableType.NONE) {
return resolvableType;
}
return ResolvableType.forClass(this.beanFactory.resolveBeanClass(mbd, beanName));
}
static InjectionPoint setCurrentInjectionPoint(@Nullable InjectionPoint injectionPoint) {
InjectionPoint old = currentInjectionPoint.get();
if (injectionPoint != null) {
@ -980,4 +1304,14 @@ class ConstructorResolver {
}
}
private enum FallbackMode {
NONE,
ASSIGNABLE_ELEMENT,
TYPE_CONVERSION
}
}

View File

@ -16,6 +16,7 @@
package org.springframework.beans.factory.support;
import java.lang.reflect.Executable;
import java.util.function.BiFunction;
import java.util.function.Supplier;
@ -71,9 +72,7 @@ public final class RegisteredBean {
* @param beanName the bean name
* @return a new {@link RegisteredBean} instance
*/
public static RegisteredBean of(ConfigurableListableBeanFactory beanFactory,
String beanName) {
public static RegisteredBean of(ConfigurableListableBeanFactory beanFactory, String beanName) {
Assert.notNull(beanFactory, "'beanFactory' must not be null");
Assert.hasLength(beanName, "'beanName' must not be empty");
return new RegisteredBean(beanFactory, () -> beanName, false,
@ -87,12 +86,9 @@ public final class RegisteredBean {
* @param innerBean a {@link BeanDefinitionHolder} for the inner bean
* @return a new {@link RegisteredBean} instance
*/
public static RegisteredBean ofInnerBean(RegisteredBean parent,
BeanDefinitionHolder innerBean) {
public static RegisteredBean ofInnerBean(RegisteredBean parent, BeanDefinitionHolder innerBean) {
Assert.notNull(innerBean, "'innerBean' must not be null");
return ofInnerBean(parent, innerBean.getBeanName(),
innerBean.getBeanDefinition());
return ofInnerBean(parent, innerBean.getBeanName(), innerBean.getBeanDefinition());
}
/**
@ -101,9 +97,7 @@ public final class RegisteredBean {
* @param innerBeanDefinition the inner-bean definition
* @return a new {@link RegisteredBean} instance
*/
public static RegisteredBean ofInnerBean(RegisteredBean parent,
BeanDefinition innerBeanDefinition) {
public static RegisteredBean ofInnerBean(RegisteredBean parent, BeanDefinition innerBeanDefinition) {
return ofInnerBean(parent, null, innerBeanDefinition);
}
@ -120,10 +114,9 @@ public final class RegisteredBean {
Assert.notNull(parent, "'parent' must not be null");
Assert.notNull(innerBeanDefinition, "'innerBeanDefinition' must not be null");
InnerBeanResolver resolver = new InnerBeanResolver(parent, innerBeanName,
innerBeanDefinition);
Supplier<String> beanName = StringUtils.hasLength(innerBeanName)
? () -> innerBeanName : resolver::resolveBeanName;
InnerBeanResolver resolver = new InnerBeanResolver(parent, innerBeanName, innerBeanDefinition);
Supplier<String> beanName = (StringUtils.hasLength(innerBeanName) ?
() -> innerBeanName : resolver::resolveBeanName);
return new RegisteredBean(parent.getBeanFactory(), beanName,
innerBeanName == null, resolver::resolveMergedBeanDefinition, parent);
}
@ -195,6 +188,16 @@ public final class RegisteredBean {
return this.parent;
}
/**
* Resolve the constructor or factory method to use for this bean.
* @return the {@link java.lang.reflect.Constructor} or {@link java.lang.reflect.Method}
*/
public Executable resolveConstructorOrFactoryMethod() {
return new ConstructorResolver((AbstractAutowireCapableBeanFactory) getBeanFactory())
.resolveConstructorOrFactoryMethod(getBeanName(), getMergedBeanDefinition());
}
@Override
public String toString() {
return new ToStringCreator(this).append("beanName", getBeanName())

View File

@ -418,6 +418,9 @@ public class RootBeanDefinition extends AbstractBeanDefinition {
*/
public void setResolvedFactoryMethod(@Nullable Method method) {
this.factoryMethodToIntrospect = method;
if (method != null) {
setUniqueFactoryMethodName(method.getName());
}
}
/**

View File

@ -294,9 +294,9 @@ class InstanceSupplierCodeGeneratorTests {
return (T) beanFactory.getBean("testBean");
}
private void compile(DefaultListableBeanFactory beanFactory,
BeanDefinition beanDefinition,
private void compile(DefaultListableBeanFactory beanFactory, BeanDefinition beanDefinition,
BiConsumer<InstanceSupplier<?>, Compiled> result) {
DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(beanFactory);
freshBeanFactory.registerBeanDefinition("testBean", beanDefinition);
RegisteredBean registeredBean = RegisteredBean.of(freshBeanFactory, "testBean");
@ -305,7 +305,7 @@ class InstanceSupplierCodeGeneratorTests {
InstanceSupplierCodeGenerator generator = new InstanceSupplierCodeGenerator(
this.generationContext, generateClass.getName(),
generateClass.getMethods(), false);
Executable constructorOrFactoryMethod = ConstructorOrFactoryMethodResolver.resolve(registeredBean);
Executable constructorOrFactoryMethod = registeredBean.resolveConstructorOrFactoryMethod();
assertThat(constructorOrFactoryMethod).isNotNull();
CodeBlock generatedCode = generator.generateCode(registeredBean, constructorOrFactoryMethod);
typeBuilder.set(type -> {

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package org.springframework.beans.factory.aot;
package org.springframework.beans.factory.support;
import java.lang.annotation.Annotation;
import java.lang.reflect.Executable;
@ -26,9 +26,6 @@ import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolder;
import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolderFactoryBean;
import org.springframework.beans.testfixture.beans.factory.generator.factory.SampleFactory;
@ -41,12 +38,12 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
/**
* Tests for {@link ConstructorOrFactoryMethodResolver}.
* Tests for AOT constructor and factory method resolution.
*
* @author Stephane Nicoll
* @author Phillip Webb
*/
class ConstructorOrFactoryMethodResolverTests {
class ConstructorAndFactoryMethodResolutionTests {
@Test
void detectBeanInstanceExecutableWithBeanClassAndFactoryMethodName() {
@ -125,21 +122,6 @@ class ConstructorOrFactoryMethodResolverTests {
.getDeclaredConstructor(Number.class, String.class));
}
@Test
void genericBeanDefinitionWithConstructorArgsForMultipleConstructors()
throws Exception {
DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
beanFactory.registerSingleton("testNumber", 1L);
beanFactory.registerSingleton("testBean", "test");
BeanDefinition beanDefinition = BeanDefinitionBuilder
.genericBeanDefinition(SampleBeanWithConstructors.class)
.addConstructorArgReference("testNumber")
.addConstructorArgReference("testBean").getBeanDefinition();
Executable executable = resolve(beanFactory, beanDefinition);
assertThat(executable).isNotNull().isEqualTo(SampleBeanWithConstructors.class
.getDeclaredConstructor(Number.class, String.class));
}
@Test
void beanDefinitionWithMultiArgConstructorAndMatchingValue() throws NoSuchMethodException {
BeanDefinition beanDefinition = BeanDefinitionBuilder
@ -341,7 +323,8 @@ class ConstructorOrFactoryMethodResolverTests {
private Executable resolve(DefaultListableBeanFactory beanFactory, BeanDefinition beanDefinition) {
return new ConstructorOrFactoryMethodResolver(beanFactory).resolve(beanDefinition);
return new ConstructorResolver(beanFactory).resolveConstructorOrFactoryMethod(
"testBean", (RootBeanDefinition) beanDefinition);
}