diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java
index b128333f2a3..73934ac5b04 100644
--- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java
+++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java
@@ -44,6 +44,7 @@ import org.springframework.util.StringUtils;
* Generates a method that returns a {@link BeanDefinition} to be registered.
*
* @author Phillip Webb
+ * @author Stephane Nicoll
* @since 6.0
* @see BeanDefinitionMethodGeneratorFactory
*/
@@ -97,11 +98,7 @@ class BeanDefinitionMethodGenerator {
ClassName target = codeFragments.getTarget(this.registeredBean,
this.constructorOrFactoryMethod);
if (!target.canonicalName().startsWith("java.")) {
- GeneratedClass generatedClass = generationContext.getGeneratedClasses()
- .getOrAddForFeatureComponent("BeanDefinitions", target, type -> {
- type.addJavadoc("Bean definitions for {@link $T}", target);
- type.addModifiers(Modifier.PUBLIC);
- });
+ GeneratedClass generatedClass = lookupGeneratedClass(generationContext, target);
GeneratedMethods generatedMethods = generatedClass.getMethods()
.withPrefix(getName());
GeneratedMethod generatedMethod = generateBeanDefinitionMethod(
@@ -117,6 +114,43 @@ class BeanDefinitionMethodGenerator {
return generatedMethod.toMethodReference();
}
+ /**
+ * Return the {@link GeneratedClass} to use for the specified {@code target}.
+ *
If the target class is an inner class, a corresponding inner class in
+ * the original structure is created.
+ * @param generationContext the generation context to use
+ * @param target the chosen target class name for the bean definition
+ * @return the generated class to use
+ */
+ private static GeneratedClass lookupGeneratedClass(GenerationContext generationContext, ClassName target) {
+ ClassName topLevelClassName = target.topLevelClassName();
+ GeneratedClass generatedClass = generationContext.getGeneratedClasses()
+ .getOrAddForFeatureComponent("BeanDefinitions", topLevelClassName, type -> {
+ type.addJavadoc("Bean definitions for {@link $T}", topLevelClassName);
+ type.addModifiers(Modifier.PUBLIC);
+ });
+ List names = target.simpleNames();
+ if (names.size() == 1) {
+ return generatedClass;
+ }
+ List namesToProcess = names.subList(1, names.size());
+ ClassName currentTargetClassName = topLevelClassName;
+ GeneratedClass tmp = generatedClass;
+ for (String nameToProcess : namesToProcess) {
+ currentTargetClassName = currentTargetClassName.nestedClass(nameToProcess);
+ tmp = createInnerClass(tmp, nameToProcess + "__BeanDefinitions", currentTargetClassName);
+ }
+ return tmp;
+ }
+
+ private static GeneratedClass createInnerClass(GeneratedClass generatedClass,
+ String name, ClassName target) {
+ return generatedClass.getOrAdd(name, type -> {
+ type.addJavadoc("Bean definitions for {@link $T}", target);
+ type.addModifiers(Modifier.PUBLIC, Modifier.STATIC);
+ });
+ }
+
private BeanRegistrationCodeFragments getCodeFragments(GenerationContext generationContext,
BeanRegistrationsCode beanRegistrationsCode) {
diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java
index 46dcb0370e1..a5a03f2ab7a 100644
--- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java
+++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java
@@ -42,7 +42,9 @@ import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.testfixture.beans.AnnotatedBean;
import org.springframework.beans.testfixture.beans.GenericBean;
import org.springframework.beans.testfixture.beans.TestBean;
+import org.springframework.beans.testfixture.beans.factory.aot.InnerBeanConfiguration;
import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode;
+import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean;
import org.springframework.core.ResolvableType;
import org.springframework.core.test.io.support.MockSpringFactoriesLoader;
import org.springframework.core.test.tools.CompileWithForkedClassLoader;
@@ -60,6 +62,7 @@ import static org.assertj.core.api.Assertions.assertThat;
* {@link DefaultBeanRegistrationCodeFragments}.
*
* @author Phillip Webb
+ * @author Stephane Nicoll
*/
class BeanDefinitionMethodGeneratorTests {
@@ -99,6 +102,52 @@ class BeanDefinitionMethodGeneratorTests {
});
}
+ @Test
+ void generateBeanDefinitionMethodWhenHasInnerClassTargetMethodGeneratesMethod() {
+ this.beanFactory.registerBeanDefinition("testBeanConfiguration", new RootBeanDefinition(
+ InnerBeanConfiguration.Simple.class));
+ RootBeanDefinition beanDefinition = new RootBeanDefinition(SimpleBean.class);
+ beanDefinition.setFactoryBeanName("testBeanConfiguration");
+ beanDefinition.setFactoryMethodName("simpleBean");
+ RegisteredBean registeredBean = registerBean(beanDefinition);
+ BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
+ this.methodGeneratorFactory, registeredBean, null,
+ Collections.emptyList());
+ MethodReference method = generator.generateBeanDefinitionMethod(
+ this.generationContext, this.beanRegistrationsCode);
+ compile(method, (actual, compiled) -> {
+ SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
+ assertThat(sourceFile.getClassName()).endsWith("InnerBeanConfiguration__BeanDefinitions");
+ assertThat(sourceFile).contains("public static class Simple__BeanDefinitions")
+ .contains("Bean definitions for {@link InnerBeanConfiguration.Simple}")
+ .doesNotContain("Another__BeanDefinitions");
+
+ });
+ }
+
+ @Test
+ void generateBeanDefinitionMethodWhenHasNestedInnerClassTargetMethodGeneratesMethod() {
+ this.beanFactory.registerBeanDefinition("testBeanConfiguration", new RootBeanDefinition(
+ InnerBeanConfiguration.Simple.Another.class));
+ RootBeanDefinition beanDefinition = new RootBeanDefinition(SimpleBean.class);
+ beanDefinition.setFactoryBeanName("testBeanConfiguration");
+ beanDefinition.setFactoryMethodName("anotherBean");
+ RegisteredBean registeredBean = registerBean(beanDefinition);
+ BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
+ this.methodGeneratorFactory, registeredBean, null,
+ Collections.emptyList());
+ MethodReference method = generator.generateBeanDefinitionMethod(
+ this.generationContext, this.beanRegistrationsCode);
+ compile(method, (actual, compiled) -> {
+ SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
+ assertThat(sourceFile.getClassName()).endsWith("InnerBeanConfiguration__BeanDefinitions");
+ assertThat(sourceFile).contains("public static class Simple__BeanDefinitions")
+ .contains("Bean definitions for {@link InnerBeanConfiguration.Simple}")
+ .contains("public static class Another__BeanDefinitions")
+ .contains("Bean definitions for {@link InnerBeanConfiguration.Simple.Another}");
+ });
+ }
+
@Test
void generateBeanDefinitionMethodWhenHasGenericsGeneratesMethod() {
RegisteredBean registeredBean = registerBean(new RootBeanDefinition(
diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/InnerBeanConfiguration.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/InnerBeanConfiguration.java
new file mode 100644
index 00000000000..c3e5a7afcde
--- /dev/null
+++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/InnerBeanConfiguration.java
@@ -0,0 +1,40 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.beans.testfixture.beans.factory.aot;
+
+/**
+ * A configuration with inner classes.
+ *
+ * @author Stephane Nicoll
+ */
+public class InnerBeanConfiguration {
+
+ public static class Simple {
+
+ public SimpleBean simpleBean() {
+ return new SimpleBean();
+ }
+
+ public static class Another {
+
+ public SimpleBean anotherBean() {
+ return new SimpleBean();
+ }
+
+ }
+ }
+}
diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/SimpleBeanConfiguration.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/SimpleBeanConfiguration.java
index 2c4e5ede429..dde310e37d3 100644
--- a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/SimpleBeanConfiguration.java
+++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/SimpleBeanConfiguration.java
@@ -17,6 +17,7 @@
package org.springframework.beans.testfixture.beans.factory.aot;
/**
+ * A sample configuration.
*
* @author Stephane Nicoll
*/
diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java
index 5a1bb302187..6917e5051e4 100644
--- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java
+++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java
@@ -24,6 +24,7 @@ import java.util.function.Consumer;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.TypeSpec;
+import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
@@ -36,13 +37,18 @@ import org.springframework.util.Assert;
*/
public final class GeneratedClass {
+ @Nullable
+ private final GeneratedClass enclosingClass;
+
private final ClassName name;
private final GeneratedMethods methods;
private final Consumer type;
- private final Map methodNameSequenceGenerator = new ConcurrentHashMap<>();
+ private final Map declaredClasses;
+
+ private final Map methodNameSequenceGenerator;
/**
@@ -53,9 +59,17 @@ public final class GeneratedClass {
* @param type a {@link Consumer} used to build the type
*/
GeneratedClass(ClassName name, Consumer type) {
+ this(null, name, type);
+ }
+
+ private GeneratedClass(@Nullable GeneratedClass enclosingClass, ClassName name,
+ Consumer type) {
+ this.enclosingClass = enclosingClass;
this.name = name;
this.type = type;
this.methods = new GeneratedMethods(name, this::generateSequencedMethodName);
+ this.declaredClasses = new ConcurrentHashMap<>();
+ this.methodNameSequenceGenerator = new ConcurrentHashMap<>();
}
@@ -79,6 +93,16 @@ public final class GeneratedClass {
return (sequence > 0) ? name.toString() + sequence : name.toString();
}
+ /**
+ * Return the enclosing {@link GeneratedClass} or {@code null} if this
+ * instance represents a top-level class.
+ * @return the enclosing generated class, if any
+ */
+ @Nullable
+ public GeneratedClass getEnclosingClass() {
+ return this.enclosingClass;
+ }
+
/**
* Return the name of the generated class.
* @return the name of the generated class
@@ -95,10 +119,33 @@ public final class GeneratedClass {
return this.methods;
}
+ /**
+ * Get or add a nested generated class with the specified name. If this method
+ * has previously been called with the given {@code name}, the existing class
+ * will be returned, otherwise a new class will be generated.
+ * @param name the name of the nested class
+ * @param type a {@link Consumer} used to build the type
+ * @return an existing or newly generated class whose enclosing class is this class
+ */
+ public GeneratedClass getOrAdd(String name, Consumer type) {
+ ClassName className = this.name.nestedClass(name);
+ return this.declaredClasses.computeIfAbsent(className,
+ key -> new GeneratedClass(this, className, type));
+ }
+
JavaFile generateJavaFile() {
+ Assert.state(getEnclosingClass() == null,
+ "Java file cannot be generated for an inner class");
+ TypeSpec.Builder type = apply();
+ return JavaFile.builder(this.name.packageName(), type.build()).build();
+ }
+
+ private TypeSpec.Builder apply() {
TypeSpec.Builder type = getBuilder(this.type);
this.methods.doWithMethodSpecs(type::addMethod);
- return JavaFile.builder(this.name.packageName(), type.build()).build();
+ this.declaredClasses.values().forEach(declaredClass ->
+ type.addType(declaredClass.apply().build()));
+ return type;
}
private TypeSpec.Builder getBuilder(Consumer type) {
diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java
index bab98ebe6b2..23802ed069f 100644
--- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java
+++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java
@@ -18,6 +18,8 @@ package org.springframework.aot.generate;
import java.util.function.Consumer;
+import javax.lang.model.element.Modifier;
+
import org.junit.jupiter.api.Test;
import org.springframework.javapoet.ClassName;
@@ -35,21 +37,34 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
*/
class GeneratedClassTests {
+ private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test");
+
private static final Consumer emptyTypeCustomizer = type -> {};
private static final Consumer emptyMethodCustomizer = method -> {};
+ @Test
+ void getEnclosingNameOnTopLevelClassReturnsNull() {
+ GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME);
+ assertThat(generatedClass.getEnclosingClass()).isNull();
+ }
+
+ @Test
+ void getEnclosingNameOnInnerClassReturnsParent() {
+ GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME);
+ GeneratedClass innerGeneratedClass = generatedClass.getOrAdd("Test", emptyTypeCustomizer);
+ assertThat(innerGeneratedClass.getEnclosingClass()).isEqualTo(generatedClass);
+ }
+
@Test
void getNameReturnsName() {
- ClassName name = ClassName.bestGuess("com.example.Test");
- GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer);
- assertThat(generatedClass.getName()).isSameAs(name);
+ GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME);
+ assertThat(generatedClass.getName()).isSameAs(TEST_CLASS_NAME);
}
@Test
void reserveMethodNamesWhenNameUsedThrowsException() {
- ClassName name = ClassName.bestGuess("com.example.Test");
- GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer);
+ GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME);
generatedClass.getMethods().add("apply", emptyMethodCustomizer);
assertThatIllegalStateException()
.isThrownBy(() -> generatedClass.reserveMethodNames("apply"));
@@ -57,8 +72,7 @@ class GeneratedClassTests {
@Test
void reserveMethodNamesReservesNames() {
- ClassName name = ClassName.bestGuess("com.example.Test");
- GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer);
+ GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME);
generatedClass.reserveMethodNames("apply");
GeneratedMethod generatedMethod = generatedClass.getMethods().add("apply", emptyMethodCustomizer);
assertThat(generatedMethod.getName()).isEqualTo("apply1");
@@ -66,18 +80,45 @@ class GeneratedClassTests {
@Test
void generateMethodNameWhenAllEmptyPartsGeneratesSetName() {
- ClassName name = ClassName.bestGuess("com.example.Test");
- GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer);
+ GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME);
GeneratedMethod generatedMethod = generatedClass.getMethods().add("123", emptyMethodCustomizer);
assertThat(generatedMethod.getName()).isEqualTo("$$aot");
}
+ @Test
+ void getOrAddWhenRepeatReturnsSameGeneratedClass() {
+ GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME);
+ GeneratedClass innerGeneratedClass = generatedClass.getOrAdd("Inner", emptyTypeCustomizer);
+ GeneratedClass innerGeneratedClass2 = generatedClass.getOrAdd("Inner", emptyTypeCustomizer);
+ GeneratedClass innerGeneratedClass3 = generatedClass.getOrAdd("Inner", emptyTypeCustomizer);
+ assertThat(innerGeneratedClass).isSameAs(innerGeneratedClass2).isSameAs(innerGeneratedClass3);
+ }
+
@Test
void generateJavaFileIncludesGeneratedMethods() {
- ClassName name = ClassName.bestGuess("com.example.Test");
- GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer);
+ GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME);
generatedClass.getMethods().add("test", method -> method.addJavadoc("Test Method"));
assertThat(generatedClass.generateJavaFile().toString()).contains("Test Method");
}
+ @Test
+ void generateJavaFileIncludesDeclaredClasses() {
+ GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME);
+ generatedClass.getOrAdd("First", type -> type.modifiers.add(Modifier.STATIC));
+ generatedClass.getOrAdd("Second", type -> type.modifiers.add(Modifier.PRIVATE));
+ assertThat(generatedClass.generateJavaFile().toString())
+ .contains("static class First").contains("private class Second");
+ }
+
+ @Test
+ void generateJavaFileOnInnerClassThrowsException() {
+ GeneratedClass generatedClass = createGeneratedClass(TEST_CLASS_NAME)
+ .getOrAdd("Inner", emptyTypeCustomizer);
+ assertThatIllegalStateException().isThrownBy(generatedClass::generateJavaFile);
+ }
+
+ private static GeneratedClass createGeneratedClass(ClassName className) {
+ return new GeneratedClass(className, emptyTypeCustomizer);
+ }
+
}