Browse Source

Allow access to env from SupplierContextDsl

Closes gh-34943
pull/34944/head
Sébastien Deleuze 10 months ago
parent
commit
eed0a3ff59
  1. 6
      spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanRegistrarDsl.kt
  2. 20
      spring-context/src/test/kotlin/org/springframework/context/annotation/BeanRegistrarDslConfigurationTests.kt

6
spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanRegistrarDsl.kt

@ -302,7 +302,7 @@ open class BeanRegistrarDsl(private val init: BeanRegistrarDsl.() -> Unit): Bean
it.prototype() it.prototype()
} }
it.supplier { it.supplier {
SupplierContextDsl<T>(it).supplier() SupplierContextDsl<T>(it, env).supplier()
} }
val resolvableType = ResolvableType.forType(object: ParameterizedTypeReference<T>() {}); val resolvableType = ResolvableType.forType(object: ParameterizedTypeReference<T>() {});
if (resolvableType.hasGenerics()) { if (resolvableType.hasGenerics()) {
@ -370,7 +370,7 @@ open class BeanRegistrarDsl(private val init: BeanRegistrarDsl.() -> Unit): Bean
it.prototype() it.prototype()
} }
it.supplier { it.supplier {
SupplierContextDsl<T>(it).supplier() SupplierContextDsl<T>(it, env).supplier()
} }
val resolvableType = ResolvableType.forType(object: ParameterizedTypeReference<T>() {}); val resolvableType = ResolvableType.forType(object: ParameterizedTypeReference<T>() {});
if (resolvableType.hasGenerics()) { if (resolvableType.hasGenerics()) {
@ -1074,7 +1074,7 @@ open class BeanRegistrarDsl(private val init: BeanRegistrarDsl.() -> Unit): Bean
* to bean dependencies. * to bean dependencies.
*/ */
@BeanRegistrarDslMarker @BeanRegistrarDslMarker
open class SupplierContextDsl<T>(@PublishedApi internal val context: SupplierContext) { open class SupplierContextDsl<T>(@PublishedApi internal val context: SupplierContext, val env: Environment) {
/** /**
* Return the bean instance that uniquely matches the given object type, * Return the bean instance that uniquely matches the given object type,

20
spring-context/src/test/kotlin/org/springframework/context/annotation/BeanRegistrarDslConfigurationTests.kt

@ -18,7 +18,6 @@ package org.springframework.context.annotation
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
import org.assertj.core.api.ThrowableAssert
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.springframework.beans.factory.BeanRegistrarDsl import org.springframework.beans.factory.BeanRegistrarDsl
import org.springframework.beans.factory.InitializingBean import org.springframework.beans.factory.InitializingBean
@ -26,6 +25,7 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException
import org.springframework.beans.factory.config.BeanDefinition import org.springframework.beans.factory.config.BeanDefinition
import org.springframework.beans.factory.getBean import org.springframework.beans.factory.getBean
import org.springframework.beans.factory.support.RootBeanDefinition import org.springframework.beans.factory.support.RootBeanDefinition
import org.springframework.mock.env.MockEnvironment
import java.util.function.Supplier import java.util.function.Supplier
/** /**
@ -37,10 +37,13 @@ class BeanRegistrarDslConfigurationTests {
@Test @Test
fun beanRegistrar() { fun beanRegistrar() {
val context = AnnotationConfigApplicationContext(BeanRegistrarKotlinConfiguration::class.java) val context = AnnotationConfigApplicationContext()
context.register(BeanRegistrarKotlinConfiguration::class.java)
context.environment = MockEnvironment().withProperty("hello.world", "Hello World!")
context.refresh()
assertThat(context.getBean<Bar>().foo).isEqualTo(context.getBean<Foo>()) assertThat(context.getBean<Bar>().foo).isEqualTo(context.getBean<Foo>())
assertThat(context.getBean<Foo>("foo")).isEqualTo(context.getBean<Foo>("fooAlias")) assertThat(context.getBean<Foo>("foo")).isEqualTo(context.getBean<Foo>("fooAlias"))
assertThatThrownBy(ThrowableAssert.ThrowingCallable { context.getBean<Baz>() }).isInstanceOf(NoSuchBeanDefinitionException::class.java) assertThatThrownBy { context.getBean<Baz>() }.isInstanceOf(NoSuchBeanDefinitionException::class.java)
assertThat(context.getBean<Init>().initialized).isTrue() assertThat(context.getBean<Init>().initialized).isTrue()
val beanDefinition = context.getBeanDefinition("bar") val beanDefinition = context.getBeanDefinition("bar")
assertThat(beanDefinition.scope).isEqualTo(BeanDefinition.SCOPE_PROTOTYPE) assertThat(beanDefinition.scope).isEqualTo(BeanDefinition.SCOPE_PROTOTYPE)
@ -53,7 +56,8 @@ class BeanRegistrarDslConfigurationTests {
fun beanRegistrarWithProfile() { fun beanRegistrarWithProfile() {
val context = AnnotationConfigApplicationContext() val context = AnnotationConfigApplicationContext()
context.register(BeanRegistrarKotlinConfiguration::class.java) context.register(BeanRegistrarKotlinConfiguration::class.java)
context.getEnvironment().addActiveProfile("baz") context.environment = MockEnvironment().withProperty("hello.world", "Hello World!")
context.environment.addActiveProfile("baz")
context.refresh() context.refresh()
assertThat(context.getBean<Bar>().foo).isEqualTo(context.getBean<Foo>()) assertThat(context.getBean<Bar>().foo).isEqualTo(context.getBean<Foo>())
assertThat(context.getBean<Baz>().message).isEqualTo("Hello World!") assertThat(context.getBean<Baz>().message).isEqualTo("Hello World!")
@ -101,7 +105,7 @@ class BeanRegistrarDslConfigurationTests {
Bar(bean<Foo>()) Bar(bean<Foo>())
} }
profile("baz") { profile("baz") {
registerBean { Baz("Hello World!") } registerBean { Baz(env.getRequiredProperty("hello.world")) }
} }
registerBean<Init>() registerBean<Init>()
registerBean(::booFactory, "fooFactory") registerBean(::booFactory, "fooFactory")
@ -113,11 +117,7 @@ class BeanRegistrarDslConfigurationTests {
private class GenericBeanRegistrar : BeanRegistrarDsl({ private class GenericBeanRegistrar : BeanRegistrarDsl({
registerBean<Supplier<Foo>>(name = "fooSupplier") { registerBean<Supplier<Foo>>(name = "fooSupplier") {
object: Supplier<Foo> { Supplier<Foo> { Foo() }
override fun get(): Foo {
return Foo()
}
}
} }
}) })

Loading…
Cancel
Save