diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java index dbd44b1b0ae..fce5383e7c1 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -43,6 +44,7 @@ import org.springframework.beans.PropertyValue; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.beans.factory.config.ConstructorArgumentValues; import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.AutowireCandidateQualifier; @@ -168,16 +170,38 @@ class BeanDefinitionPropertiesCodeGenerator { } private void addConstructorArgumentValues(CodeBlock.Builder code, BeanDefinition beanDefinition) { - Map argumentValues = - beanDefinition.getConstructorArgumentValues().getIndexedArgumentValues(); - if (!argumentValues.isEmpty()) { - argumentValues.forEach((index, valueHolder) -> { + ConstructorArgumentValues constructorValues = beanDefinition.getConstructorArgumentValues(); + Map indexedValues = constructorValues.getIndexedArgumentValues(); + if (!indexedValues.isEmpty()) { + indexedValues.forEach((index, valueHolder) -> { CodeBlock valueCode = generateValue(valueHolder.getName(), valueHolder.getValue()); code.addStatement( "$L.getConstructorArgumentValues().addIndexedArgumentValue($L, $L)", BEAN_DEFINITION_VARIABLE, index, valueCode); }); } + List genericValues = constructorValues.getGenericArgumentValues(); + if (!genericValues.isEmpty()) { + genericValues.forEach(valueHolder -> { + String valueName = valueHolder.getName(); + CodeBlock valueCode = generateValue(valueName, valueHolder.getValue()); + if (valueName != null) { + CodeBlock valueTypeCode = this.valueCodeGenerator.generateCode(valueHolder.getType()); + code.addStatement( + "$L.getConstructorArgumentValues().addGenericArgumentValue(new $T($L, $L, $S))", + BEAN_DEFINITION_VARIABLE, ValueHolder.class, valueCode, valueTypeCode, valueName); + } + else if (valueHolder.getType() != null) { + code.addStatement("$L.getConstructorArgumentValues().addGenericArgumentValue($L, $S)", + BEAN_DEFINITION_VARIABLE, valueCode, valueHolder.getType()); + + } + else { + code.addStatement("$L.getConstructorArgumentValues().addGenericArgumentValue($L)", + BEAN_DEFINITION_VARIABLE, valueCode); + } + }); + } } private void addPropertyValues(CodeBlock.Builder code, RootBeanDefinition beanDefinition) { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java index 9d959f2b313..2be49bcf516 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java @@ -20,8 +20,11 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Executable; import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.lang.reflect.Parameter; import java.util.Arrays; +import java.util.HashSet; import java.util.LinkedHashSet; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -248,7 +251,7 @@ public final class BeanInstanceSupplier extends AutowiredElementResolver impl Assert.isTrue(this.shortcuts == null || this.shortcuts.length == resolved.length, () -> "'shortcuts' must contain " + resolved.length + " elements"); - ConstructorArgumentValues argumentValues = resolveArgumentValues(registeredBean); + ValueHolder[] argumentValues = resolveArgumentValues(registeredBean, executable); Set autowiredBeanNames = new LinkedHashSet<>(resolved.length * 2); for (int i = startIndex; i < parameterCount; i++) { MethodParameter parameter = getMethodParameter(executable, i); @@ -257,8 +260,9 @@ public final class BeanInstanceSupplier extends AutowiredElementResolver impl if (shortcut != null) { descriptor = new ShortcutDependencyDescriptor(descriptor, shortcut); } - ValueHolder argumentValue = argumentValues.getIndexedArgumentValue(i, null); - resolved[i - startIndex] = resolveArgument(registeredBean, descriptor, argumentValue, autowiredBeanNames); + ValueHolder argumentValue = argumentValues[i]; + resolved[i - startIndex] = resolveAutowiredArgument( + registeredBean, descriptor, argumentValue, autowiredBeanNames); } registerDependentBeans(registeredBean.getBeanFactory(), registeredBean.getBeanName(), autowiredBeanNames); @@ -275,22 +279,44 @@ public final class BeanInstanceSupplier extends AutowiredElementResolver impl throw new IllegalStateException("Unsupported executable: " + executable.getClass().getName()); } - private ConstructorArgumentValues resolveArgumentValues(RegisteredBean registeredBean) { - ConstructorArgumentValues resolved = new ConstructorArgumentValues(); + private ValueHolder[] resolveArgumentValues(RegisteredBean registeredBean, Executable executable) { + Parameter[] parameters = executable.getParameters(); + ValueHolder[] resolved = new ValueHolder[parameters.length]; RootBeanDefinition beanDefinition = registeredBean.getMergedBeanDefinition(); if (beanDefinition.hasConstructorArgumentValues() && registeredBean.getBeanFactory() instanceof AbstractAutowireCapableBeanFactory beanFactory) { BeanDefinitionValueResolver valueResolver = new BeanDefinitionValueResolver( beanFactory, registeredBean.getBeanName(), beanDefinition, beanFactory.getTypeConverter()); - ConstructorArgumentValues values = beanDefinition.getConstructorArgumentValues(); - values.getIndexedArgumentValues().forEach((index, valueHolder) -> { - ValueHolder resolvedValue = resolveArgumentValue(valueResolver, valueHolder); - resolved.addIndexedArgumentValue(index, resolvedValue); - }); + ConstructorArgumentValues values = resolveConstructorArguments( + valueResolver, beanDefinition.getConstructorArgumentValues()); + Set usedValueHolders = new HashSet<>(parameters.length); + for (int i = 0; i < parameters.length; i++) { + Class parameterType = parameters[i].getType(); + String parameterName = (parameters[i].isNamePresent() ? parameters[i].getName() : null); + ValueHolder valueHolder = values.getArgumentValue( + i, parameterType, parameterName, usedValueHolders); + if (valueHolder != null) { + resolved[i] = valueHolder; + usedValueHolders.add(valueHolder); + } + } } return resolved; } + private ConstructorArgumentValues resolveConstructorArguments( + BeanDefinitionValueResolver valueResolver, ConstructorArgumentValues constructorArguments) { + + ConstructorArgumentValues resolvedConstructorArguments = new ConstructorArgumentValues(); + for (Map.Entry entry : constructorArguments.getIndexedArgumentValues().entrySet()) { + resolvedConstructorArguments.addIndexedArgumentValue(entry.getKey(), resolveArgumentValue(valueResolver, entry.getValue())); + } + for (ConstructorArgumentValues.ValueHolder valueHolder : constructorArguments.getGenericArgumentValues()) { + resolvedConstructorArguments.addGenericArgumentValue(resolveArgumentValue(valueResolver, valueHolder)); + } + return resolvedConstructorArguments; + } + private ValueHolder resolveArgumentValue(BeanDefinitionValueResolver resolver, ValueHolder valueHolder) { if (valueHolder.isConverted()) { return valueHolder; @@ -302,7 +328,7 @@ public final class BeanInstanceSupplier extends AutowiredElementResolver impl } @Nullable - private Object resolveArgument(RegisteredBean registeredBean, DependencyDescriptor descriptor, + private Object resolveAutowiredArgument(RegisteredBean registeredBean, DependencyDescriptor descriptor, @Nullable ValueHolder argumentValue, Set autowiredBeanNames) { TypeConverter typeConverter = registeredBean.getBeanFactory().getTypeConverter(); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/ConstructorResolver.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/ConstructorResolver.java index 1a32e7000b2..c4ecb81c6aa 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/ConstructorResolver.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/ConstructorResolver.java @@ -61,6 +61,7 @@ import org.springframework.beans.factory.config.ConstructorArgumentValues; import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; import org.springframework.beans.factory.config.DependencyDescriptor; import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.config.TypedStringValue; import org.springframework.core.CollectionFactory; import org.springframework.core.MethodParameter; import org.springframework.core.NamedThreadLocal; @@ -999,6 +1000,9 @@ class ConstructorResolver { for (ValueHolder valueHolder : mbd.getConstructorArgumentValues().getIndexedArgumentValues().values()) { parameterTypes.add(determineParameterValueType(mbd, valueHolder)); } + for (ValueHolder valueHolder : mbd.getConstructorArgumentValues().getGenericArgumentValues()) { + parameterTypes.add(determineParameterValueType(mbd, valueHolder)); + } return parameterTypes; } @@ -1023,6 +1027,12 @@ class ConstructorResolver { return (FactoryBean.class.isAssignableFrom(type.toClass()) ? type.as(FactoryBean.class).getGeneric(0) : type); } + if (value instanceof TypedStringValue typedValue) { + if (typedValue.hasTargetType()) { + return ResolvableType.forClass(typedValue.getTargetType()); + } + return ResolvableType.forClass(String.class); + } if (value instanceof Class clazz) { return ResolvableType.forClassWithGenerics(Class.class, clazz); } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java index 8cdf923bb69..8c736e25edf 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java @@ -40,6 +40,7 @@ import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanReference; +import org.springframework.beans.factory.config.ConstructorArgumentValues; import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; import org.springframework.beans.factory.config.RuntimeBeanNameReference; import org.springframework.beans.factory.config.RuntimeBeanReference; @@ -219,18 +220,49 @@ class BeanDefinitionPropertiesCodeGeneratorTests { } @Test - void constructorArgumentValuesWhenValues() { + void constructorArgumentValuesWhenIndexedValues() { this.beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, String.class); this.beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(1, "test"); this.beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(2, 123); compile((actual, compiled) -> { - Map values = actual.getConstructorArgumentValues().getIndexedArgumentValues(); - assertThat(values.get(0).getValue()).isEqualTo(String.class); - assertThat(values.get(1).getValue()).isEqualTo("test"); - assertThat(values.get(2).getValue()).isEqualTo(123); + ConstructorArgumentValues argumentValues = actual.getConstructorArgumentValues(); + Map values = argumentValues.getIndexedArgumentValues(); + assertThat(values.get(0)).satisfies(assertValueHolder(String.class, null, null)); + assertThat(values.get(1)).satisfies(assertValueHolder("test", null, null)); + assertThat(values.get(2)).satisfies(assertValueHolder(123, null, null)); + assertThat(values).hasSize(3); + assertThat(argumentValues.getGenericArgumentValues()).isEmpty(); }); } + @Test + void constructorArgumentValuesWhenGenericValuesWithName() { + this.beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(String.class); + this.beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(2, Long.class.getName()); + this.beanDefinition.getConstructorArgumentValues().addGenericArgumentValue( + new ValueHolder("value", null, "param1")); + this.beanDefinition.getConstructorArgumentValues().addGenericArgumentValue( + new ValueHolder("another", CharSequence.class.getName(), "param2")); + compile((actual, compiled) -> { + ConstructorArgumentValues argumentValues = actual.getConstructorArgumentValues(); + List values = argumentValues.getGenericArgumentValues(); + assertThat(values.get(0)).satisfies(assertValueHolder(String.class, null, null)); + assertThat(values.get(1)).satisfies(assertValueHolder(2, Long.class, null)); + assertThat(values.get(2)).satisfies(assertValueHolder("value", null, "param1")); + assertThat(values.get(3)).satisfies(assertValueHolder("another", CharSequence.class, "param2")); + assertThat(values).hasSize(4); + assertThat(argumentValues.getIndexedArgumentValues()).isEmpty(); + }); + } + + private Consumer assertValueHolder(Object value, @Nullable Class type, @Nullable String name) { + return valueHolder -> { + assertThat(valueHolder.getValue()).isEqualTo(value); + assertThat(valueHolder.getType()).isEqualTo((type != null ? type.getName() : null)); + assertThat(valueHolder.getName()).isEqualTo(name); + }; + } + @Test void propertyValuesWhenValues() { this.beanDefinition.setTargetType(PropertyValuesBean.class); diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java index 422de9d4420..6d8da260546 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java @@ -444,7 +444,7 @@ class BeanInstanceSupplierTests { } @ParameterizedResolverTest(Sources.MIXED_ARGS) - void resolveArgumentsWithMixedArgsConstructorWithUserValue(Source source) { + void resolveArgumentsWithMixedArgsConstructorWithIndexedUserValue(Source source) { ResourceLoader resourceLoader = new DefaultResourceLoader(); Environment environment = mock(); this.beanFactory.registerResolvableDependency(ResourceLoader.class, @@ -465,7 +465,28 @@ class BeanInstanceSupplierTests { } @ParameterizedResolverTest(Sources.MIXED_ARGS) - void resolveArgumentsWithMixedArgsConstructorWithUserBeanReference(Source source) { + void resolveArgumentsWithMixedArgsConstructorWithGenericUserValue(Source source) { + ResourceLoader resourceLoader = new DefaultResourceLoader(); + Environment environment = mock(); + this.beanFactory.registerResolvableDependency(ResourceLoader.class, + resourceLoader); + this.beanFactory.registerSingleton("environment", environment); + RegisteredBean registerBean = source.registerBean(this.beanFactory, + beanDefinition -> { + beanDefinition + .setAutowireMode(AbstractBeanDefinition.AUTOWIRE_CONSTRUCTOR); + beanDefinition.getConstructorArgumentValues() + .addGenericArgumentValue("user-value"); + }); + AutowiredArguments arguments = source.getResolver().resolveArguments(registerBean); + assertThat(arguments.toArray()).hasSize(3); + assertThat(arguments.getObject(0)).isEqualTo(resourceLoader); + assertThat(arguments.getObject(1)).isEqualTo("user-value"); + assertThat(arguments.getObject(2)).isEqualTo(environment); + } + + @ParameterizedResolverTest(Sources.MIXED_ARGS) + void resolveArgumentsWithMixedArgsConstructorAndIndexedUserBeanReference(Source source) { ResourceLoader resourceLoader = new DefaultResourceLoader(); Environment environment = mock(); this.beanFactory.registerResolvableDependency(ResourceLoader.class, @@ -487,8 +508,31 @@ class BeanInstanceSupplierTests { assertThat(arguments.getObject(2)).isEqualTo(environment); } + @ParameterizedResolverTest(Sources.MIXED_ARGS) + void resolveArgumentsWithMixedArgsConstructorAndGenericUserBeanReference(Source source) { + ResourceLoader resourceLoader = new DefaultResourceLoader(); + Environment environment = mock(); + this.beanFactory.registerResolvableDependency(ResourceLoader.class, + resourceLoader); + this.beanFactory.registerSingleton("environment", environment); + this.beanFactory.registerSingleton("one", "1"); + this.beanFactory.registerSingleton("two", "2"); + RegisteredBean registerBean = source.registerBean(this.beanFactory, + beanDefinition -> { + beanDefinition + .setAutowireMode(AbstractBeanDefinition.AUTOWIRE_CONSTRUCTOR); + beanDefinition.getConstructorArgumentValues() + .addGenericArgumentValue(new RuntimeBeanReference("two")); + }); + AutowiredArguments arguments = source.getResolver().resolveArguments(registerBean); + assertThat(arguments.toArray()).hasSize(3); + assertThat(arguments.getObject(0)).isEqualTo(resourceLoader); + assertThat(arguments.getObject(1)).isEqualTo("2"); + assertThat(arguments.getObject(2)).isEqualTo(environment); + } + @Test - void resolveArgumentsWithUserValueWithTypeConversionRequired() { + void resolveIndexedArgumentsWithUserValueWithTypeConversionRequired() { Source source = new Source(CharDependency.class, BeanInstanceSupplier.forConstructor(char.class)); RegisteredBean registerBean = source.registerBean(this.beanFactory, @@ -503,8 +547,24 @@ class BeanInstanceSupplierTests { assertThat(arguments.getObject(0)).isInstanceOf(Character.class).isEqualTo('\\'); } + @Test + void resolveGenericArgumentsWithUserValueWithTypeConversionRequired() { + Source source = new Source(CharDependency.class, + BeanInstanceSupplier.forConstructor(char.class)); + RegisteredBean registerBean = source.registerBean(this.beanFactory, + beanDefinition -> { + beanDefinition + .setAutowireMode(AbstractBeanDefinition.AUTOWIRE_CONSTRUCTOR); + beanDefinition.getConstructorArgumentValues() + .addGenericArgumentValue("\\", char.class.getName()); + }); + AutowiredArguments arguments = source.getResolver().resolveArguments(registerBean); + assertThat(arguments.toArray()).hasSize(1); + assertThat(arguments.getObject(0)).isInstanceOf(Character.class).isEqualTo('\\'); + } + @ParameterizedResolverTest(Sources.SINGLE_ARG) - void resolveArgumentsWithUserValueWithBeanReference(Source source) { + void resolveIndexedArgumentsWithUserValueWithBeanReference(Source source) { this.beanFactory.registerSingleton("stringBean", "string"); RegisteredBean registerBean = source.registerBean(this.beanFactory, beanDefinition -> beanDefinition.getConstructorArgumentValues() @@ -516,7 +576,18 @@ class BeanInstanceSupplierTests { } @ParameterizedResolverTest(Sources.SINGLE_ARG) - void resolveArgumentsWithUserValueWithBeanDefinition(Source source) { + void resolveGenericArgumentsWithUserValueWithBeanReference(Source source) { + this.beanFactory.registerSingleton("stringBean", "string"); + RegisteredBean registerBean = source.registerBean(this.beanFactory, + beanDefinition -> beanDefinition.getConstructorArgumentValues() + .addGenericArgumentValue(new RuntimeBeanReference("stringBean"))); + AutowiredArguments arguments = source.getResolver().resolveArguments(registerBean); + assertThat(arguments.toArray()).hasSize(1); + assertThat(arguments.getObject(0)).isEqualTo("string"); + } + + @ParameterizedResolverTest(Sources.SINGLE_ARG) + void resolveIndexedArgumentsWithUserValueWithBeanDefinition(Source source) { AbstractBeanDefinition userValue = BeanDefinitionBuilder .rootBeanDefinition(String.class, () -> "string").getBeanDefinition(); RegisteredBean registerBean = source.registerBean(this.beanFactory, @@ -528,11 +599,23 @@ class BeanInstanceSupplierTests { } @ParameterizedResolverTest(Sources.SINGLE_ARG) - void resolveArgumentsWithUserValueThatIsAlreadyResolved(Source source) { + void resolveGenericArgumentsWithUserValueWithBeanDefinition(Source source) { + AbstractBeanDefinition userValue = BeanDefinitionBuilder + .rootBeanDefinition(String.class, () -> "string").getBeanDefinition(); + RegisteredBean registerBean = source.registerBean(this.beanFactory, + beanDefinition -> beanDefinition.getConstructorArgumentValues() + .addGenericArgumentValue(userValue)); + AutowiredArguments arguments = source.getResolver().resolveArguments(registerBean); + assertThat(arguments.toArray()).hasSize(1); + assertThat(arguments.getObject(0)).isEqualTo("string"); + } + + @ParameterizedResolverTest(Sources.SINGLE_ARG) + void resolveIndexedArgumentsWithUserValueThatIsAlreadyResolved(Source source) { RegisteredBean registerBean = source.registerBean(this.beanFactory); BeanDefinition mergedBeanDefinition = this.beanFactory .getMergedBeanDefinition("testBean"); - ValueHolder valueHolder = new ValueHolder('a'); + ValueHolder valueHolder = new ValueHolder("a"); valueHolder.setConvertedValue("this is an a"); mergedBeanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, valueHolder); @@ -541,6 +624,19 @@ class BeanInstanceSupplierTests { assertThat(arguments.getObject(0)).isEqualTo("this is an a"); } + @ParameterizedResolverTest(Sources.SINGLE_ARG) + void resolveGenericArgumentsWithUserValueThatIsAlreadyResolved(Source source) { + RegisteredBean registerBean = source.registerBean(this.beanFactory); + BeanDefinition mergedBeanDefinition = this.beanFactory + .getMergedBeanDefinition("testBean"); + ValueHolder valueHolder = new ValueHolder("a"); + valueHolder.setConvertedValue("this is an a"); + mergedBeanDefinition.getConstructorArgumentValues().addGenericArgumentValue(valueHolder); + AutowiredArguments arguments = source.getResolver().resolveArguments(registerBean); + assertThat(arguments.toArray()).hasSize(1); + assertThat(arguments.getObject(0)).isEqualTo("this is an a"); + } + @Test void resolveArgumentsWhenUsingShortcutsInjectsDirectly() { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory() { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/support/ConstructorResolverAotTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/support/ConstructorResolverAotTests.java index da6b27d3609..f40bbfecbe1 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/support/ConstructorResolverAotTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/support/ConstructorResolverAotTests.java @@ -26,6 +26,7 @@ import org.junit.jupiter.api.Test; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.TypedStringValue; import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolder; import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolderFactoryBean; import org.springframework.beans.testfixture.beans.factory.generator.factory.SampleFactory; @@ -72,7 +73,7 @@ class ConstructorResolverAotTests { } @Test - void beanDefinitionWithFactoryMethodNameAndAssignableConstructorArg() { + void beanDefinitionWithFactoryMethodNameAndAssignableIndexedConstructorArgs() { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); beanFactory.registerSingleton("testNumber", 1L); beanFactory.registerSingleton("testBean", "test"); @@ -85,6 +86,34 @@ class ConstructorResolverAotTests { .findMethod(SampleFactory.class, "create", Number.class, String.class)); } + @Test + void beanDefinitionWithFactoryMethodNameAndAssignableGenericConstructorArgs() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + AbstractBeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(SampleFactory.class).setFactoryMethod("create") + .getBeanDefinition(); + beanDefinition.getConstructorArgumentValues().addGenericArgumentValue("test"); + beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(1L); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo(ReflectionUtils + .findMethod(SampleFactory.class, "create", Number.class, String.class)); + } + + @Test + void beanDefinitionWithFactoryMethodNameAndAssignableTypeStringValues() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + AbstractBeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(SampleFactory.class).setFactoryMethod("create") + .getBeanDefinition(); + beanDefinition.getConstructorArgumentValues() + .addGenericArgumentValue(new TypedStringValue("test")); + beanDefinition.getConstructorArgumentValues() + .addGenericArgumentValue(new TypedStringValue("1", Integer.class)); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo(ReflectionUtils + .findMethod(SampleFactory.class, "create", Number.class, String.class)); + } + @Test void beanDefinitionWithFactoryMethodNameAndMatchingMethodNames() { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); @@ -122,7 +151,7 @@ class ConstructorResolverAotTests { } @Test - void beanDefinitionWithConstructorArgsForMultipleConstructors() throws Exception { + void beanDefinitionWithIndexedConstructorArgsForMultipleConstructors() throws Exception { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); beanFactory.registerSingleton("testNumber", 1L); beanFactory.registerSingleton("testBean", "test"); @@ -136,7 +165,22 @@ class ConstructorResolverAotTests { } @Test - void beanDefinitionWithMultiArgConstructorAndMatchingValue() throws NoSuchMethodException { + void beanDefinitionWithGenericConstructorArgsForMultipleConstructors() throws Exception { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerSingleton("testNumber", 1L); + beanFactory.registerSingleton("testBean", "test"); + AbstractBeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(SampleBeanWithConstructors.class) + .getBeanDefinition(); + beanDefinition.getConstructorArgumentValues().addGenericArgumentValue("test"); + beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(1L); + Executable executable = resolve(beanFactory, beanDefinition); + assertThat(executable).isNotNull().isEqualTo(SampleBeanWithConstructors.class + .getDeclaredConstructor(Number.class, String.class)); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndMatchingIndexedValue() throws NoSuchMethodException { BeanDefinition beanDefinition = BeanDefinitionBuilder .rootBeanDefinition(MultiConstructorSample.class) .addConstructorArgValue(42).getBeanDefinition(); @@ -146,7 +190,18 @@ class ConstructorResolverAotTests { } @Test - void beanDefinitionWithMultiArgConstructorAndMatchingArrayValue() throws NoSuchMethodException { + void beanDefinitionWithMultiArgConstructorAndMatchingGenericValue() throws NoSuchMethodException { + AbstractBeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorSample.class) + .getBeanDefinition(); + beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(42); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + MultiConstructorSample.class.getDeclaredConstructor(Integer.class)); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndMatchingArrayFromIndexedValue() throws NoSuchMethodException { BeanDefinition beanDefinition = BeanDefinitionBuilder .rootBeanDefinition(MultiConstructorArraySample.class) .addConstructorArgValue(42).getBeanDefinition(); @@ -156,7 +211,18 @@ class ConstructorResolverAotTests { } @Test - void beanDefinitionWithMultiArgConstructorAndMatchingListValue() throws NoSuchMethodException { + void beanDefinitionWithMultiArgConstructorAndMatchingArrayFromGenericValue() throws NoSuchMethodException { + AbstractBeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorArraySample.class) + .getBeanDefinition(); + beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(42); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isNotNull().isEqualTo(MultiConstructorArraySample.class + .getDeclaredConstructor(Integer[].class)); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndMatchingListFromIndexedValue() throws NoSuchMethodException { BeanDefinition beanDefinition = BeanDefinitionBuilder .rootBeanDefinition(MultiConstructorListSample.class) .addConstructorArgValue(42).getBeanDefinition(); @@ -166,7 +232,18 @@ class ConstructorResolverAotTests { } @Test - void beanDefinitionWithMultiArgConstructorAndMatchingValueAsInnerBean() throws NoSuchMethodException { + void beanDefinitionWithMultiArgConstructorAndMatchingListFromGenericValue() throws NoSuchMethodException { + AbstractBeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorListSample.class) + .getBeanDefinition(); + beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(42); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + MultiConstructorListSample.class.getDeclaredConstructor(List.class)); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndMatchingIndexedValueAsInnerBean() throws NoSuchMethodException { BeanDefinition beanDefinition = BeanDefinitionBuilder .rootBeanDefinition(MultiConstructorSample.class) .addConstructorArgValue( @@ -179,7 +256,20 @@ class ConstructorResolverAotTests { } @Test - void beanDefinitionWithMultiArgConstructorAndMatchingValueAsInnerBeanFactory() throws NoSuchMethodException { + void beanDefinitionWithMultiArgConstructorAndMatchingGenericValueAsInnerBean() throws NoSuchMethodException { + AbstractBeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorSample.class) + .getBeanDefinition(); + beanDefinition.getConstructorArgumentValues().addGenericArgumentValue( + BeanDefinitionBuilder.rootBeanDefinition(Integer.class, "valueOf") + .addConstructorArgValue("42").getBeanDefinition()); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + MultiConstructorSample.class.getDeclaredConstructor(Integer.class)); + } + + @Test + void beanDefinitionWithMultiArgConstructorAndMatchingIndexedValueAsInnerBeanFactory() throws NoSuchMethodException { BeanDefinition beanDefinition = BeanDefinitionBuilder .rootBeanDefinition(MultiConstructorSample.class) .addConstructorArgValue(BeanDefinitionBuilder @@ -190,6 +280,18 @@ class ConstructorResolverAotTests { MultiConstructorSample.class.getDeclaredConstructor(Integer.class)); } + @Test + void beanDefinitionWithMultiArgConstructorAndMatchingGenericValueAsInnerBeanFactory() throws NoSuchMethodException { + AbstractBeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(MultiConstructorSample.class) + .getBeanDefinition(); + beanDefinition.getConstructorArgumentValues().addGenericArgumentValue( + BeanDefinitionBuilder.rootBeanDefinition(IntegerFactoryBean.class).getBeanDefinition()); + Executable executable = resolve(new DefaultListableBeanFactory(), beanDefinition); + assertThat(executable).isNotNull().isEqualTo( + MultiConstructorSample.class.getDeclaredConstructor(Integer.class)); + } + @Test void beanDefinitionWithMultiArgConstructorAndNonMatchingValue() { BeanDefinition beanDefinition = BeanDefinitionBuilder diff --git a/spring-context/src/main/java/org/springframework/context/support/PostProcessorRegistrationDelegate.java b/spring-context/src/main/java/org/springframework/context/support/PostProcessorRegistrationDelegate.java index 8d4a29c6bca..7735045a980 100644 --- a/spring-context/src/main/java/org/springframework/context/support/PostProcessorRegistrationDelegate.java +++ b/spring-context/src/main/java/org/springframework/context/support/PostProcessorRegistrationDelegate.java @@ -30,6 +30,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.beans.PropertyValue; import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.BeanDefinitionHolder; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.BeanPostProcessor; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; @@ -479,26 +480,32 @@ final class PostProcessorRegistrationDelegate { BeanDefinitionValueResolver valueResolver = new BeanDefinitionValueResolver(this.beanFactory, beanName, bd); postProcessors.forEach(postProcessor -> postProcessor.postProcessMergedBeanDefinition(bd, beanType, beanName)); for (PropertyValue propertyValue : bd.getPropertyValues().getPropertyValueList()) { - Object value = propertyValue.getValue(); - if (value instanceof AbstractBeanDefinition innerBd) { - Class innerBeanType = resolveBeanType(innerBd); - resolveInnerBeanDefinition(valueResolver, innerBd, (innerBeanName, innerBeanDefinition) - -> postProcessRootBeanDefinition(postProcessors, innerBeanName, innerBeanType, innerBeanDefinition)); - } - if (value instanceof TypedStringValue typedStringValue) { - resolveTypeStringValue(typedStringValue); - } + postProcessValue(postProcessors, valueResolver, propertyValue.getValue()); } for (ValueHolder valueHolder : bd.getConstructorArgumentValues().getIndexedArgumentValues().values()) { - Object value = valueHolder.getValue(); - if (value instanceof AbstractBeanDefinition innerBd) { - Class innerBeanType = resolveBeanType(innerBd); - resolveInnerBeanDefinition(valueResolver, innerBd, (innerBeanName, innerBeanDefinition) - -> postProcessRootBeanDefinition(postProcessors, innerBeanName, innerBeanType, innerBeanDefinition)); - } - if (value instanceof TypedStringValue typedStringValue) { - resolveTypeStringValue(typedStringValue); - } + postProcessValue(postProcessors, valueResolver, valueHolder.getValue()); + } + for (ValueHolder valueHolder : bd.getConstructorArgumentValues().getGenericArgumentValues()) { + postProcessValue(postProcessors, valueResolver, valueHolder.getValue()); + } + } + + private void postProcessValue(List postProcessors, + BeanDefinitionValueResolver valueResolver, @Nullable Object value) { + if (value instanceof BeanDefinitionHolder bdh + && bdh.getBeanDefinition() instanceof AbstractBeanDefinition innerBd) { + + Class innerBeanType = resolveBeanType(innerBd); + resolveInnerBeanDefinition(valueResolver, innerBd, (innerBeanName, innerBeanDefinition) + -> postProcessRootBeanDefinition(postProcessors, innerBeanName, innerBeanType, innerBeanDefinition)); + } + else if (value instanceof AbstractBeanDefinition innerBd) { + Class innerBeanType = resolveBeanType(innerBd); + resolveInnerBeanDefinition(valueResolver, innerBd, (innerBeanName, innerBeanDefinition) + -> postProcessRootBeanDefinition(postProcessors, innerBeanName, innerBeanType, innerBeanDefinition)); + } + else if (value instanceof TypedStringValue typedStringValue) { + resolveTypeStringValue(typedStringValue); } } diff --git a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java index ea8125c319e..65a35a46f5a 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java @@ -459,8 +459,10 @@ class ApplicationContextAotGeneratorTests { assertThat(employee.getName()).isEqualTo("John Smith"); assertThat(employee.getAge()).isEqualTo(42); assertThat(employee.getCompany()).isEqualTo("Acme Widgets, Inc."); - assertThat(freshApplicationContext.getBean("pet", Pet.class) + assertThat(freshApplicationContext.getBean("petIndexed", Pet.class) .getName()).isEqualTo("Fido"); + assertThat(freshApplicationContext.getBean("petGeneric", Pet.class) + .getName()).isEqualTo("Dofi"); }); } @@ -496,6 +498,20 @@ class ApplicationContextAotGeneratorTests { }); } + @Test + void processAheadOfTimeWhenXmlHasBeanReferences() { + GenericXmlApplicationContext applicationContext = new GenericXmlApplicationContext(); + applicationContext + .load(new ClassPathResource("applicationContextAotGeneratorTests-references.xml", getClass())); + testCompiledResult(applicationContext, (initializer, compiled) -> { + GenericApplicationContext freshApplicationContext = toFreshApplicationContext(initializer); + assertThat(freshApplicationContext.getBean("petInnerBean", Pet.class) + .getName()).isEqualTo("Fido"); + assertThat(freshApplicationContext.getBean("petRefBean", Pet.class) + .getName()).isEqualTo("Dofi"); + }); + } + } private Consumer> doesNotHaveProxyFor(Class target) { diff --git a/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java b/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java index 731da7c078d..baa0bb1e161 100644 --- a/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java +++ b/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java @@ -331,7 +331,7 @@ class GenericApplicationContextTests { } @Test - void refreshForAotLoadsBeanClassNameOfConstructorArgumentInnerBeanDefinition() { + void refreshForAotLoadsBeanClassNameOfIndexedConstructorArgumentInnerBeanDefinition() { GenericApplicationContext context = new GenericApplicationContext(); RootBeanDefinition beanDefinition = new RootBeanDefinition(String.class); GenericBeanDefinition innerBeanDefinition = new GenericBeanDefinition(); @@ -347,6 +347,23 @@ class GenericApplicationContextTests { context.close(); } + @Test + void refreshForAotLoadsBeanClassNameOfGenericConstructorArgumentInnerBeanDefinition() { + GenericApplicationContext context = new GenericApplicationContext(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(String.class); + GenericBeanDefinition innerBeanDefinition = new GenericBeanDefinition(); + innerBeanDefinition.setBeanClassName("java.lang.Integer"); + beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(innerBeanDefinition); + context.registerBeanDefinition("test",beanDefinition); + context.refreshForAotProcessing(new RuntimeHints()); + RootBeanDefinition bd = getBeanDefinition(context, "test"); + GenericBeanDefinition value = (GenericBeanDefinition) bd.getConstructorArgumentValues() + .getGenericArgumentValues().get(0).getValue(); + assertThat(value.hasBeanClass()).isTrue(); + assertThat(value.getBeanClass()).isEqualTo(Integer.class); + context.close(); + } + @Test void refreshForAotLoadsBeanClassNameOfPropertyValueInnerBeanDefinition() { GenericApplicationContext context = new GenericApplicationContext(); @@ -377,7 +394,7 @@ class GenericApplicationContextTests { } @Test - void refreshForAotLoadsTypedStringValueClassNameInConstructorArgument() { + void refreshForAotLoadsTypedStringValueClassNameInIndexedConstructorArgument() { GenericApplicationContext context = new GenericApplicationContext(); RootBeanDefinition beanDefinition = new RootBeanDefinition("java.lang.Integer"); beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, @@ -391,6 +408,21 @@ class GenericApplicationContextTests { context.close(); } + @Test + void refreshForAotLoadsTypedStringValueClassNameInGenericConstructorArgument() { + GenericApplicationContext context = new GenericApplicationContext(); + RootBeanDefinition beanDefinition = new RootBeanDefinition("java.lang.Integer"); + beanDefinition.getConstructorArgumentValues().addGenericArgumentValue( + new TypedStringValue("42", "java.lang.Integer")); + context.registerBeanDefinition("number", beanDefinition); + context.refreshForAotProcessing(new RuntimeHints()); + assertThat(getBeanDefinition(context, "number").getConstructorArgumentValues() + .getGenericArgumentValue(TypedStringValue.class).getValue()) + .isInstanceOfSatisfying(TypedStringValue.class, typeStringValue -> + assertThat(typeStringValue.getTargetType()).isEqualTo(Integer.class)); + context.close(); + } + @Test void refreshForAotInvokesBeanFactoryPostProcessors() { GenericApplicationContext context = new GenericApplicationContext(); @@ -414,7 +446,7 @@ class GenericApplicationContextTests { } @Test - void refreshForAotInvokesMergedBeanDefinitionPostProcessorsOnConstructorArgument() { + void refreshForAotInvokesMergedBeanDefinitionPostProcessorsOnIndexedConstructorArgument() { GenericApplicationContext context = new GenericApplicationContext(); RootBeanDefinition beanDefinition = new RootBeanDefinition(BeanD.class); GenericBeanDefinition innerBeanDefinition = new GenericBeanDefinition(); @@ -430,6 +462,23 @@ class GenericApplicationContextTests { context.close(); } + @Test + void refreshForAotInvokesMergedBeanDefinitionPostProcessorsOnGenericConstructorArgument() { + GenericApplicationContext context = new GenericApplicationContext(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(BeanD.class); + GenericBeanDefinition innerBeanDefinition = new GenericBeanDefinition(); + innerBeanDefinition.setBeanClassName("java.lang.Integer"); + beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(innerBeanDefinition); + context.registerBeanDefinition("test", beanDefinition); + MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context); + context.refreshForAotProcessing(new RuntimeHints()); + ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); + verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), BeanD.class, "test"); + verify(bpp).postProcessMergedBeanDefinition(any(RootBeanDefinition.class), eq(Integer.class), captor.capture()); + assertThat(captor.getValue()).startsWith("(inner bean)"); + context.close(); + } + @Test void refreshForAotInvokesMergedBeanDefinitionPostProcessorsOnPropertyValue() { GenericApplicationContext context = new GenericApplicationContext(); diff --git a/spring-context/src/test/resources/org/springframework/context/aot/applicationContextAotGeneratorTests-references.xml b/spring-context/src/test/resources/org/springframework/context/aot/applicationContextAotGeneratorTests-references.xml new file mode 100644 index 00000000000..992ed41df19 --- /dev/null +++ b/spring-context/src/test/resources/org/springframework/context/aot/applicationContextAotGeneratorTests-references.xml @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/spring-context/src/test/resources/org/springframework/context/aot/applicationContextAotGeneratorTests-values.xml b/spring-context/src/test/resources/org/springframework/context/aot/applicationContextAotGeneratorTests-values.xml index 6adf0ee52ff..964f3c03e7f 100644 --- a/spring-context/src/test/resources/org/springframework/context/aot/applicationContextAotGeneratorTests-values.xml +++ b/spring-context/src/test/resources/org/springframework/context/aot/applicationContextAotGeneratorTests-values.xml @@ -8,8 +8,12 @@ - + + + + +