diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java index acc796df582..c65ab4ed1f7 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java @@ -142,7 +142,7 @@ public class InstanceSupplierCodeGenerator { if (constructorOrFactoryMethod instanceof Constructor constructor) { return generateCodeForConstructor(registeredBean, constructor); } - if (constructorOrFactoryMethod instanceof Method method) { + if (constructorOrFactoryMethod instanceof Method method && !KotlinDetector.isSuspendingFunction(method)) { return generateCodeForFactoryMethod(registeredBean, method, instantiationDescriptor.targetClass()); } throw new IllegalStateException( diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/ConstructorResolver.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/ConstructorResolver.java index 16adf097d4c..96773ccfee5 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/ConstructorResolver.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/ConstructorResolver.java @@ -63,6 +63,7 @@ import org.springframework.beans.factory.config.DependencyDescriptor; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.config.TypedStringValue; import org.springframework.core.CollectionFactory; +import org.springframework.core.KotlinDetector; import org.springframework.core.MethodParameter; import org.springframework.core.NamedThreadLocal; import org.springframework.core.ParameterNameDiscoverer; @@ -623,6 +624,11 @@ class ConstructorResolver { "Invalid factory method '" + mbd.getFactoryMethodName() + "' on class [" + factoryClass.getName() + "]: needs to have a non-void return type!"); } + else if (KotlinDetector.isKotlinPresent() && KotlinDetector.isSuspendingFunction(factoryMethodToUse)) { + throw new BeanCreationException(mbd.getResourceDescription(), beanName, + "Invalid factory method '" + mbd.getFactoryMethodName() + "' on class [" + + factoryClass.getName() + "]: suspending functions are not supported!"); + } else if (ambiguousFactoryMethods != null) { throw new BeanCreationException(mbd.getResourceDescription(), beanName, "Ambiguous factory method matches found on class [" + factoryClass.getName() + "] " + diff --git a/spring-beans/src/test/kotlin/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorKotlinTests.kt b/spring-beans/src/test/kotlin/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorKotlinTests.kt index 65c71ce2dda..63501a601c9 100644 --- a/spring-beans/src/test/kotlin/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorKotlinTests.kt +++ b/spring-beans/src/test/kotlin/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorKotlinTests.kt @@ -22,10 +22,8 @@ import org.junit.jupiter.api.Test import org.springframework.aot.hint.* import org.springframework.aot.test.generate.TestGenerationContext import org.springframework.beans.factory.config.BeanDefinition -import org.springframework.beans.factory.support.DefaultListableBeanFactory -import org.springframework.beans.factory.support.InstanceSupplier -import org.springframework.beans.factory.support.RegisteredBean -import org.springframework.beans.factory.support.RootBeanDefinition +import org.springframework.beans.factory.support.* +import org.springframework.beans.testfixture.beans.KotlinConfiguration import org.springframework.beans.testfixture.beans.KotlinTestBean import org.springframework.beans.testfixture.beans.KotlinTestBeanWithOptionalParameter import org.springframework.beans.testfixture.beans.factory.aot.DeferredTypeBuilder @@ -47,6 +45,8 @@ class InstanceSupplierCodeGeneratorKotlinTests { private val generationContext = TestGenerationContext() + private val beanFactory = DefaultListableBeanFactory() + @Test fun generateWhenHasDefaultConstructor() { val beanDefinition: BeanDefinition = RootBeanDefinition(KotlinTestBean::class.java) @@ -74,6 +74,39 @@ class InstanceSupplierCodeGeneratorKotlinTests { .satisfies(hasMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)) } + @Test + fun generateWhenHasFactoryMethodWithNoArg() { + val beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(String::class.java) + .setFactoryMethodOnBean("stringBean", "config").beanDefinition + this.beanFactory.registerBeanDefinition("config", BeanDefinitionBuilder + .genericBeanDefinition(KotlinConfiguration::class.java).beanDefinition + ) + compile(beanFactory, beanDefinition) { instanceSupplier, compiled -> + val bean = getBean(beanFactory, beanDefinition, instanceSupplier) + Assertions.assertThat(bean).isInstanceOf(String::class.java) + Assertions.assertThat(bean).isEqualTo("Hello") + Assertions.assertThat(compiled.sourceFile).contains( + "getBeanFactory().getBean(KotlinConfiguration.class).stringBean()" + ) + } + Assertions.assertThat(getReflectionHints().getTypeHint(KotlinConfiguration::class.java)) + .satisfies(hasMethodWithMode(ExecutableMode.INTROSPECT)) + } + + @Test + fun generateWhenHasSuspendingFactoryMethod() { + val beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(String::class.java) + .setFactoryMethodOnBean("suspendingStringBean", "config").beanDefinition + this.beanFactory.registerBeanDefinition("config", BeanDefinitionBuilder + .genericBeanDefinition(KotlinConfiguration::class.java).beanDefinition + ) + Assertions.assertThatIllegalStateException().isThrownBy { + compile(beanFactory, beanDefinition) { _, _ -> } + } + } + private fun getReflectionHints(): ReflectionHints { return generationContext.runtimeHints.reflection() } @@ -96,6 +129,12 @@ class InstanceSupplierCodeGeneratorKotlinTests { } } + private fun hasMethodWithMode(mode: ExecutableMode): ThrowingConsumer { + return ThrowingConsumer { hint: TypeHint -> + Assertions.assertThat(hint.methods()).anySatisfy(hasMode(mode)) + } + } + @Suppress("UNCHECKED_CAST") private fun getBean(beanFactory: DefaultListableBeanFactory, beanDefinition: BeanDefinition, instanceSupplier: InstanceSupplier<*>): T { diff --git a/spring-beans/src/test/kotlin/org/springframework/beans/factory/support/ConstructorResolverKotlinTests.kt b/spring-beans/src/test/kotlin/org/springframework/beans/factory/support/ConstructorResolverKotlinTests.kt new file mode 100644 index 00000000000..4aa8bb9d7c1 --- /dev/null +++ b/spring-beans/src/test/kotlin/org/springframework/beans/factory/support/ConstructorResolverKotlinTests.kt @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.support + +import org.assertj.core.api.Assertions +import org.junit.jupiter.api.Test +import org.springframework.beans.BeanWrapper +import org.springframework.beans.factory.BeanCreationException +import org.springframework.beans.factory.config.BeanDefinition +import org.springframework.beans.testfixture.beans.factory.generator.factory.KotlinFactory + +class ConstructorResolverKotlinTests { + + @Test + fun instantiateBeanInstanceWithBeanClassAndFactoryMethodName() { + val beanFactory = DefaultListableBeanFactory() + val beanDefinition: BeanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(KotlinFactory::class.java).setFactoryMethod("create") + .beanDefinition + val beanWrapper = instantiate(beanFactory, beanDefinition) + Assertions.assertThat(beanWrapper.wrappedInstance).isEqualTo("test") + } + + @Test + fun instantiateBeanInstanceWithBeanClassAndSuspendingFactoryMethodName() { + val beanFactory = DefaultListableBeanFactory() + val beanDefinition: BeanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(KotlinFactory::class.java).setFactoryMethod("suspendingCreate") + .beanDefinition + Assertions.assertThatThrownBy { instantiate(beanFactory, beanDefinition, null) } + .isInstanceOf(BeanCreationException::class.java) + .hasMessageContaining("suspending functions are not supported") + + } + + private fun instantiate(beanFactory: DefaultListableBeanFactory, beanDefinition: BeanDefinition, + vararg explicitArgs: Any?): BeanWrapper { + return ConstructorResolver(beanFactory) + .instantiateUsingFactoryMethod("testBean", (beanDefinition as RootBeanDefinition), explicitArgs) + } +} diff --git a/spring-beans/src/testFixtures/kotlin/org/springframework/beans/testfixture/beans/KotlinConfiguration.kt b/spring-beans/src/testFixtures/kotlin/org/springframework/beans/testfixture/beans/KotlinConfiguration.kt new file mode 100644 index 00000000000..24a61d0744b --- /dev/null +++ b/spring-beans/src/testFixtures/kotlin/org/springframework/beans/testfixture/beans/KotlinConfiguration.kt @@ -0,0 +1,28 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans + +class KotlinConfiguration { + + fun stringBean(): String { + return "Hello" + } + + suspend fun suspendingStringBean(): String { + return "Hello" + } +} diff --git a/spring-beans/src/testFixtures/kotlin/org/springframework/beans/testfixture/beans/factory/generator/factory/KotlinFactory.kt b/spring-beans/src/testFixtures/kotlin/org/springframework/beans/testfixture/beans/factory/generator/factory/KotlinFactory.kt new file mode 100644 index 00000000000..408d9421469 --- /dev/null +++ b/spring-beans/src/testFixtures/kotlin/org/springframework/beans/testfixture/beans/factory/generator/factory/KotlinFactory.kt @@ -0,0 +1,29 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans.factory.generator.factory + +class KotlinFactory { + + companion object { + + @JvmStatic + fun create() = "test" + + @JvmStatic + suspend fun suspendingCreate() = "test" + } +} diff --git a/spring-core/src/main/java/org/springframework/core/MethodParameter.java b/spring-core/src/main/java/org/springframework/core/MethodParameter.java index 40388e4f8cb..62855808227 100644 --- a/spring-core/src/main/java/org/springframework/core/MethodParameter.java +++ b/spring-core/src/main/java/org/springframework/core/MethodParameter.java @@ -720,7 +720,7 @@ public class MethodParameter { else if (this.executable instanceof Constructor constructor) { parameterNames = discoverer.getParameterNames(constructor); } - if (parameterNames != null) { + if (parameterNames != null && this.parameterIndex < parameterNames.length) { this.parameterName = parameterNames[this.parameterIndex]; } this.parameterNameDiscoverer = null; diff --git a/spring-core/src/test/kotlin/org/springframework/core/AbstractReflectionParameterNameDiscovererKotlinTests.kt b/spring-core/src/test/kotlin/org/springframework/core/AbstractReflectionParameterNameDiscovererKotlinTests.kt index d84f6c32196..583db8c6b1b 100644 --- a/spring-core/src/test/kotlin/org/springframework/core/AbstractReflectionParameterNameDiscovererKotlinTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/core/AbstractReflectionParameterNameDiscovererKotlinTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.springframework.util.ReflectionUtils +import kotlin.coroutines.Continuation /** * Abstract tests for Kotlin [ParameterNameDiscoverer] aware implementations. @@ -46,6 +47,14 @@ abstract class AbstractReflectionParameterNameDiscovererKotlinTests(protected va assertThat(actualMethodParams).contains("message") } + @Test + fun getParameterNamesOnSuspendingFunction() { + val method = ReflectionUtils.findMethod(CoroutinesMessageService::class.java, "sendMessage", + String::class.java, Continuation::class.java)!! + val actualMethodParams = parameterNameDiscoverer.getParameterNames(method) + assertThat(actualMethodParams).containsExactly("message") + } + @Test fun getParameterNamesOnExtensionMethod() { val method = ReflectionUtils.findMethod(UtilityClass::class.java, "identity", String::class.java)!! @@ -65,4 +74,8 @@ abstract class AbstractReflectionParameterNameDiscovererKotlinTests(protected va fun String.identity() = this } + class CoroutinesMessageService { + suspend fun sendMessage(message: String) = message + } + } diff --git a/spring-core/src/test/kotlin/org/springframework/core/MethodParameterKotlinTests.kt b/spring-core/src/test/kotlin/org/springframework/core/MethodParameterKotlinTests.kt index c1c8901817e..18f9a04aab4 100644 --- a/spring-core/src/test/kotlin/org/springframework/core/MethodParameterKotlinTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/core/MethodParameterKotlinTests.kt @@ -114,6 +114,27 @@ class MethodParameterKotlinTests { assertThat(returnGenericParameterType("suspendFun8")).isEqualTo(Object::class.java) } + @Test + fun `Parameter name for regular function`() { + val methodParameter = returnMethodParameter("nullable", 0) + methodParameter.initParameterNameDiscovery(KotlinReflectionParameterNameDiscoverer()) + assertThat(methodParameter.getParameterName()).isEqualTo("nullable") + } + + @Test + fun `Parameter name for suspending function`() { + val methodParameter = returnMethodParameter("suspendFun", 0) + methodParameter.initParameterNameDiscovery(KotlinReflectionParameterNameDiscoverer()) + assertThat(methodParameter.getParameterName()).isEqualTo("p1") + } + + @Test + fun `Continuation parameter name for suspending function`() { + val methodParameter = returnMethodParameter("suspendFun", 1) + methodParameter.initParameterNameDiscovery(KotlinReflectionParameterNameDiscoverer()) + assertThat(methodParameter.getParameterName()).isNull() + } + @Test fun `Continuation parameter is optional`() { val method = this::class.java.getDeclaredMethod("suspendFun", String::class.java, Continuation::class.java) @@ -126,8 +147,8 @@ class MethodParameterKotlinTests { private fun returnGenericParameterTypeName(funName: String) = returnGenericParameterType(funName).typeName private fun returnGenericParameterTypeBoundName(funName: String) = (returnGenericParameterType(funName) as TypeVariable<*>).bounds[0].typeName - private fun returnMethodParameter(funName: String) = - MethodParameter(this::class.declaredFunctions.first { it.name == funName }.javaMethod!!, -1) + private fun returnMethodParameter(funName: String, parameterIndex: Int = -1) = + MethodParameter(this::class.declaredFunctions.first { it.name == funName }.javaMethod!!, parameterIndex) @Suppress("unused_parameter") fun nullable(nullable: String?): Int? = 42