diff --git a/spring-context/src/main/kotlin/org/springframework/context/support/BeanDefinitionDsl.kt b/spring-context/src/main/kotlin/org/springframework/context/support/BeanDefinitionDsl.kt index 8bf012e734..4ddbc9ab05 100644 --- a/spring-context/src/main/kotlin/org/springframework/context/support/BeanDefinitionDsl.kt +++ b/spring-context/src/main/kotlin/org/springframework/context/support/BeanDefinitionDsl.kt @@ -17,7 +17,6 @@ package org.springframework.context.support import org.springframework.beans.factory.config.BeanDefinitionCustomizer -import org.springframework.context.ApplicationContext import org.springframework.core.env.ConfigurableEnvironment import java.util.function.Supplier @@ -38,12 +37,20 @@ open class BeanDefinitionDsl(val condition: (ConfigurableEnvironment) -> Boolean PROTOTYPE } - class BeanDefinitionContext(val context: ApplicationContext) { + class BeanDefinitionContext(val context: GenericApplicationContext) { + inline fun ref(name: String? = null) : T = when (name) { null -> context.getBean(T::class.java) else -> context.getBean(name, T::class.java) } + + /** + * Get the [ConfigurableEnvironment] associated to the underlying [GenericApplicationContext]. + */ + val env : ConfigurableEnvironment + get() = context.environment + } /** diff --git a/spring-context/src/test/kotlin/org/springframework/context/support/BeanDefinitionDslTests.kt b/spring-context/src/test/kotlin/org/springframework/context/support/BeanDefinitionDslTests.kt index 690e3acc06..b4b0260bee 100644 --- a/spring-context/src/test/kotlin/org/springframework/context/support/BeanDefinitionDslTests.kt +++ b/spring-context/src/test/kotlin/org/springframework/context/support/BeanDefinitionDslTests.kt @@ -21,6 +21,7 @@ import org.junit.Test import org.springframework.beans.factory.NoSuchBeanDefinitionException import org.springframework.beans.factory.getBean import org.springframework.context.support.BeanDefinitionDsl.* +import org.springframework.core.env.SimpleCommandLinePropertySource class BeanDefinitionDslTests { @@ -75,13 +76,16 @@ class BeanDefinitionDslTests { val beans = beans { bean() bean("bar") + bean { FooFoo(it.env.getProperty("name")) } environment({it.activeProfiles.contains("baz")}) { bean { Baz(it.ref()) } bean { Baz(it.ref("bar")) } } } - val context = GenericApplicationContext() + val context = GenericApplicationContext().apply { + environment.propertySources.addFirst(SimpleCommandLinePropertySource("--name=foofoo")) + } beans.invoke(context) context.refresh() @@ -92,6 +96,8 @@ class BeanDefinitionDslTests { fail("Expect NoSuchBeanDefinitionException to be thrown") } catch(ex: NoSuchBeanDefinitionException) { null } + val foofoo = context.getBean() + assertEquals("foofoo", foofoo.name) } } @@ -99,3 +105,4 @@ class BeanDefinitionDslTests { class Foo class Bar class Baz(val bar: Bar) +class FooFoo(val name: String)