diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java index b128333f2a3..73934ac5b04 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java @@ -44,6 +44,7 @@ import org.springframework.util.StringUtils; * Generates a method that returns a {@link BeanDefinition} to be registered. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 * @see BeanDefinitionMethodGeneratorFactory */ @@ -97,11 +98,7 @@ class BeanDefinitionMethodGenerator { ClassName target = codeFragments.getTarget(this.registeredBean, this.constructorOrFactoryMethod); if (!target.canonicalName().startsWith("java.")) { - GeneratedClass generatedClass = generationContext.getGeneratedClasses() - .getOrAddForFeatureComponent("BeanDefinitions", target, type -> { - type.addJavadoc("Bean definitions for {@link $T}", target); - type.addModifiers(Modifier.PUBLIC); - }); + GeneratedClass generatedClass = lookupGeneratedClass(generationContext, target); GeneratedMethods generatedMethods = generatedClass.getMethods() .withPrefix(getName()); GeneratedMethod generatedMethod = generateBeanDefinitionMethod( @@ -117,6 +114,43 @@ class BeanDefinitionMethodGenerator { return generatedMethod.toMethodReference(); } + /** + * Return the {@link GeneratedClass} to use for the specified {@code target}. + *

If the target class is an inner class, a corresponding inner class in + * the original structure is created. + * @param generationContext the generation context to use + * @param target the chosen target class name for the bean definition + * @return the generated class to use + */ + private static GeneratedClass lookupGeneratedClass(GenerationContext generationContext, ClassName target) { + ClassName topLevelClassName = target.topLevelClassName(); + GeneratedClass generatedClass = generationContext.getGeneratedClasses() + .getOrAddForFeatureComponent("BeanDefinitions", topLevelClassName, type -> { + type.addJavadoc("Bean definitions for {@link $T}", topLevelClassName); + type.addModifiers(Modifier.PUBLIC); + }); + List names = target.simpleNames(); + if (names.size() == 1) { + return generatedClass; + } + List namesToProcess = names.subList(1, names.size()); + ClassName currentTargetClassName = topLevelClassName; + GeneratedClass tmp = generatedClass; + for (String nameToProcess : namesToProcess) { + currentTargetClassName = currentTargetClassName.nestedClass(nameToProcess); + tmp = createInnerClass(tmp, nameToProcess + "__BeanDefinitions", currentTargetClassName); + } + return tmp; + } + + private static GeneratedClass createInnerClass(GeneratedClass generatedClass, + String name, ClassName target) { + return generatedClass.getOrAdd(name, type -> { + type.addJavadoc("Bean definitions for {@link $T}", target); + type.addModifiers(Modifier.PUBLIC, Modifier.STATIC); + }); + } + private BeanRegistrationCodeFragments getCodeFragments(GenerationContext generationContext, BeanRegistrationsCode beanRegistrationsCode) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 46dcb0370e1..a5a03f2ab7a 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -42,7 +42,9 @@ import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.AnnotatedBean; import org.springframework.beans.testfixture.beans.GenericBean; import org.springframework.beans.testfixture.beans.TestBean; +import org.springframework.beans.testfixture.beans.factory.aot.InnerBeanConfiguration; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; +import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean; import org.springframework.core.ResolvableType; import org.springframework.core.test.io.support.MockSpringFactoriesLoader; import org.springframework.core.test.tools.CompileWithForkedClassLoader; @@ -60,6 +62,7 @@ import static org.assertj.core.api.Assertions.assertThat; * {@link DefaultBeanRegistrationCodeFragments}. * * @author Phillip Webb + * @author Stephane Nicoll */ class BeanDefinitionMethodGeneratorTests { @@ -99,6 +102,52 @@ class BeanDefinitionMethodGeneratorTests { }); } + @Test + void generateBeanDefinitionMethodWhenHasInnerClassTargetMethodGeneratesMethod() { + this.beanFactory.registerBeanDefinition("testBeanConfiguration", new RootBeanDefinition( + InnerBeanConfiguration.Simple.class)); + RootBeanDefinition beanDefinition = new RootBeanDefinition(SimpleBean.class); + beanDefinition.setFactoryBeanName("testBeanConfiguration"); + beanDefinition.setFactoryMethodName("simpleBean"); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile.getClassName()).endsWith("InnerBeanConfiguration__BeanDefinitions"); + assertThat(sourceFile).contains("public static class Simple__BeanDefinitions") + .contains("Bean definitions for {@link InnerBeanConfiguration.Simple}") + .doesNotContain("Another__BeanDefinitions"); + + }); + } + + @Test + void generateBeanDefinitionMethodWhenHasNestedInnerClassTargetMethodGeneratesMethod() { + this.beanFactory.registerBeanDefinition("testBeanConfiguration", new RootBeanDefinition( + InnerBeanConfiguration.Simple.Another.class)); + RootBeanDefinition beanDefinition = new RootBeanDefinition(SimpleBean.class); + beanDefinition.setFactoryBeanName("testBeanConfiguration"); + beanDefinition.setFactoryMethodName("anotherBean"); + RegisteredBean registeredBean = registerBean(beanDefinition); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + Collections.emptyList()); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile.getClassName()).endsWith("InnerBeanConfiguration__BeanDefinitions"); + assertThat(sourceFile).contains("public static class Simple__BeanDefinitions") + .contains("Bean definitions for {@link InnerBeanConfiguration.Simple}") + .contains("public static class Another__BeanDefinitions") + .contains("Bean definitions for {@link InnerBeanConfiguration.Simple.Another}"); + }); + } + @Test void generateBeanDefinitionMethodWhenHasGenericsGeneratesMethod() { RegisteredBean registeredBean = registerBean(new RootBeanDefinition( diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/InnerBeanConfiguration.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/InnerBeanConfiguration.java new file mode 100644 index 00000000000..c3e5a7afcde --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/InnerBeanConfiguration.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans.factory.aot; + +/** + * A configuration with inner classes. + * + * @author Stephane Nicoll + */ +public class InnerBeanConfiguration { + + public static class Simple { + + public SimpleBean simpleBean() { + return new SimpleBean(); + } + + public static class Another { + + public SimpleBean anotherBean() { + return new SimpleBean(); + } + + } + } +} diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/SimpleBeanConfiguration.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/SimpleBeanConfiguration.java index 2c4e5ede429..dde310e37d3 100644 --- a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/SimpleBeanConfiguration.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/SimpleBeanConfiguration.java @@ -17,6 +17,7 @@ package org.springframework.beans.testfixture.beans.factory.aot; /** + * A sample configuration. * * @author Stephane Nicoll */ diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java index 5a1bb302187..6917e5051e4 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java @@ -24,6 +24,7 @@ import java.util.function.Consumer; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.TypeSpec; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -36,13 +37,18 @@ import org.springframework.util.Assert; */ public final class GeneratedClass { + @Nullable + private final GeneratedClass enclosingClass; + private final ClassName name; private final GeneratedMethods methods; private final Consumer type; - private final Map methodNameSequenceGenerator = new ConcurrentHashMap<>(); + private final Map declaredClasses; + + private final Map methodNameSequenceGenerator; /** @@ -53,9 +59,17 @@ public final class GeneratedClass { * @param type a {@link Consumer} used to build the type */ GeneratedClass(ClassName name, Consumer type) { + this(null, name, type); + } + + private GeneratedClass(@Nullable GeneratedClass enclosingClass, ClassName name, + Consumer type) { + this.enclosingClass = enclosingClass; this.name = name; this.type = type; this.methods = new GeneratedMethods(name, this::generateSequencedMethodName); + this.declaredClasses = new ConcurrentHashMap<>(); + this.methodNameSequenceGenerator = new ConcurrentHashMap<>(); } @@ -79,6 +93,16 @@ public final class GeneratedClass { return (sequence > 0) ? name.toString() + sequence : name.toString(); } + /** + * Return the enclosing {@link GeneratedClass} or {@code null} if this + * instance represents a top-level class. + * @return the enclosing generated class, if any + */ + @Nullable + public GeneratedClass getEnclosingClass() { + return this.enclosingClass; + } + /** * Return the name of the generated class. * @return the name of the generated class @@ -95,10 +119,33 @@ public final class GeneratedClass { return this.methods; } + /** + * Get or add a nested generated class with the specified name. If this method + * has previously been called with the given {@code name}, the existing class + * will be returned, otherwise a new class will be generated. + * @param name the name of the nested class + * @param type a {@link Consumer} used to build the type + * @return an existing or newly generated class whose enclosing class is this class + */ + public GeneratedClass getOrAdd(String name, Consumer type) { + ClassName className = this.name.nestedClass(name); + return this.declaredClasses.computeIfAbsent(className, + key -> new GeneratedClass(this, className, type)); + } + JavaFile generateJavaFile() { + Assert.state(getEnclosingClass() == null, + "Java file cannot be generated for an inner class"); + TypeSpec.Builder type = apply(); + return JavaFile.builder(this.name.packageName(), type.build()).build(); + } + + private TypeSpec.Builder apply() { TypeSpec.Builder type = getBuilder(this.type); this.methods.doWithMethodSpecs(type::addMethod); - return JavaFile.builder(this.name.packageName(), type.build()).build(); + this.declaredClasses.values().forEach(declaredClass -> + type.addType(declaredClass.apply().build())); + return type; } private TypeSpec.Builder getBuilder(Consumer type) { diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java index bab98ebe6b2..23802ed069f 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java @@ -18,6 +18,8 @@ package org.springframework.aot.generate; import java.util.function.Consumer; +import javax.lang.model.element.Modifier; + import org.junit.jupiter.api.Test; import org.springframework.javapoet.ClassName; @@ -35,21 +37,34 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException; */ class GeneratedClassTests { + private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test"); + private static final Consumer emptyTypeCustomizer = type -> {}; private static final Consumer emptyMethodCustomizer = method -> {}; + @Test + void getEnclosingNameOnTopLevelClassReturnsNull() { + GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME); + assertThat(generatedClass.getEnclosingClass()).isNull(); + } + + @Test + void getEnclosingNameOnInnerClassReturnsParent() { + GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME); + GeneratedClass innerGeneratedClass = generatedClass.getOrAdd("Test", emptyTypeCustomizer); + assertThat(innerGeneratedClass.getEnclosingClass()).isEqualTo(generatedClass); + } + @Test void getNameReturnsName() { - ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer); - assertThat(generatedClass.getName()).isSameAs(name); + GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME); + assertThat(generatedClass.getName()).isSameAs(TEST_CLASS_NAME); } @Test void reserveMethodNamesWhenNameUsedThrowsException() { - ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer); + GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME); generatedClass.getMethods().add("apply", emptyMethodCustomizer); assertThatIllegalStateException() .isThrownBy(() -> generatedClass.reserveMethodNames("apply")); @@ -57,8 +72,7 @@ class GeneratedClassTests { @Test void reserveMethodNamesReservesNames() { - ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer); + GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME); generatedClass.reserveMethodNames("apply"); GeneratedMethod generatedMethod = generatedClass.getMethods().add("apply", emptyMethodCustomizer); assertThat(generatedMethod.getName()).isEqualTo("apply1"); @@ -66,18 +80,45 @@ class GeneratedClassTests { @Test void generateMethodNameWhenAllEmptyPartsGeneratesSetName() { - ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer); + GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME); GeneratedMethod generatedMethod = generatedClass.getMethods().add("123", emptyMethodCustomizer); assertThat(generatedMethod.getName()).isEqualTo("$$aot"); } + @Test + void getOrAddWhenRepeatReturnsSameGeneratedClass() { + GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME); + GeneratedClass innerGeneratedClass = generatedClass.getOrAdd("Inner", emptyTypeCustomizer); + GeneratedClass innerGeneratedClass2 = generatedClass.getOrAdd("Inner", emptyTypeCustomizer); + GeneratedClass innerGeneratedClass3 = generatedClass.getOrAdd("Inner", emptyTypeCustomizer); + assertThat(innerGeneratedClass).isSameAs(innerGeneratedClass2).isSameAs(innerGeneratedClass3); + } + @Test void generateJavaFileIncludesGeneratedMethods() { - ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer); + GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME); generatedClass.getMethods().add("test", method -> method.addJavadoc("Test Method")); assertThat(generatedClass.generateJavaFile().toString()).contains("Test Method"); } + @Test + void generateJavaFileIncludesDeclaredClasses() { + GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME); + generatedClass.getOrAdd("First", type -> type.modifiers.add(Modifier.STATIC)); + generatedClass.getOrAdd("Second", type -> type.modifiers.add(Modifier.PRIVATE)); + assertThat(generatedClass.generateJavaFile().toString()) + .contains("static class First").contains("private class Second"); + } + + @Test + void generateJavaFileOnInnerClassThrowsException() { + GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME) + .getOrAdd("Inner", emptyTypeCustomizer); + assertThatIllegalStateException().isThrownBy(generatedClass::generateJavaFile); + } + + private static GeneratedClass createGeneratedClass(ClassName className) { + return new GeneratedClass(className, emptyTypeCustomizer); + } + }