diff --git a/spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanFactoryExtension.kt b/spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanFactoryExtension.kt index 9ff92b9fadb..e7186535e26 100644 --- a/spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanFactoryExtension.kt +++ b/spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanFactoryExtension.kt @@ -15,16 +15,33 @@ object BeanFactoryExtension { */ fun BeanFactory.getBean(requiredType: KClass) = getBean(requiredType.java) + /** + * @see BeanFactory.getBean(Class) + */ + inline fun BeanFactory.getBean() = getBean(T::class.java) + /** * @see BeanFactory.getBean(String, Class) */ fun BeanFactory.getBean(name: String, requiredType: KClass) = getBean(name, requiredType.java) + /** + * @see BeanFactory.getBean(String, Class) + */ + inline fun BeanFactory.getBean(name: String) = + getBean(name, T::class.java) + /** * @see BeanFactory.getBean(Class, Object...) */ fun BeanFactory.getBean(requiredType: KClass, vararg args:Any) = getBean(requiredType.java, *args) + /** + * @see BeanFactory.getBean(Class, Object...) + */ + inline fun BeanFactory.getBean(vararg args:Any) = + getBean(T::class.java, *args) + } diff --git a/spring-beans/src/main/kotlin/org/springframework/beans/factory/ListableBeanFactoryExtension.kt b/spring-beans/src/main/kotlin/org/springframework/beans/factory/ListableBeanFactoryExtension.kt index 26004217f65..9cb0f3be897 100644 --- a/spring-beans/src/main/kotlin/org/springframework/beans/factory/ListableBeanFactoryExtension.kt +++ b/spring-beans/src/main/kotlin/org/springframework/beans/factory/ListableBeanFactoryExtension.kt @@ -16,6 +16,12 @@ object ListableBeanFactoryExtension { fun ListableBeanFactory.getBeanNamesForType(type: KClass) = getBeanNamesForType(type.java) + /** + * @see ListableBeanFactory.getBeanNamesForType(Class) + */ + inline fun ListableBeanFactory.getBeanNamesForType() = + getBeanNamesForType(T::class.java) + /** * @see ListableBeanFactory.getBeanNamesForType(Class, boolean, boolean) */ @@ -23,12 +29,24 @@ object ListableBeanFactoryExtension { includeNonSingletons: Boolean, allowEagerInit: Boolean) = getBeanNamesForType(type.java, includeNonSingletons, allowEagerInit) + /** + * @see ListableBeanFactory.getBeanNamesForType(Class, boolean, boolean) + */ + inline fun ListableBeanFactory.getBeanNamesForType(includeNonSingletons: Boolean, allowEagerInit: Boolean) = + getBeanNamesForType(T::class.java, includeNonSingletons, allowEagerInit) + /** * @see ListableBeanFactory.getBeansOfType(Class) */ fun ListableBeanFactory.getBeansOfType(type: KClass) = getBeansOfType(type.java) + /** + * @see ListableBeanFactory.getBeansOfType(Class) + */ + inline fun ListableBeanFactory.getBeansOfType() = + getBeansOfType(T::class.java) + /** * @see ListableBeanFactory.getBeansOfType(Class, boolean, boolean) */ @@ -36,22 +54,46 @@ object ListableBeanFactoryExtension { includeNonSingletons: Boolean, allowEagerInit: Boolean) = getBeansOfType(type.java, includeNonSingletons, allowEagerInit) + /** + * @see ListableBeanFactory.getBeansOfType(Class, boolean, boolean) + */ + inline fun ListableBeanFactory.getBeansOfType(includeNonSingletons: Boolean, allowEagerInit: Boolean) = + getBeansOfType(T::class.java, includeNonSingletons, allowEagerInit) + /** * @see ListableBeanFactory.getBeanNamesForAnnotation */ fun ListableBeanFactory.getBeanNamesForAnnotation(type: KClass) = getBeanNamesForAnnotation(type.java) + /** + * @see ListableBeanFactory.getBeanNamesForAnnotation + */ + inline fun ListableBeanFactory.getBeanNamesForAnnotation() = + getBeanNamesForAnnotation(T::class.java) + /** * @see ListableBeanFactory.getBeansWithAnnotation */ fun ListableBeanFactory.getBeansWithAnnotation(type: KClass) = getBeansWithAnnotation(type.java) + /** + * @see ListableBeanFactory.getBeansWithAnnotation + */ + inline fun ListableBeanFactory.getBeansWithAnnotation() = + getBeansWithAnnotation(T::class.java) + /** * @see ListableBeanFactoryExtension.findAnnotationOnBean */ fun ListableBeanFactory.findAnnotationOnBean(beanName:String, type: KClass) = findAnnotationOnBean(beanName, type.java) + /** + * @see ListableBeanFactoryExtension.findAnnotationOnBean + */ + inline fun ListableBeanFactory.findAnnotationOnBean(beanName:String) = + findAnnotationOnBean(beanName, T::class.java) + } diff --git a/spring-context/src/main/kotlin/org/springframework/context/support/GenericApplicationContextExtension.kt b/spring-context/src/main/kotlin/org/springframework/context/support/GenericApplicationContextExtension.kt index 1bf3a296802..5e1dc33ede3 100644 --- a/spring-context/src/main/kotlin/org/springframework/context/support/GenericApplicationContextExtension.kt +++ b/spring-context/src/main/kotlin/org/springframework/context/support/GenericApplicationContextExtension.kt @@ -23,6 +23,13 @@ object GenericApplicationContextExtension { registerBean(beanClass.java, *customizers) } + /** + * @see GenericApplicationContext.registerBean(Class, BeanDefinitionCustomizer...) + */ + inline fun GenericApplicationContext.registerBean(vararg customizers: BeanDefinitionCustomizer) { + registerBean(T::class.java, *customizers) + } + /** * @see GenericApplicationContext.registerBean(String, Class, BeanDefinitionCustomizer...) */ @@ -31,6 +38,13 @@ object GenericApplicationContextExtension { registerBean(beanName, beanClass.java, *customizers) } + /** + * @see GenericApplicationContext.registerBean(String, Class, BeanDefinitionCustomizer...) + */ + inline fun GenericApplicationContext.registerBean(beanName: String, vararg customizers: BeanDefinitionCustomizer) { + registerBean(beanName, T::class.java, *customizers) + } + /** * @see GenericApplicationContext.registerBean(Class, Supplier, BeanDefinitionCustomizer...) */ @@ -46,4 +60,6 @@ object GenericApplicationContextExtension { vararg customizers: BeanDefinitionCustomizer, crossinline function: (ApplicationContext) -> T) { registerBean(name, T::class.java, Supplier { function.invoke(this) }, *customizers) } + + fun GenericApplicationContext(configure: GenericApplicationContext.()->Unit) = GenericApplicationContext().apply(configure) } diff --git a/spring-context/src/test/kotlin/org/springframework/context/support/GenericApplicationContextExtensionTests.kt b/spring-context/src/test/kotlin/org/springframework/context/support/GenericApplicationContextExtensionTests.kt index daaa8873cfb..ef38c0eb8b6 100644 --- a/spring-context/src/test/kotlin/org/springframework/context/support/GenericApplicationContextExtensionTests.kt +++ b/spring-context/src/test/kotlin/org/springframework/context/support/GenericApplicationContextExtensionTests.kt @@ -4,6 +4,7 @@ import org.junit.Assert.assertNotNull import org.junit.Test import org.springframework.context.support.GenericApplicationContextExtension.registerBean import org.springframework.beans.factory.BeanFactoryExtension.getBean +import org.springframework.context.support.GenericApplicationContextExtension.GenericApplicationContext class GenericApplicationContextExtensionTests { @@ -59,6 +60,17 @@ class GenericApplicationContextExtensionTests { assertNotNull(context.getBean("b")) } + @Test + fun registerBeanWithGradleStyleApi() { + val context = GenericApplicationContext { + registerBean() + registerBean { BeanB(it.getBean()) } + } + context.refresh() + assertNotNull(context.getBean()) + assertNotNull(context.getBean()) + } + internal class BeanA internal class BeanB(val a: BeanA)