From fef3cf8e58a850cb7f6f7b3020afa120c7ed2d08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Nicoll?= Date: Mon, 18 Sep 2023 15:40:23 +0200 Subject: [PATCH] Review AOT-generated code for beanClass and targetType This commit reviews when an AOT-generated bean definition defines a beanClass or targetType. Previously, a beanClass was not consistently set which could lead to issues. Closes gh-31242 --- .../DefaultBeanRegistrationCodeFragments.java | 27 ++--- .../BeanDefinitionMethodGeneratorTests.java | 102 ++++++++++++++++-- 2 files changed, 108 insertions(+), 21 deletions(-) diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index 26df73d45ee..d5281ef6caf 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -117,7 +117,8 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme Class beanClass = (mergedBeanDefinition.hasBeanClass() ? ClassUtils.getUserClass(mergedBeanDefinition.getBeanClass()) : null); CodeBlock beanClassCode = generateBeanClassCode( - beanRegistrationCode.getClassName().packageName(), beanClass); + beanRegistrationCode.getClassName().packageName(), + (beanClass != null ? beanClass : beanType.toClass())); code.addStatement("$T $L = new $T($L)", RootBeanDefinition.class, BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, beanClassCode); if (targetTypeNecessary(beanType, beanClass)) { @@ -127,16 +128,13 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme return code.build(); } - private CodeBlock generateBeanClassCode(String targetPackage, @Nullable Class beanClass) { - if (beanClass != null) { - if (Modifier.isPublic(beanClass.getModifiers()) || targetPackage.equals(beanClass.getPackageName())) { - return CodeBlock.of("$T.class", beanClass); - } - else { - return CodeBlock.of("$S", beanClass.getName()); - } + private CodeBlock generateBeanClassCode(String targetPackage, Class beanClass) { + if (Modifier.isPublic(beanClass.getModifiers()) || targetPackage.equals(beanClass.getPackageName())) { + return CodeBlock.of("$T.class", beanClass); + } + else { + return CodeBlock.of("$S", beanClass.getName()); } - return CodeBlock.of(""); } private CodeBlock generateBeanTypeCode(ResolvableType beanType) { @@ -147,11 +145,14 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme } private boolean targetTypeNecessary(ResolvableType beanType, @Nullable Class beanClass) { - if (beanType.hasGenerics() || beanClass == null) { + if (beanType.hasGenerics()) { + return true; + } + if (beanClass != null + && this.registeredBean.getMergedBeanDefinition().getFactoryMethodName() != null) { return true; } - return (!beanType.toClass().equals(beanClass) - || this.registeredBean.getMergedBeanDefinition().getFactoryMethodName() != null); + return (beanClass != null && !beanType.toClass().equals(beanClass)); } @Override diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 3d490ec1a23..aae78a040a8 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -47,6 +47,7 @@ import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.factory.aot.InnerBeanConfiguration; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean; +import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanConfiguration; import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy; import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Implementation; import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.One; @@ -92,9 +93,27 @@ class BeanDefinitionMethodGeneratorTests { this.beanRegistrationsCode = new MockBeanRegistrationsCode(this.generationContext); } + @Test + void generateWithBeanClassSetsOnlyBeanClass() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(TestBean.class)"); + assertThat(sourceFile).doesNotContain("setTargetType("); + assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } @Test - void generateBeanDefinitionMethodWithOnlyTargetTypeDoesNotSetBeanClass() { + void generateWithTargetTypeWithNoGenericSetsOnlyBeanClass() { RootBeanDefinition beanDefinition = new RootBeanDefinition(); beanDefinition.setTargetType(TestBean.class); RegisteredBean registeredBean = registerBean(beanDefinition); @@ -106,16 +125,16 @@ class BeanDefinitionMethodGeneratorTests { compile(method, (actual, compiled) -> { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); - assertThat(sourceFile).contains("new RootBeanDefinition()"); - assertThat(sourceFile).contains("setTargetType(TestBean.class)"); + assertThat(sourceFile).contains("new RootBeanDefinition(TestBean.class)"); assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); assertThat(actual).isInstanceOf(RootBeanDefinition.class); }); } @Test - void generateBeanDefinitionMethodSpecifiesBeanClassIfSet() { - RootBeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class); + void generateWithTargetTypeUsingGenericsSetsBothBeanClassAndTargetType() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(); + beanDefinition.setTargetType(ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class)); RegisteredBean registeredBean = registerBean(beanDefinition); BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( this.methodGeneratorFactory, registeredBean, null, @@ -123,17 +142,62 @@ class BeanDefinitionMethodGeneratorTests { MethodReference method = generator.generateBeanDefinitionMethod( this.generationContext, this.beanRegistrationsCode); compile(method, (actual, compiled) -> { + assertThat(actual.getResolvableType().resolve()).isEqualTo(GenericBean.class); SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); - assertThat(sourceFile).contains("new RootBeanDefinition(TestBean.class)"); + assertThat(sourceFile).contains("new RootBeanDefinition(GenericBean.class)"); + assertThat(sourceFile).contains( + "setTargetType(ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class))"); + assertThat(sourceFile).contains("setInstanceSupplier(GenericBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateWithBeanClassAndFactoryMethodNameSetsTargetTypeAndBeanClass() { + this.beanFactory.registerSingleton("factory", new SimpleBeanConfiguration()); + RootBeanDefinition beanDefinition = new RootBeanDefinition(SimpleBean.class); + beanDefinition.setFactoryBeanName("factory"); + beanDefinition.setFactoryMethodName("simpleBean"); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(SimpleBean.class)"); + assertThat(sourceFile).contains("setTargetType(SimpleBean.class)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + + @Test + void generateWithTargetTypeAndFactoryMethodNameSetsOnlyBeanClass() { + this.beanFactory.registerSingleton("factory", new SimpleBeanConfiguration()); + RootBeanDefinition beanDefinition = new RootBeanDefinition(); + beanDefinition.setTargetType(SimpleBean.class); + beanDefinition.setFactoryBeanName("factory"); + beanDefinition.setFactoryMethodName("simpleBean"); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(SimpleBean.class)"); assertThat(sourceFile).doesNotContain("setTargetType("); - assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); assertThat(actual).isInstanceOf(RootBeanDefinition.class); }); } @Test - void generateBeanDefinitionMethodSpecifiesBeanClassAndTargetTypIfDifferent() { + void generateWithBeanClassAndTargetTypeDifferentSetsBoth() { RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class); beanDefinition.setTargetType(Implementation.class); beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean")); @@ -152,6 +216,28 @@ class BeanDefinitionMethodGeneratorTests { }); } + @Test + void generateWithBeanClassAndTargetTypWithGenericSetsBoth() { + RootBeanDefinition beanDefinition = new RootBeanDefinition(Integer.class); + beanDefinition.setTargetType(ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class)); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + assertThat(actual.getResolvableType().resolve()).isEqualTo(GenericBean.class); + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("new RootBeanDefinition(Integer.class)"); + assertThat(sourceFile).contains( + "setTargetType(ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class))"); + assertThat(sourceFile).contains("setInstanceSupplier(GenericBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); + } + @Test void generateBeanDefinitionMethodUSeBeanClassNameIfNotReachable() { RootBeanDefinition beanDefinition = new RootBeanDefinition(PackagePrivateTestBean.class);