From 06ef82e9a5aa7f882775b259667461223cfaa8c8 Mon Sep 17 00:00:00 2001 From: Juergen Hoeller Date: Thu, 26 Jun 2025 12:51:55 +0200 Subject: [PATCH] Consistent type-based bean lookup for internal resolution paths Includes additional tests for List/ObjectProvider dependencies. See gh-35101 --- .../support/DefaultListableBeanFactory.java | 33 ++++--- .../support/StaticListableBeanFactory.java | 9 +- .../beans/factory/BeanFactoryUtilsTests.java | 87 +++++++++++++++++-- 3 files changed, 107 insertions(+), 22 deletions(-) diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/DefaultListableBeanFactory.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/DefaultListableBeanFactory.java index 8248d756fd..f59766e14a 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/DefaultListableBeanFactory.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/DefaultListableBeanFactory.java @@ -496,7 +496,7 @@ public class DefaultListableBeanFactory extends AbstractAutowireCapableBeanFacto @Override public Stream stream() { return Arrays.stream(beanNamesForStream(requiredType, true, allowEagerInit)) - .map(name -> (T) getBean(name)) + .map(name -> (T) resolveBean(name, requiredType)) .filter(bean -> !(bean instanceof NullBean)); } @SuppressWarnings("unchecked") @@ -508,7 +508,7 @@ public class DefaultListableBeanFactory extends AbstractAutowireCapableBeanFacto } Map matchingBeans = CollectionUtils.newLinkedHashMap(beanNames.length); for (String beanName : beanNames) { - Object beanInstance = getBean(beanName); + Object beanInstance = resolveBean(beanName, requiredType); if (!(beanInstance instanceof NullBean)) { matchingBeans.put(beanName, (T) beanInstance); } @@ -521,7 +521,7 @@ public class DefaultListableBeanFactory extends AbstractAutowireCapableBeanFacto public Stream stream(Predicate> customFilter, boolean includeNonSingletons) { return Arrays.stream(beanNamesForStream(requiredType, includeNonSingletons, allowEagerInit)) .filter(name -> customFilter.test(getType(name))) - .map(name -> (T) getBean(name)) + .map(name -> (T) resolveBean(name, requiredType)) .filter(bean -> !(bean instanceof NullBean)); } @SuppressWarnings("unchecked") @@ -534,7 +534,7 @@ public class DefaultListableBeanFactory extends AbstractAutowireCapableBeanFacto Map matchingBeans = CollectionUtils.newLinkedHashMap(beanNames.length); for (String beanName : beanNames) { if (customFilter.test(getType(beanName))) { - Object beanInstance = getBean(beanName); + Object beanInstance = resolveBean(beanName, requiredType); if (!(beanInstance instanceof NullBean)) { matchingBeans.put(beanName, (T) beanInstance); } @@ -1207,6 +1207,17 @@ public class DefaultListableBeanFactory extends AbstractAutowireCapableBeanFacto } } + private Object resolveBean(String beanName, ResolvableType requiredType) { + try { + // Need to provide required type for SmartFactoryBean + return getBean(beanName, requiredType.toClass()); + } + catch (BeanNotOfRequiredTypeException ex) { + // Probably a null bean... + return getBean(beanName); + } + } + private static String getThreadNamePrefix() { String name = Thread.currentThread().getName(); int numberSeparator = name.lastIndexOf('-'); @@ -1542,7 +1553,7 @@ public class DefaultListableBeanFactory extends AbstractAutowireCapableBeanFacto Map candidates = CollectionUtils.newLinkedHashMap(candidateNames.length); for (String beanName : candidateNames) { if (containsSingleton(beanName) && args == null) { - Object beanInstance = getBean(beanName); + Object beanInstance = resolveBean(beanName, requiredType); candidates.put(beanName, (beanInstance instanceof NullBean ? null : beanInstance)); } else { @@ -1659,7 +1670,7 @@ public class DefaultListableBeanFactory extends AbstractAutowireCapableBeanFacto if (autowiredBeanNames != null) { autowiredBeanNames.add(dependencyName); } - Object dependencyBean = getBean(dependencyName); + Object dependencyBean = resolveBean(dependencyName, descriptor.getResolvableType()); return resolveInstance(dependencyBean, descriptor, type, dependencyName); } } @@ -2582,16 +2593,18 @@ public class DefaultListableBeanFactory extends AbstractAutowireCapableBeanFacto @Override public Stream stream(Predicate> customFilter, boolean includeNonSingletons) { - return Arrays.stream(beanNamesForStream(this.descriptor.getResolvableType(), includeNonSingletons, true)) + ResolvableType type = this.descriptor.getResolvableType(); + return Arrays.stream(beanNamesForStream(type, includeNonSingletons, true)) .filter(name -> AutowireUtils.isAutowireCandidate(DefaultListableBeanFactory.this, name)) .filter(name -> customFilter.test(getType(name))) - .map(name -> getBean(name)) + .map(name -> resolveBean(name, type)) .filter(bean -> !(bean instanceof NullBean)); } @Override public Stream orderedStream(Predicate> customFilter, boolean includeNonSingletons) { - String[] beanNames = beanNamesForStream(this.descriptor.getResolvableType(), includeNonSingletons, true); + ResolvableType type = this.descriptor.getResolvableType(); + String[] beanNames = beanNamesForStream(type, includeNonSingletons, true); if (beanNames.length == 0) { return Stream.empty(); } @@ -2599,7 +2612,7 @@ public class DefaultListableBeanFactory extends AbstractAutowireCapableBeanFacto for (String beanName : beanNames) { if (AutowireUtils.isAutowireCandidate(DefaultListableBeanFactory.this, beanName) && customFilter.test(getType(beanName))) { - Object beanInstance = getBean(beanName); + Object beanInstance = resolveBean(beanName, type); if (!(beanInstance instanceof NullBean)) { matchingBeans.put(beanName, beanInstance); } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java index a9a9225c21..cfd23ce4e7 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java @@ -307,7 +307,7 @@ public class StaticListableBeanFactory implements ListableBeanFactory { public T getObject() throws BeansException { String[] beanNames = getBeanNamesForType(requiredType); if (beanNames.length == 1) { - return (T) getBean(beanNames[0], requiredType); + return (T) getBean(beanNames[0], requiredType.toClass()); } else if (beanNames.length > 1) { throw new NoUniqueBeanDefinitionException(requiredType, beanNames); @@ -333,7 +333,7 @@ public class StaticListableBeanFactory implements ListableBeanFactory { public @Nullable T getIfAvailable() throws BeansException { String[] beanNames = getBeanNamesForType(requiredType); if (beanNames.length == 1) { - return (T) getBean(beanNames[0]); + return (T) getBean(beanNames[0], requiredType.toClass()); } else if (beanNames.length > 1) { throw new NoUniqueBeanDefinitionException(requiredType, beanNames); @@ -346,7 +346,7 @@ public class StaticListableBeanFactory implements ListableBeanFactory { public @Nullable T getIfUnique() throws BeansException { String[] beanNames = getBeanNamesForType(requiredType); if (beanNames.length == 1) { - return (T) getBean(beanNames[0]); + return (T) getBean(beanNames[0], requiredType.toClass()); } else { return null; @@ -354,7 +354,8 @@ public class StaticListableBeanFactory implements ListableBeanFactory { } @Override public Stream stream() { - return Arrays.stream(getBeanNamesForType(requiredType)).map(name -> (T) getBean(name)); + return Arrays.stream(getBeanNamesForType(requiredType)) + .map(name -> (T) getBean(name, requiredType.toClass())); } }; } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java index bac0ac7f3c..9b7910ce7d 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java @@ -464,8 +464,40 @@ class BeanFactoryUtilsTests { lbf.registerSingleton("fb2", fb2); lbf.registerSingleton("sfb1", sfb1); lbf.registerSingleton("sfb2", sfb2); + lbf.registerBeanDefinition("recipient", + new RootBeanDefinition(Recipient.class, RootBeanDefinition.AUTOWIRE_CONSTRUCTOR, false)); - testSupportsMultipleTypesWithStaticFactory(lbf); + Recipient recipient = lbf.getBean("recipient", Recipient.class); + assertThat(recipient.sfb1).isSameAs(lbf.getBean("sfb1", TestBean.class)); + assertThat(recipient.sfb2).isSameAs(lbf.getBean("sfb2", TestBean.class)); + + List testBeanList = recipient.testBeanList; + assertThat(testBeanList).hasSize(5); + assertThat(testBeanList.get(0)).isSameAs(bean); + assertThat(testBeanList.get(1)).isSameAs(fb1.getObject()); + assertThat(testBeanList.get(2)).isInstanceOf(TestBean.class); + assertThat(testBeanList.get(3)).isSameAs(lbf.getBean("sfb1", TestBean.class)); + assertThat(testBeanList.get(4)).isSameAs(lbf.getBean("sfb2", TestBean.class)); + + List stringList = recipient.stringList; + assertThat(stringList).hasSize(2); + assertThat(stringList.get(0)).isSameAs(lbf.getBean("sfb1", String.class)); + assertThat(stringList.get(1)).isSameAs(lbf.getBean("sfb2", String.class)); + + testBeanList = recipient.testBeanProvider.stream().toList(); + assertThat(testBeanList).hasSize(5); + assertThat(testBeanList.get(0)).isSameAs(bean); + assertThat(testBeanList.get(1)).isSameAs(fb1.getObject()); + assertThat(testBeanList.get(2)).isInstanceOf(TestBean.class); + assertThat(testBeanList.get(3)).isSameAs(lbf.getBean("sfb1", TestBean.class)); + assertThat(testBeanList.get(4)).isSameAs(lbf.getBean("sfb2", TestBean.class)); + + stringList = recipient.stringProvider.stream().toList(); + assertThat(stringList).hasSize(2); + assertThat(stringList.get(0)).isSameAs(lbf.getBean("sfb1", String.class)); + assertThat(stringList.get(1)).isSameAs(lbf.getBean("sfb2", String.class)); + + testSupportsMultipleTypes(lbf); } @Test @@ -483,22 +515,35 @@ class BeanFactoryUtilsTests { lbf.addBean("sfb1", sfb1); lbf.addBean("sfb2", sfb2); - testSupportsMultipleTypesWithStaticFactory(lbf); + testSupportsMultipleTypes(lbf); } - void testSupportsMultipleTypesWithStaticFactory(ListableBeanFactory lbf) { + void testSupportsMultipleTypes(ListableBeanFactory lbf) { + List testBeanList = lbf.getBeanProvider(ITestBean.class).stream().toList(); + assertThat(testBeanList).hasSize(5); + assertThat(testBeanList.get(0)).isSameAs(lbf.getBean("bean", TestBean.class)); + assertThat(testBeanList.get(1)).isSameAs(lbf.getBean("fb1", TestBean.class)); + assertThat(testBeanList.get(2)).isInstanceOf(TestBean.class); + assertThat(testBeanList.get(3)).isSameAs(lbf.getBean("sfb1", TestBean.class)); + assertThat(testBeanList.get(4)).isSameAs(lbf.getBean("sfb2", TestBean.class)); + + List stringList = lbf.getBeanProvider(CharSequence.class).stream().toList(); + assertThat(stringList).hasSize(2); + assertThat(stringList.get(0)).isSameAs(lbf.getBean("sfb1", String.class)); + assertThat(stringList.get(1)).isSameAs(lbf.getBean("sfb2", String.class)); + Map beans = BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, ITestBean.class); assertThat(beans).hasSize(5); assertThat(beans.get("bean")).isSameAs(lbf.getBean("bean")); - assertThat(beans.get("fb1")).isSameAs(lbf.getBean("&fb1", DummyFactory.class).getObject()); + assertThat(beans.get("fb1")).isSameAs(lbf.getBean("fb1",TestBean.class)); assertThat(beans.get("fb2")).isInstanceOf(TestBean.class); - assertThat(beans.get("sfb1")).isInstanceOf(TestBean.class); - assertThat(beans.get("sfb2")).isInstanceOf(TestBean.class); + assertThat(beans.get("sfb1")).isSameAs(lbf.getBean("sfb1", TestBean.class)); + assertThat(beans.get("sfb2")).isSameAs(lbf.getBean("sfb2", TestBean.class)); beans = BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, CharSequence.class); assertThat(beans).hasSize(2); - assertThat(beans.get("sfb1")).isInstanceOf(String.class); - assertThat(beans.get("sfb2")).isInstanceOf(String.class); + assertThat(beans.get("sfb1")).isSameAs(lbf.getBean("sfb1", String.class)); + assertThat(beans.get("sfb2")).isSameAs(lbf.getBean("sfb1", String.class)); assertThat(lbf.getBean("sfb1", ITestBean.class)).isInstanceOf(TestBean.class); assertThat(lbf.getBean("sfb2", ITestBean.class)).isInstanceOf(TestBean.class); @@ -604,4 +649,30 @@ class BeanFactoryUtilsTests { } } + + static class Recipient { + + public Recipient(ITestBean sfb1, ITestBean sfb2, List testBeanList, List stringList, + ObjectProvider testBeanProvider, ObjectProvider stringProvider) { + this.sfb1 = sfb1; + this.sfb2 = sfb2; + this.testBeanList = testBeanList; + this.stringList = stringList; + this.testBeanProvider = testBeanProvider; + this.stringProvider = stringProvider; + } + + ITestBean sfb1; + + ITestBean sfb2; + + List testBeanList; + + List stringList; + + ObjectProvider testBeanProvider; + + ObjectProvider stringProvider; + } + }