diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/BeanDefinitionReaderUtils.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/BeanDefinitionReaderUtils.java index 5be980df09..a5f7cf97fd 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/BeanDefinitionReaderUtils.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/BeanDefinitionReaderUtils.java @@ -124,13 +124,29 @@ public abstract class BeanDefinitionReaderUtils { id = generatedBeanName + GENERATED_BEAN_NAME_SEPARATOR + ObjectUtils.getIdentityHexString(definition); } else { - // Top-level bean: use plain class name. - // Increase counter until the id is unique. - int counter = -1; - while (counter == -1 || registry.containsBeanDefinition(id)) { - counter++; - id = generatedBeanName + GENERATED_BEAN_NAME_SEPARATOR + counter; - } + // Top-level bean: use plain class name with unique suffix if necessary. + return uniqueBeanName(generatedBeanName, registry); + } + return id; + } + + /** + * Turn the given bean name into a unique bean name for the given bean factory, + * appending a unique counter as suffix if necessary. + * @param beanName the original bean name + * @param registry the bean factory that the definition is going to be + * registered with (to check for existing bean names) + * @return the unique bean name to use + * @since 5.1 + */ + public static String uniqueBeanName(String beanName, BeanDefinitionRegistry registry) { + String id = beanName; + int counter = -1; + + // Increase counter until the id is unique. + while (counter == -1 || registry.containsBeanDefinition(id)) { + counter++; + id = beanName + GENERATED_BEAN_NAME_SEPARATOR + counter; } return id; } diff --git a/spring-context/src/main/kotlin/org/springframework/context/support/BeanDefinitionDsl.kt b/spring-context/src/main/kotlin/org/springframework/context/support/BeanDefinitionDsl.kt index fcd3bec60e..7dbbb928eb 100644 --- a/spring-context/src/main/kotlin/org/springframework/context/support/BeanDefinitionDsl.kt +++ b/spring-context/src/main/kotlin/org/springframework/context/support/BeanDefinitionDsl.kt @@ -18,6 +18,7 @@ package org.springframework.context.support import org.springframework.beans.factory.config.BeanDefinitionCustomizer import org.springframework.beans.factory.support.AbstractBeanDefinition +import org.springframework.beans.factory.support.BeanDefinitionReaderUtils import org.springframework.context.ApplicationContextInitializer import org.springframework.core.env.ConfigurableEnvironment import java.util.function.Supplier @@ -167,11 +168,8 @@ open class BeanDefinitionDsl(private val init: BeanDefinitionDsl.() -> Unit, } } - when (name) { - null -> context.registerBean(T::class.java, customizer) - else -> context.registerBean(name, T::class.java, customizer) - } - + val beanName = name ?: BeanDefinitionReaderUtils.uniqueBeanName(T::class.java.name, context); + context.registerBean(beanName, T::class.java, customizer) } /** @@ -207,13 +205,8 @@ open class BeanDefinitionDsl(private val init: BeanDefinitionDsl.() -> Unit, } - when (name) { - null -> context.registerBean(T::class.java, - Supplier { function.invoke() }, customizer) - else -> context.registerBean(name, T::class.java, - Supplier { function.invoke() }, customizer) - } - + val beanName = name ?: BeanDefinitionReaderUtils.uniqueBeanName(T::class.java.name, context); + context.registerBean(beanName, T::class.java, Supplier { function.invoke() }, customizer) } /** diff --git a/spring-context/src/test/kotlin/org/springframework/context/support/BeanDefinitionDslTests.kt b/spring-context/src/test/kotlin/org/springframework/context/support/BeanDefinitionDslTests.kt index 2ed638b7e5..8863788034 100644 --- a/spring-context/src/test/kotlin/org/springframework/context/support/BeanDefinitionDslTests.kt +++ b/spring-context/src/test/kotlin/org/springframework/context/support/BeanDefinitionDslTests.kt @@ -34,7 +34,6 @@ class BeanDefinitionDslTests { val beans = beans { bean() bean("bar", scope = Scope.PROTOTYPE) - bean { Baz(ref()) } bean { Baz(ref("bar")) } } @@ -59,7 +58,6 @@ class BeanDefinitionDslTests { } } profile("baz") { - bean { Baz(ref()) } bean { Baz(ref("bar")) } } } @@ -89,7 +87,6 @@ class BeanDefinitionDslTests { bean { FooFoo(env["name"]!!) } } environment( { activeProfiles.contains("baz") } ) { - bean { Baz(ref()) } bean { Baz(ref("bar")) } } } @@ -133,8 +130,8 @@ class BeanDefinitionDslTests { @Test // SPR-16269 fun `Provide access to the context for allowing calling advanced features like getBeansOfType`() { val beans = beans { - bean("foo1") - bean("foo2") + bean() + bean() bean { BarBar(context.getBeansOfType().values) } }