diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index cf4e4f045c4..97a3df3fd66 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -16,12 +16,14 @@ package org.springframework.beans.factory.aot; +import java.lang.reflect.Constructor; import java.lang.reflect.Executable; import java.util.List; import java.util.function.Predicate; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionHolder; import org.springframework.beans.factory.support.InstanceSupplier; @@ -69,14 +71,21 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments public Class getTarget(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { - Class target = ClassUtils - .getUserClass(constructorOrFactoryMethod.getDeclaringClass()); + Class target = extractDeclaringClass(constructorOrFactoryMethod); while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) { target = registeredBean.getParent().getBeanClass(); } return target; } + private Class extractDeclaringClass(Executable executable) { + Class declaringClass = ClassUtils.getUserClass(executable.getDeclaringClass()); + if (executable instanceof Constructor && FactoryBean.class.isAssignableFrom(declaringClass)) { + return ResolvableType.forType(declaringClass).as(FactoryBean.class).getGeneric(0).toClass(); + } + return executable.getDeclaringClass(); + } + @Override public CodeBlock generateNewBeanDefinitionCode(GenerationContext generationContext, ResolvableType beanType, BeanRegistrationCode beanRegistrationCode) { @@ -107,7 +116,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments generationContext.getRuntimeHints(), attributeFilter, beanRegistrationCode.getMethods(), (name, value) -> generateValueCode(generationContext, name, value)) - .generateCode(beanDefinition); + .generateCode(beanDefinition); } @Nullable @@ -171,7 +180,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments return new InstanceSupplierCodeGenerator(generationContext, beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut) - .generateCode(this.registeredBean, constructorOrFactoryMethod); + .generateCode(this.registeredBean, constructorOrFactoryMethod); } @Override diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java new file mode 100644 index 00000000000..6ac4dfef376 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java @@ -0,0 +1,157 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.aot; + +import java.lang.reflect.Method; + +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.annotation.InjectAnnotationBeanPostProcessorTests.StringFactoryBean; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.factory.DummyFactory; +import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link DefaultBeanRegistrationCodeFragments}. + * + * @author Stephane Nicoll + */ +class DefaultBeanRegistrationCodeFragmentsTests { + + private final BeanRegistrationsCode beanRegistrationsCode = new MockBeanRegistrationsCode(new TestGenerationContext()); + + private final DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + + @Test + void getTargetOnConstructor() { + RegisteredBean registeredBean = registerTestBean(TestBean.class); + assertThat(createInstance(registeredBean).getTarget(registeredBean, + TestBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class); + } + + @Test + void getTargetOnConstructorToFactoryBean() { + RegisteredBean registeredBean = registerTestBean(TestBean.class); + assertThat(createInstance(registeredBean).getTarget(registeredBean, + TestBeanFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class); + } + + @Test + void getTargetOnMethod() { + RegisteredBean registeredBean = registerTestBean(TestBean.class); + Method method = ReflectionUtils.findMethod(TestBeanFactoryBean.class, "getObject"); + assertThat(method).isNotNull(); + assertThat(createInstance(registeredBean).getTarget(registeredBean, + method)).isEqualTo(TestBeanFactoryBean.class); + } + + @Test + void getTargetOnMethodWithInnerBeanInJavaPackage() { + RegisteredBean registeredBean = registerTestBean(TestBean.class); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(String.class)); + Method method = ReflectionUtils.findMethod(getClass(), "createString"); + assertThat(method).isNotNull(); + assertThat(createInstance(innerBean).getTarget(innerBean, + method)).isEqualTo(getClass()); + } + + @Test + void getTargetOnConstructorWithInnerBeanInJavaPackage() { + RegisteredBean registeredBean = registerTestBean(TestBean.class); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(String.class)); + assertThat(createInstance(innerBean).getTarget(innerBean, + String.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class); + } + + @Test + void getTargetOnConstructorWithInnerBeanOnTypeInJavaPackage() { + RegisteredBean registeredBean = registerTestBean(TestBean.class); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", + new RootBeanDefinition(StringFactoryBean.class)); + assertThat(createInstance(innerBean).getTarget(innerBean, + StringFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class); + } + + @Test + void getTargetOnMethodWithInnerBeanInRegularPackage() { + RegisteredBean registeredBean = registerTestBean(DummyFactory.class); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(TestBean.class)); + Method method = ReflectionUtils.findMethod(TestBeanFactoryBean.class, "getObject"); + assertThat(method).isNotNull(); + assertThat(createInstance(innerBean).getTarget(innerBean, method)).isEqualTo(TestBeanFactoryBean.class); + } + + @Test + void getTargetOnConstructorWithInnerBeanInRegularPackage() { + RegisteredBean registeredBean = registerTestBean(DummyFactory.class); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(TestBean.class)); + assertThat(createInstance(innerBean).getTarget(innerBean, + TestBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class); + } + + @Test + void getTargetOnConstructorWithInnerBeanOnFactoryBeanOnTypeInRegularPackage() { + RegisteredBean registeredBean = registerTestBean(DummyFactory.class); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", + new RootBeanDefinition(TestBean.class)); + assertThat(createInstance(innerBean).getTarget(innerBean, + TestBeanFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class); + } + + + private RegisteredBean registerTestBean(Class beanType) { + this.beanFactory.registerBeanDefinition("testBean", + new RootBeanDefinition(beanType)); + return RegisteredBean.of(this.beanFactory, "testBean"); + } + + private BeanRegistrationCodeFragments createInstance(RegisteredBean registeredBean) { + return new DefaultBeanRegistrationCodeFragments(this.beanRegistrationsCode, registeredBean, new BeanDefinitionMethodGeneratorFactory(this.beanFactory)); + } + + @SuppressWarnings("unused") + static String createString() { + return "Test"; + } + + @SuppressWarnings("unused") + static class TestBean { + + } + + + static class TestBeanFactoryBean implements FactoryBean { + + @Override + public TestBean getObject() throws Exception { + return new TestBean(); + } + + @Override + public Class getObjectType() { + return TestBean.class; + } + } + +}