From 8a4a89b9d92cbfd93aa766beae896df1ed62db4f Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Wed, 24 Aug 2022 07:21:02 +0200 Subject: [PATCH 1/3] Allow a MethodReference to be produced from a GeneratedMethod This commit updates GeneratedMethod and its underlying infrastructure to be able to produce a MethodReference. This simplifies the need when such a reference needs to be created manually and reuses more of what MethodReference has to offer. See gh-29005 --- ...opedProxyBeanRegistrationAotProcessor.java | 4 +- .../AutowiredAnnotationBeanPostProcessor.java | 4 +- .../aot/BeanDefinitionMethodGenerator.java | 6 +-- .../aot/BeanRegistrationsAotContribution.java | 3 +- .../aot/InstanceSupplierCodeGenerator.java | 4 +- .../BeanDefinitionMethodGeneratorTests.java | 6 +-- .../ConfigurationClassPostProcessor.java | 3 +- .../aot/generate/GeneratedClass.java | 2 +- .../aot/generate/GeneratedMethod.java | 22 ++++++++++- .../aot/generate/GeneratedMethods.java | 17 ++++++-- .../aot/generate/GeneratedMethodTests.java | 39 ++++++++++++++++--- .../aot/generate/GeneratedMethodsTests.java | 28 +++++++++---- ...agedTypesBeanRegistrationAotProcessor.java | 2 +- ...ersistenceAnnotationBeanPostProcessor.java | 4 +- 14 files changed, 101 insertions(+), 43 deletions(-) diff --git a/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java b/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java index 36d44764b15..5ad52df235b 100644 --- a/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java +++ b/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java @@ -163,8 +163,8 @@ class ScopedProxyBeanRegistrationAotProcessor implements BeanRegistrationAotProc method.addStatement("return ($T) factory.getObject()", beanClass); }); - return CodeBlock.of("$T.of($T::$L)", InstanceSupplier.class, - beanRegistrationCode.getClassName(), generatedMethod.getName()); + return CodeBlock.of("$T.of($L)", InstanceSupplier.class, + generatedMethod.toMethodReference().toCodeBlock()); } } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java index 2154610858f..ace0f093ac6 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java @@ -44,7 +44,6 @@ import org.springframework.aot.generate.AccessVisibility; import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.ExecutableMode; import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.BeanUtils; @@ -944,8 +943,7 @@ public class AutowiredAnnotationBeanPostProcessor implements SmartInstantiationA method.returns(this.target); method.addCode(generateMethodCode(generationContext.getRuntimeHints())); }); - beanRegistrationCode.addInstancePostProcessor( - MethodReference.ofStatic(generatedClass.getName(), generateMethod.getName())); + beanRegistrationCode.addInstancePostProcessor(generateMethod.toMethodReference()); if (this.candidateResolver != null) { registerHints(generationContext.getRuntimeHints()); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java index 8f08985b792..29703c3eeb5 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java @@ -107,16 +107,14 @@ class BeanDefinitionMethodGenerator { GeneratedMethod generatedMethod = generateBeanDefinitionMethod( generationContext, generatedClass.getName(), generatedMethods, codeFragments, Modifier.PUBLIC); - return MethodReference.ofStatic(generatedClass.getName(), - generatedMethod.getName()); + return generatedMethod.toMethodReference(); } GeneratedMethods generatedMethods = beanRegistrationsCode.getMethods() .withPrefix(getName()); GeneratedMethod generatedMethod = generateBeanDefinitionMethod(generationContext, beanRegistrationsCode.getClassName(), generatedMethods, codeFragments, Modifier.PRIVATE); - return MethodReference.ofStatic(beanRegistrationsCode.getClassName(), - generatedMethod.getName()); + return generatedMethod.toMethodReference(); } private BeanRegistrationCodeFragments getCodeFragments(GenerationContext generationContext, diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java index 2dad04b877f..fc8ca237b8b 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java @@ -65,8 +65,7 @@ class BeanRegistrationsAotContribution BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator(generatedClass); GeneratedMethod generatedMethod = codeGenerator.getMethods().add("registerBeanDefinitions", method -> generateRegisterMethod(method, generationContext, codeGenerator)); - beanFactoryInitializationCode.addInitializer( - MethodReference.of(generatedClass.getName(), generatedMethod.getName())); + beanFactoryInitializationCode.addInitializer(generatedMethod.toMethodReference()); } private void generateRegisterMethod(MethodSpec.Builder method, diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java index 66bba162047..1b597a36d37 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java @@ -296,8 +296,8 @@ class InstanceSupplierCodeGenerator { REGISTERED_BEAN_PARAMETER_NAME, declaringClass, factoryMethodName, args); } - private CodeBlock generateReturnStatement(GeneratedMethod getInstanceMethod) { - return CodeBlock.of("$T.$L()", this.className, getInstanceMethod.getName()); + private CodeBlock generateReturnStatement(GeneratedMethod generatedMethod) { + return generatedMethod.toMethodReference().toInvokeCodeBlock(); } private CodeBlock generateWithGeneratorCode(boolean hasArguments, CodeBlock newInstance) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 53c182ddc58..3cc278470a1 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -129,8 +129,7 @@ class BeanDefinitionMethodGeneratorTests { .addParameter(RegisteredBean.class, "registeredBean") .addParameter(TestBean.class, "testBean") .returns(TestBean.class).addCode("return new $T($S);", TestBean.class, "postprocessed")); - beanRegistrationCode.addInstancePostProcessor(MethodReference.ofStatic( - beanRegistrationCode.getClassName(), generatedMethod.getName())); + beanRegistrationCode.addInstancePostProcessor(generatedMethod.toMethodReference()); }; List aotContributions = Collections .singletonList(aotContribution); @@ -167,8 +166,7 @@ class BeanDefinitionMethodGeneratorTests { .addParameter(RegisteredBean.class, "registeredBean") .addParameter(TestBean.class, "testBean") .returns(TestBean.class).addCode("return new $T($S);", TestBean.class, "postprocessed")); - beanRegistrationCode.addInstancePostProcessor(MethodReference.ofStatic( - beanRegistrationCode.getClassName(), generatedMethod.getName())); + beanRegistrationCode.addInstancePostProcessor(generatedMethod.toMethodReference()); }; List aotContributions = Collections .singletonList(aotContribution); diff --git a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java index 3c2fa738179..02db1a37c9c 100644 --- a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java +++ b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java @@ -33,7 +33,6 @@ import org.apache.commons.logging.LogFactory; import org.springframework.aop.framework.autoproxy.AutoProxyUtils; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.ResourceHints; import org.springframework.aot.hint.TypeReference; import org.springframework.beans.PropertyValues; @@ -536,7 +535,7 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo .add("addImportAwareBeanPostProcessors", method -> generateAddPostProcessorMethod(method, mappings)); beanFactoryInitializationCode - .addInitializer(MethodReference.of(generatedMethod.getName())); + .addInitializer(generatedMethod.toMethodReference()); ResourceHints hints = generationContext.getRuntimeHints().resources(); mappings.forEach( (target, from) -> hints.registerType(TypeReference.of(from))); diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java index be591208fff..5a1bb302187 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java @@ -55,7 +55,7 @@ public final class GeneratedClass { GeneratedClass(ClassName name, Consumer type) { this.name = name; this.type = type; - this.methods = new GeneratedMethods(this::generateSequencedMethodName); + this.methods = new GeneratedMethods(name, this::generateSequencedMethodName); } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java index 4d351241d0d..7247c212d0c 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java @@ -18,6 +18,9 @@ package org.springframework.aot.generate; import java.util.function.Consumer; +import javax.lang.model.element.Modifier; + +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.MethodSpec; import org.springframework.util.Assert; @@ -25,11 +28,14 @@ import org.springframework.util.Assert; * A generated method. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 * @see GeneratedMethods */ public final class GeneratedMethod { + private final ClassName className; + private final String name; private final MethodSpec methodSpec; @@ -39,12 +45,14 @@ public final class GeneratedMethod { * Create a new {@link GeneratedMethod} instance with the given name. This * constructor is package-private since names should only be generated via * {@link GeneratedMethods}. + * @param className the declaring class of the method * @param name the generated method name * @param method consumer to generate the method */ - GeneratedMethod(String name, Consumer method) { + GeneratedMethod(ClassName className, String name, Consumer method) { + this.className = className; this.name = name; - MethodSpec.Builder builder = MethodSpec.methodBuilder(getName()); + MethodSpec.Builder builder = MethodSpec.methodBuilder(this.name); method.accept(builder); this.methodSpec = builder.build(); Assert.state(this.name.equals(this.methodSpec.name), @@ -60,6 +68,16 @@ public final class GeneratedMethod { return this.name; } + /** + * Return a {@link MethodReference} to this generated method. + * @return a method reference + */ + public MethodReference toMethodReference() { + return (this.methodSpec.modifiers.contains(Modifier.STATIC) + ? MethodReference.ofStatic(this.className, this.name) + : MethodReference.of(this.className, this.name)); + } + /** * Return the {@link MethodSpec} for this generated method. * @return the method spec diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethods.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethods.java index e16779ab779..0c65c37582a 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethods.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethods.java @@ -22,6 +22,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Stream; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.MethodSpec.Builder; import org.springframework.util.Assert; @@ -30,11 +31,14 @@ import org.springframework.util.Assert; * A managed collection of generated methods. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 * @see GeneratedMethod */ public class GeneratedMethods { + private final ClassName className; + private final Function methodNameGenerator; private final MethodName prefix; @@ -44,18 +48,22 @@ public class GeneratedMethods { /** * Create a new {@link GeneratedMethods} using the specified method name * generator. + * @param className the declaring class name * @param methodNameGenerator the method name generator */ - GeneratedMethods(Function methodNameGenerator) { + GeneratedMethods(ClassName className, Function methodNameGenerator) { + Assert.notNull(className, "'className' must not be null"); Assert.notNull(methodNameGenerator, "'methodNameGenerator' must not be null"); + this.className = className; this.methodNameGenerator = methodNameGenerator; this.prefix = MethodName.NONE; this.generatedMethods = new ArrayList<>(); } - private GeneratedMethods(Function methodNameGenerator, + private GeneratedMethods(ClassName className, Function methodNameGenerator, MethodName prefix, List generatedMethods) { + this.className = className; this.methodNameGenerator = methodNameGenerator; this.prefix = prefix; this.generatedMethods = generatedMethods; @@ -82,7 +90,7 @@ public class GeneratedMethods { Assert.notNull(suggestedNameParts, "'suggestedNameParts' must not be null"); Assert.notNull(method, "'method' must not be null"); String generatedName = this.methodNameGenerator.apply(this.prefix.and(suggestedNameParts)); - GeneratedMethod generatedMethod = new GeneratedMethod(generatedName, method); + GeneratedMethod generatedMethod = new GeneratedMethod(this.className, generatedName, method); this.generatedMethods.add(generatedMethod); return generatedMethod; } @@ -90,7 +98,8 @@ public class GeneratedMethods { public GeneratedMethods withPrefix(String prefix) { Assert.notNull(prefix, "'prefix' must not be null"); - return new GeneratedMethods(this.methodNameGenerator, this.prefix.and(prefix), this.generatedMethods); + return new GeneratedMethods(this.className, this.methodNameGenerator, + this.prefix.and(prefix), this.generatedMethods); } /** diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java index f080ce8ed20..34ac962746a 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java @@ -18,8 +18,12 @@ package org.springframework.aot.generate; import java.util.function.Consumer; +import javax.lang.model.element.Modifier; + import org.junit.jupiter.api.Test; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; import static org.assertj.core.api.Assertions.assertThat; @@ -29,30 +33,55 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException; * Tests for {@link GeneratedMethod}. * * @author Phillip Webb + * @author Stephane Nicoll */ class GeneratedMethodTests { - private static final Consumer methodSpecCustomizer = method -> {}; + private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test"); + + private static final Consumer emptyMethod = method -> {}; private static final String NAME = "spring"; @Test void getNameReturnsName() { - GeneratedMethod generatedMethod = new GeneratedMethod(NAME, methodSpecCustomizer); + GeneratedMethod generatedMethod = new GeneratedMethod(TEST_CLASS_NAME, NAME, emptyMethod); assertThat(generatedMethod.getName()).isSameAs(NAME); } @Test void generateMethodSpecReturnsMethodSpec() { - GeneratedMethod generatedMethod = new GeneratedMethod(NAME, method -> method.addJavadoc("Test")); + GeneratedMethod generatedMethod = create(method -> method.addJavadoc("Test")); assertThat(generatedMethod.getMethodSpec().javadoc).asString().contains("Test"); } @Test void generateMethodSpecWhenMethodNameIsChangedThrowsException() { assertThatIllegalStateException().isThrownBy(() -> - new GeneratedMethod(NAME, method -> method.setName("badname")).getMethodSpec()) - .withMessage("'method' consumer must not change the generated method name"); + create(method -> method.setName("badname")).getMethodSpec()) + .withMessage("'method' consumer must not change the generated method name"); + } + + @Test + void toMethodReferenceWithInstanceMethod() { + GeneratedMethod generatedMethod = create(emptyMethod); + MethodReference methodReference = generatedMethod.toMethodReference(); + assertThat(methodReference).isNotNull(); + assertThat(methodReference.toInvokeCodeBlock("test")) + .isEqualTo(CodeBlock.of("test.spring()")); + } + + @Test + void toMethodReferenceWithStaticMethod() { + GeneratedMethod generatedMethod = create(method -> method.addModifiers(Modifier.STATIC)); + MethodReference methodReference = generatedMethod.toMethodReference(); + assertThat(methodReference).isNotNull(); + assertThat(methodReference.toInvokeCodeBlock()) + .isEqualTo(CodeBlock.of("com.example.Test.spring()")); + } + + private GeneratedMethod create(Consumer method) { + return new GeneratedMethod(TEST_CLASS_NAME, NAME, method); } } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodsTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodsTests.java index 2ae2517e6c8..1b051f691e0 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodsTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodsTests.java @@ -23,6 +23,7 @@ import java.util.function.Function; import org.junit.jupiter.api.Test; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.MethodSpec; import static org.assertj.core.api.Assertions.assertThat; @@ -32,38 +33,49 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException * Tests for {@link GeneratedMethods}. * * @author Phillip Webb + * @author Stephane Nicoll */ class GeneratedMethodsTests { + private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test"); + private static final Consumer methodSpecCustomizer = method -> {}; - private final GeneratedMethods methods = new GeneratedMethods(MethodName::toString); + private final GeneratedMethods methods = new GeneratedMethods(TEST_CLASS_NAME, MethodName::toString); + + @Test + void createWhenClassNameIsNullThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> + new GeneratedMethods(null, MethodName::toString)) + .withMessage("'className' must not be null"); + } @Test void createWhenMethodNameGeneratorIsNullThrowsException() { - assertThatIllegalArgumentException().isThrownBy(() -> new GeneratedMethods(null)) + assertThatIllegalArgumentException().isThrownBy(() -> + new GeneratedMethods(TEST_CLASS_NAME, null)) .withMessage("'methodNameGenerator' must not be null"); } @Test void createWithExistingGeneratorUsesGenerator() { Function generator = name -> "__" + name.toString(); - GeneratedMethods methods = new GeneratedMethods(generator); + GeneratedMethods methods = new GeneratedMethods(TEST_CLASS_NAME, generator); assertThat(methods.add("test", methodSpecCustomizer).getName()).hasToString("__test"); } @Test void addWithStringNameWhenSuggestedMethodIsNullThrowsException() { assertThatIllegalArgumentException().isThrownBy(() -> - this.methods.add((String) null, methodSpecCustomizer)) - .withMessage("'suggestedName' must not be null"); + this.methods.add((String) null, methodSpecCustomizer)) + .withMessage("'suggestedName' must not be null"); } @Test void addWithStringNameWhenMethodIsNullThrowsException() { assertThatIllegalArgumentException().isThrownBy(() -> - this.methods.add("test", null)) - .withMessage("'method' must not be null"); + this.methods.add("test", null)) + .withMessage("'method' must not be null"); } @Test @@ -71,7 +83,7 @@ class GeneratedMethodsTests { this.methods.add("springBeans", methodSpecCustomizer); this.methods.add("springContext", methodSpecCustomizer); assertThat(this.methods.stream().map(GeneratedMethod::getName).map(Object::toString)) - .containsExactly("springBeans", "springContext"); + .containsExactly("springBeans", "springContext"); } @Test diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java index cf88dc8acbc..e8d8fe3848b 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java @@ -99,7 +99,7 @@ class PersistenceManagedTypesBeanRegistrationAotProcessor implements BeanRegistr List.class, toCodeBlock(persistenceManagedTypes.getManagedPackages())); method.addStatement("return $T.of($L, $L)", beanType, "managedClassNames", "managedPackages"); }); - return CodeBlock.of("() -> $T.$L()", beanRegistrationCode.getClassName(), generatedMethod.getName()); + return generatedMethod.toMethodReference().toCodeBlock(); } private CodeBlock toCodeBlock(List values) { diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java index 47078596607..abac448ac02 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java @@ -43,7 +43,6 @@ import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.BeanUtils; import org.springframework.beans.PropertyValues; @@ -797,8 +796,7 @@ public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwar method.returns(this.target); method.addCode(generateMethodCode(generationContext.getRuntimeHints(), generatedClass.getMethods())); }); - beanRegistrationCode.addInstancePostProcessor(MethodReference - .ofStatic(generatedClass.getName(), generatedMethod.getName())); + beanRegistrationCode.addInstancePostProcessor(generatedMethod.toMethodReference()); } private CodeBlock generateMethodCode(RuntimeHints hints, GeneratedMethods generatedMethods) { From ae706f3954b1823626bee1373f306b20fb9239d3 Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Wed, 24 Aug 2022 07:53:38 +0200 Subject: [PATCH 2/3] Allow MethodReference to define a more flexible signature This commit moves MethodReference to an interface with a default implementation that relies on a MethodSpec. Such an arrangement avoid the need of specifying attributes of the method such as whether it is static or not. The resolution of the invocation block now takes an ArgumentCodeGenerator rather than the raw arguments. Doing so gives the opportunity to create more flexible signatures. See gh-29005 --- ...roxyBeanRegistrationAotProcessorTests.java | 6 +- .../aot/BeanRegistrationsAotContribution.java | 5 +- .../DefaultBeanRegistrationCodeFragments.java | 3 +- .../aot/InstanceSupplierCodeGenerator.java | 4 +- ...nBeanRegistrationAotContributionTests.java | 6 +- .../BeanDefinitionMethodGeneratorTests.java | 5 +- ...BeanRegistrationsAotContributionTests.java | 6 +- .../MockBeanFactoryInitializationCode.java | 4 + ...ionContextInitializationCodeGenerator.java | 8 +- ...lassPostProcessorAotContributionTests.java | 6 +- .../aot/generate/DefaultMethodReference.java | 134 +++++++++ .../aot/generate/GeneratedMethod.java | 6 +- .../aot/generate/MethodReference.java | 271 ++++++------------ .../generate/DefaultMethodReferenceTests.java | 199 +++++++++++++ .../aot/generate/GeneratedMethodTests.java | 8 +- .../aot/generate/MethodReferenceTests.java | 226 --------------- 16 files changed, 469 insertions(+), 428 deletions(-) create mode 100644 spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java create mode 100644 spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java delete mode 100644 spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java diff --git a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java index e29a3192492..1ce8b798680 100644 --- a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java +++ b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java @@ -26,6 +26,7 @@ import org.junit.jupiter.api.Test; import org.springframework.aop.framework.AopInfrastructureBean; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.Compiled; import org.springframework.aot.test.generate.compile.TestCompiler; @@ -139,11 +140,14 @@ class ScopedProxyBeanRegistrationAotProcessorTests { MethodReference methodReference = this.beanFactoryInitializationCode .getInitializers().get(0); this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"), + this.beanFactoryInitializationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .addStatement(methodInvocation) .build()); }); this.generationContext.writeGeneratedContent(); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java index fc8ca237b8b..a80db112d35 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java @@ -25,6 +25,7 @@ import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; @@ -81,9 +82,11 @@ class BeanRegistrationsAotContribution MethodReference beanDefinitionMethod = beanDefinitionMethodGenerator .generateBeanDefinitionMethod(generationContext, beanRegistrationsCode); + CodeBlock methodInvocation = beanDefinitionMethod.toInvokeCodeBlock( + ArgumentCodeGenerator.none(), beanRegistrationsCode.getClassName()); code.addStatement("$L.registerBeanDefinition($S, $L)", BEAN_FACTORY_PARAMETER_NAME, beanName, - beanDefinitionMethod.toInvokeCodeBlock()); + methodInvocation); }); method.addCode(code.build()); } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index bf5719eccdb..fa04f862115 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -24,6 +24,7 @@ import java.util.function.Predicate; import org.springframework.aot.generate.AccessVisibility; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionHolder; @@ -156,7 +157,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments MethodReference generatedMethod = methodGenerator .generateBeanDefinitionMethod(generationContext, this.beanRegistrationsCode); - return generatedMethod.toInvokeCodeBlock(); + return generatedMethod.toInvokeCodeBlock(ArgumentCodeGenerator.none()); } return null; } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java index 1b597a36d37..e6cb5df84f1 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java @@ -28,6 +28,7 @@ import org.springframework.aot.generate.AccessVisibility; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.ExecutableMode; import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.RegisteredBean; @@ -297,7 +298,8 @@ class InstanceSupplierCodeGenerator { } private CodeBlock generateReturnStatement(GeneratedMethod generatedMethod) { - return generatedMethod.toMethodReference().toInvokeCodeBlock(); + return generatedMethod.toMethodReference().toInvokeCodeBlock( + ArgumentCodeGenerator.none(), this.className); } private CodeBlock generateWithGeneratorCode(boolean hasArguments, CodeBlock newInstance) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java index 2f4fe187b13..219093424e7 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java @@ -24,6 +24,7 @@ import javax.lang.model.element.Modifier; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess; @@ -161,13 +162,16 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { Class target = registeredBean.getBeanClass(); MethodReference methodReference = this.beanRegistrationCode.getInstancePostProcessors().get(0); this.beanRegistrationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(RegisteredBean.class, "registeredBean").and(target, "instance"), + this.beanRegistrationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(BiFunction.class, RegisteredBean.class, target, target)); type.addMethod(MethodSpec.methodBuilder("apply") .addModifiers(Modifier.PUBLIC) .addParameter(RegisteredBean.class, "registeredBean") .addParameter(target, "instance").returns(target) - .addStatement("return $L", methodReference.toInvokeCodeBlock(CodeBlock.of("registeredBean"), CodeBlock.of("instance"))) + .addStatement("return $L", methodInvocation) .build()); }); diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 3cc278470a1..9020bf8626a 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -30,6 +30,7 @@ import org.junit.jupiter.api.Test; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess; import org.springframework.aot.test.generate.compile.Compiled; @@ -414,12 +415,14 @@ class BeanDefinitionMethodGeneratorTests { private void compile(MethodReference method, BiConsumer result) { this.beanRegistrationsCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = method.toInvokeCodeBlock(ArgumentCodeGenerator.none(), + this.beanRegistrationsCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Supplier.class, BeanDefinition.class)); type.addMethod(MethodSpec.methodBuilder("get") .addModifiers(Modifier.PUBLIC) .returns(BeanDefinition.class) - .addCode("return $L;", method.toInvokeCodeBlock()).build()); + .addCode("return $L;", methodInvocation).build()); }); this.generationContext.writeGeneratedContent(); TestCompiler.forSystem().withFiles(this.generationContext.getGeneratedFiles()).compile(compiled -> diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java index 1eee3a83434..bd2bba145e9 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java @@ -31,6 +31,7 @@ import org.junit.jupiter.api.Test; import org.springframework.aot.generate.ClassNameGenerator; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.TestTarget; import org.springframework.aot.test.generate.compile.Compiled; @@ -155,11 +156,14 @@ class BeanRegistrationsAotContributionTests { MethodReference methodReference = this.beanFactoryInitializationCode .getInitializers().get(0); this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"), + this.beanFactoryInitializationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .addStatement(methodInvocation) .build()); }); this.generationContext.writeGeneratedContent(); diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java index 01a78dda3b4..c6986c7c4b0 100644 --- a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java @@ -25,6 +25,7 @@ import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; +import org.springframework.javapoet.ClassName; /** * Mock {@link BeanFactoryInitializationCode} implementation. @@ -46,6 +47,9 @@ public class MockBeanFactoryInitializationCode implements BeanFactoryInitializat .addForFeature("TestCode", this.typeBuilder); } + public ClassName getClassName() { + return this.generatedClass.getName(); + } public DeferredTypeBuilder getTypeBuilder() { return this.typeBuilder; diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java index 29f502c7353..b2bf870e2f5 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java @@ -25,6 +25,7 @@ import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.context.ApplicationContextInitializer; @@ -88,12 +89,17 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia BEAN_FACTORY_VARIABLE, ContextAnnotationAutowireCandidateResolver.class); code.addStatement("$L.setDependencyComparator($T.INSTANCE)", BEAN_FACTORY_VARIABLE, AnnotationAwareOrderComparator.class); + ArgumentCodeGenerator argCodeGenerator = createInitializerMethodsArgumentCodeGenerator(); for (MethodReference initializer : this.initializers) { - code.addStatement(initializer.toInvokeCodeBlock(CodeBlock.of(BEAN_FACTORY_VARIABLE))); + code.addStatement(initializer.toInvokeCodeBlock(argCodeGenerator, this.generatedClass.getName())); } return code.build(); } + private ArgumentCodeGenerator createInitializerMethodsArgumentCodeGenerator() { + return ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, BEAN_FACTORY_VARIABLE); + } + GeneratedClass getGeneratedClass() { return this.generatedClass; } diff --git a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java index 24bef4dcf15..29961c61070 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java @@ -27,6 +27,7 @@ import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.ResourcePatternHint; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.Compiled; @@ -162,11 +163,14 @@ class ConfigurationClassPostProcessorAotContributionTests { private void compile(BiConsumer, Compiled> result) { MethodReference methodReference = this.beanFactoryInitializationCode.getInitializers().get(0); this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"), + this.beanFactoryInitializationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .addStatement(methodInvocation) .build()); }); this.generationContext.writeGeneratedContent(); diff --git a/spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java b/spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java new file mode 100644 index 00000000000..b3a3ab117d8 --- /dev/null +++ b/spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java @@ -0,0 +1,134 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.generate; + +import java.util.ArrayList; +import java.util.List; + +import javax.lang.model.element.Modifier; + +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.TypeName; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Default {@link MethodReference} implementation based on a {@link MethodSpec}. + * + * @author Stephane Nicoll + * @author Phillip Webb + * @since 6.0 + */ +public class DefaultMethodReference implements MethodReference { + + private final MethodSpec method; + + @Nullable + private final ClassName declaringClass; + + public DefaultMethodReference(MethodSpec method, @Nullable ClassName declaringClass) { + this.method = method; + this.declaringClass = declaringClass; + } + + @Override + public CodeBlock toCodeBlock() { + String methodName = this.method.name; + if (isStatic()) { + Assert.notNull(this.declaringClass, "static method reference must define a declaring class"); + return CodeBlock.of("$T::$L", this.declaringClass, methodName); + } + else { + return CodeBlock.of("this::$L", methodName); + } + } + + public CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator, + @Nullable ClassName targetClassName) { + String methodName = this.method.name; + CodeBlock.Builder code = CodeBlock.builder(); + if (isStatic()) { + Assert.notNull(this.declaringClass, "static method reference must define a declaring class"); + if (isSameDeclaringClass(targetClassName)) { + code.add("$L", methodName); + } + else { + code.add("$T.$L", this.declaringClass, methodName); + } + } + else { + if (!isSameDeclaringClass(targetClassName)) { + code.add(instantiateDeclaringClass(this.declaringClass)); + } + code.add("$L", methodName); + } + code.add("("); + addArguments(code, argumentCodeGenerator); + code.add(")"); + return code.build(); + } + + /** + * Add the code for the method arguments using the specified + * {@link ArgumentCodeGenerator} if necessary. + * @param code the code builder to use to add method arguments + * @param argumentCodeGenerator the code generator to use + */ + protected void addArguments(CodeBlock.Builder code, ArgumentCodeGenerator argumentCodeGenerator) { + List arguments = new ArrayList<>(); + TypeName[] argumentTypes = this.method.parameters.stream() + .map(parameter -> parameter.type).toArray(TypeName[]::new); + for (int i = 0; i < argumentTypes.length; i++) { + TypeName argumentType = argumentTypes[i]; + CodeBlock argumentCode = argumentCodeGenerator.generateCode(argumentType); + if (argumentCode == null) { + throw new IllegalArgumentException("Could not generate code for " + this + + ": parameter " + i + " of type " + argumentType + " is not supported"); + } + arguments.add(argumentCode); + } + code.add(CodeBlock.join(arguments, ", ")); + } + + protected CodeBlock instantiateDeclaringClass(ClassName declaringClass) { + return CodeBlock.of("new $T().", declaringClass); + } + + private boolean isStatic() { + return this.method.modifiers.contains(Modifier.STATIC); + } + + private boolean isSameDeclaringClass(ClassName declaringClass) { + return this.declaringClass == null || this.declaringClass.equals(declaringClass); + } + + @Override + public String toString() { + String methodName = this.method.name; + if (isStatic()) { + return this.declaringClass + "::" + methodName; + } + else { + return ((this.declaringClass != null) + ? "<" + this.declaringClass + ">" : "") + + "::" + methodName; + } + } + +} diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java index 7247c212d0c..b09d36f61f2 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java @@ -18,8 +18,6 @@ package org.springframework.aot.generate; import java.util.function.Consumer; -import javax.lang.model.element.Modifier; - import org.springframework.javapoet.ClassName; import org.springframework.javapoet.MethodSpec; import org.springframework.util.Assert; @@ -73,9 +71,7 @@ public final class GeneratedMethod { * @return a method reference */ public MethodReference toMethodReference() { - return (this.methodSpec.modifiers.contains(Modifier.STATIC) - ? MethodReference.ofStatic(this.className, this.name) - : MethodReference.of(this.className, this.name)); + return new DefaultMethodReference(this.methodSpec, this.className); } /** diff --git a/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java b/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java index 80359dd314b..f6dda971007 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java @@ -16,223 +16,124 @@ package org.springframework.aot.generate; +import java.util.function.Function; + import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.TypeName; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; /** - * A reference to a static or instance method. + * A reference to a method with convenient code generation for + * referencing, or invoking it. * + * @author Stephane Nicoll * @author Phillip Webb * @since 6.0 */ -public final class MethodReference { - - private final Kind kind; - - @Nullable - private final ClassName declaringClass; - - private final String methodName; - - - private MethodReference(Kind kind, @Nullable ClassName declaringClass, - String methodName) { - this.kind = kind; - this.declaringClass = declaringClass; - this.methodName = methodName; - } - - - /** - * Create a new method reference that refers to the given instance method. - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference of(String methodName) { - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.INSTANCE, null, methodName); - } - - /** - * Create a new method reference that refers to the given instance method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference of(Class declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.INSTANCE, ClassName.get(declaringClass), - methodName); - } - - /** - * Create a new method reference that refers to the given instance method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference of(ClassName declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.INSTANCE, declaringClass, methodName); - } - - /** - * Create a new method reference that refers to the given static method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference ofStatic(Class declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.STATIC, ClassName.get(declaringClass), - methodName); - } - - /** - * Create a new method reference that refers to the given static method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference ofStatic(ClassName declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.STATIC, declaringClass, methodName); - } - - - /** - * Return the referenced declaring class. - * @return the declaring class - */ - @Nullable - public ClassName getDeclaringClass() { - return this.declaringClass; - } - - /** - * Return the referenced method name. - * @return the method name - */ - public String getMethodName() { - return this.methodName; - } +public interface MethodReference { /** * Return this method reference as a {@link CodeBlock}. If the reference is * to an instance method then {@code this::} will be returned. * @return a code block for the method reference. - * @see #toCodeBlock(String) */ - public CodeBlock toCodeBlock() { - return toCodeBlock(null); - } + CodeBlock toCodeBlock(); /** - * Return this method reference as a {@link CodeBlock}. If the reference is - * to an instance method and {@code instanceVariable} is {@code null} then - * {@code this::} will be returned. No {@code instanceVariable} - * can be specified for static method references. - * @param instanceVariable the instance variable or {@code null} - * @return a code block for the method reference. - * @see #toCodeBlock(String) + * Return this method reference as a {@link CodeBlock} using the specified + * {@link ArgumentCodeGenerator}. + * @param argumentCodeGenerator the argument code generator to use + * @return a code block to invoke the method */ - public CodeBlock toCodeBlock(@Nullable String instanceVariable) { - return switch (this.kind) { - case INSTANCE -> toCodeBlockForInstance(instanceVariable); - case STATIC -> toCodeBlockForStatic(instanceVariable); - }; - } - - private CodeBlock toCodeBlockForInstance(@Nullable String instanceVariable) { - instanceVariable = (instanceVariable != null) ? instanceVariable : "this"; - return CodeBlock.of("$L::$L", instanceVariable, this.methodName); - } - - private CodeBlock toCodeBlockForStatic(@Nullable String instanceVariable) { - Assert.isTrue(instanceVariable == null, - "'instanceVariable' must be null for static method references"); - return CodeBlock.of("$T::$L", this.declaringClass, this.methodName); + default CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator) { + return toInvokeCodeBlock(argumentCodeGenerator, null); } /** - * Return this method reference as an invocation {@link CodeBlock}. - * @param arguments the method arguments - * @return a code back to invoke the method + * Return this method reference as a {@link CodeBlock} using the specified + * {@link ArgumentCodeGenerator}. The {@code targetClassName} defines the + * context in which the method invocation is added. + *

If the caller has an instance of the type in which this method is + * defined, it can hint that by specifying the type as a target class. + * @param argumentCodeGenerator the argument code generator to use + * @param targetClassName the target class name + * @return a code block to invoke the method */ - public CodeBlock toInvokeCodeBlock(CodeBlock... arguments) { - return toInvokeCodeBlock(null, arguments); - } + CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator, @Nullable ClassName targetClassName); + /** - * Return this method reference as an invocation {@link CodeBlock}. - * @param instanceVariable the instance variable or {@code null} - * @param arguments the method arguments - * @return a code back to invoke the method + * Strategy for generating code for arguments based on their type. */ - public CodeBlock toInvokeCodeBlock(@Nullable String instanceVariable, - CodeBlock... arguments) { - - return switch (this.kind) { - case INSTANCE -> toInvokeCodeBlockForInstance(instanceVariable, arguments); - case STATIC -> toInvokeCodeBlockForStatic(instanceVariable, arguments); - }; - } - - private CodeBlock toInvokeCodeBlockForInstance(@Nullable String instanceVariable, - CodeBlock[] arguments) { - - CodeBlock.Builder code = CodeBlock.builder(); - if (instanceVariable != null) { - code.add("$L.", instanceVariable); - } - else if (this.declaringClass != null) { - code.add("new $T().", this.declaringClass); + interface ArgumentCodeGenerator { + + /** + * Generate the code for the given argument type. If this type is + * not supported, return {@code null}. + * @param argumentType the argument type + * @return the code for this argument, or {@code null} + */ + @Nullable + CodeBlock generateCode(TypeName argumentType); + + /** + * Factory method that returns an {@link ArgumentCodeGenerator} that + * always returns {@code null}. + * @return a new {@link ArgumentCodeGenerator} instance + */ + static ArgumentCodeGenerator none() { + return from(type -> null); } - code.add("$L", this.methodName); - addArguments(code, arguments); - return code.build(); - } - private CodeBlock toInvokeCodeBlockForStatic(@Nullable String instanceVariable, - CodeBlock[] arguments) { - - Assert.isTrue(instanceVariable == null, - "'instanceVariable' must be null for static method references"); - CodeBlock.Builder code = CodeBlock.builder(); - code.add("$T.$L", this.declaringClass, this.methodName); - addArguments(code, arguments); - return code.build(); - } + /** + * Factory method that can be used to create an {@link ArgumentCodeGenerator} + * that support only the given argument type. + * @param argumentType the argument type + * @param argumentCode the code for an argument of that type + * @return a new {@link ArgumentCodeGenerator} instance + */ + static ArgumentCodeGenerator of(Class argumentType, String argumentCode) { + return from(candidateType -> (candidateType.equals(ClassName.get(argumentType)) + ? CodeBlock.of(argumentCode) : null)); + } - private void addArguments(CodeBlock.Builder code, CodeBlock[] arguments) { - code.add("("); - for (int i = 0; i < arguments.length; i++) { - if (i != 0) { - code.add(", "); - } - code.add(arguments[i]); + /** + * Factory method that creates a new {@link ArgumentCodeGenerator} from + * a lambda friendly function. The given function is provided with the + * argument type and must provide the code to use or {@code null} if + * the type is not supported. + * @param function the resolver function + * @return a new {@link ArgumentCodeGenerator} instance backed by the function + */ + static ArgumentCodeGenerator from(Function function) { + return function::apply; } - code.add(")"); - } - @Override - public String toString() { - return switch (this.kind) { - case INSTANCE -> ((this.declaringClass != null) ? "<" + this.declaringClass + ">" - : "") + "::" + this.methodName; - case STATIC -> this.declaringClass + "::" + this.methodName; - }; - } + /** + * Create a new composed {@link ArgumentCodeGenerator} by combining this + * generator with supporting the given argument type. + * @param argumentType the argument type + * @param argumentCode the code for an argument of that type + * @return a new composite {@link ArgumentCodeGenerator} instance + */ + default ArgumentCodeGenerator and(Class argumentType, String argumentCode) { + return and(ArgumentCodeGenerator.of(argumentType, argumentCode)); + } + /** + * Create a new composed {@link ArgumentCodeGenerator} by combining this + * generator with the given generator. + * @param argumentCodeGenerator the argument generator to add + * @return a new composite {@link ArgumentCodeGenerator} instance + */ + default ArgumentCodeGenerator and(ArgumentCodeGenerator argumentCodeGenerator) { + return from(type -> { + CodeBlock code = generateCode(type); + return (code != null ? code : argumentCodeGenerator.generateCode(type)); + }); + } - private enum Kind { - INSTANCE, STATIC } } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java b/spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java new file mode 100644 index 00000000000..b9643151ca9 --- /dev/null +++ b/spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java @@ -0,0 +1,199 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.generate; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.MethodSpec.Builder; +import org.springframework.javapoet.TypeName; +import org.springframework.lang.Nullable; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link DefaultMethodReference}. + * + * @author Phillip Webb + * @author Stephane Nicoll + */ +class DefaultMethodReferenceTests { + + private static final String EXPECTED_STATIC = "org.springframework.aot.generate.DefaultMethodReferenceTests::someMethod"; + + private static final String EXPECTED_ANONYMOUS_INSTANCE = "::someMethod"; + + private static final String EXPECTED_DECLARED_INSTANCE = "::someMethod"; + + private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test"); + + private static final ClassName INITIALIZER_CLASS_NAME = ClassName.get("com.example", "Initializer"); + + @Test + void createWithStringCreatesMethodReference() { + MethodSpec method = createTestMethod("someMethod", new TypeName[0]); + MethodReference reference = new DefaultMethodReference(method, null); + assertThat(reference).hasToString(EXPECTED_ANONYMOUS_INSTANCE); + } + + @Test + void createWithClassNameAndStringCreateMethodReference() { + ClassName declaringClass = ClassName.get(DefaultMethodReferenceTests.class); + MethodReference reference = createMethodReference("someMethod", new TypeName[0], declaringClass); + assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE); + } + + @Test + void createWithStaticAndClassAndStringCreatesMethodReference() { + ClassName declaringClass = ClassName.get(DefaultMethodReferenceTests.class); + MethodReference reference = createStaticMethodReference("someMethod", declaringClass); + assertThat(reference).hasToString(EXPECTED_STATIC); + } + + @Test + void toCodeBlock() { + assertThat(createLocalMethodReference("methodName").toCodeBlock()) + .isEqualTo(CodeBlock.of("this::methodName")); + } + + @Test + void toCodeBlockWithStaticMethod() { + assertThat(createStaticMethodReference("methodName", TEST_CLASS_NAME).toCodeBlock()) + .isEqualTo(CodeBlock.of("com.example.Test::methodName")); + } + + @Test + void toCodeBlockWithStaticMethodRequiresDeclaringClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0], Modifier.STATIC); + MethodReference methodReference = new DefaultMethodReference(method, null); + assertThatIllegalArgumentException().isThrownBy(methodReference::toCodeBlock) + .withMessage("static method reference must define a declaring class"); + } + + @Test + void toInvokeCodeBlockWithNullDeclaringClassAndTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, null); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME)) + .isEqualTo(CodeBlock.of("methodName()")); + } + + @Test + void toInvokeCodeBlockWithNullDeclaringClassAndNullTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, null); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none())) + .isEqualTo(CodeBlock.of("methodName()")); + } + + @Test + void toInvokeCodeBlockWithDeclaringClassAndNullTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none())) + .isEqualTo(CodeBlock.of("new com.example.Test().methodName()")); + } + + @Test + void toInvokeCodeBlockWithMatchingTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME); + CodeBlock invocation = methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME); + // Assume com.example.Test is in a `test` variable. + assertThat(CodeBlock.of("$L.$L", "test", invocation)).isEqualTo(CodeBlock.of("test.methodName()")); + } + + @Test + void toInvokeCodeBlockWithNonMatchingDeclaringClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), INITIALIZER_CLASS_NAME)) + .isEqualTo(CodeBlock.of("new com.example.Test().methodName()")); + } + + @Test + void toInvokeCodeBlockWithMatchingArg() { + MethodReference methodReference = createLocalMethodReference("methodName", ClassName.get(String.class)); + ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(String.class, "stringArg"); + assertThat(methodReference.toInvokeCodeBlock(argCodeGenerator)) + .isEqualTo(CodeBlock.of("methodName(stringArg)")); + } + + @Test + void toInvokeCodeBlockWithMatchingArgs() { + MethodReference methodReference = createLocalMethodReference("methodName", + ClassName.get(Integer.class), ClassName.get(String.class)); + ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(String.class, "stringArg") + .and(Integer.class, "integerArg"); + assertThat(methodReference.toInvokeCodeBlock(argCodeGenerator)) + .isEqualTo(CodeBlock.of("methodName(integerArg, stringArg)")); + } + + @Test + void toInvokeCodeBlockWithNonMatchingArg() { + MethodReference methodReference = createLocalMethodReference("methodName", + ClassName.get(Integer.class), ClassName.get(String.class)); + ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(Integer.class, "integerArg"); + assertThatIllegalArgumentException().isThrownBy(() -> methodReference.toInvokeCodeBlock(argCodeGenerator)) + .withMessageContaining("parameter 1 of type java.lang.String is not supported"); + } + + @Test + void toInvokeCodeBlockWithStaticMethodAndMatchingDeclaringClass() { + MethodReference methodReference = createStaticMethodReference("methodName", TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME)) + .isEqualTo(CodeBlock.of("methodName()")); + } + + @Test + void toInvokeCodeBlockWithStaticMethodAndSeparateDeclaringClass() { + MethodReference methodReference = createStaticMethodReference("methodName", TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), INITIALIZER_CLASS_NAME)) + .isEqualTo(CodeBlock.of("com.example.Test.methodName()")); + } + + + private MethodReference createLocalMethodReference(String name, TypeName... argumentTypes) { + return createMethodReference(name, argumentTypes, null); + } + + private MethodReference createMethodReference(String name, TypeName[] argumentTypes, @Nullable ClassName declaringClass) { + MethodSpec method = createTestMethod(name, argumentTypes); + return new DefaultMethodReference(method, declaringClass); + } + + private MethodReference createStaticMethodReference(String name, ClassName declaringClass, TypeName... argumentTypes) { + MethodSpec method = createTestMethod(name, argumentTypes, Modifier.STATIC); + return new DefaultMethodReference(method, declaringClass); + } + + private MethodSpec createTestMethod(String name, TypeName[] argumentTypes, Modifier... modifiers) { + Builder method = MethodSpec.methodBuilder(name); + for (int i = 0; i < argumentTypes.length; i++) { + method.addParameter(argumentTypes[i], "args" + i); + } + method.addModifiers(modifiers); + return method.build(); + } + +} diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java index 34ac962746a..6e865bd4eb8 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java @@ -22,6 +22,7 @@ import javax.lang.model.element.Modifier; import org.junit.jupiter.api.Test; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; @@ -67,8 +68,8 @@ class GeneratedMethodTests { GeneratedMethod generatedMethod = create(emptyMethod); MethodReference methodReference = generatedMethod.toMethodReference(); assertThat(methodReference).isNotNull(); - assertThat(methodReference.toInvokeCodeBlock("test")) - .isEqualTo(CodeBlock.of("test.spring()")); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME)) + .isEqualTo(CodeBlock.of("spring()")); } @Test @@ -76,7 +77,8 @@ class GeneratedMethodTests { GeneratedMethod generatedMethod = create(method -> method.addModifiers(Modifier.STATIC)); MethodReference methodReference = generatedMethod.toMethodReference(); assertThat(methodReference).isNotNull(); - assertThat(methodReference.toInvokeCodeBlock()) + ClassName anotherDeclaringClass = ClassName.get("com.example", "Another"); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), anotherDeclaringClass)) .isEqualTo(CodeBlock.of("com.example.Test.spring()")); } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java b/spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java deleted file mode 100644 index de5c79667b4..00000000000 --- a/spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -import org.junit.jupiter.api.Test; - -import org.springframework.javapoet.ClassName; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; - -/** - * Tests for {@link MethodReference}. - * - * @author Phillip Webb - */ -class MethodReferenceTests { - - private static final String EXPECTED_STATIC = "org.springframework.aot.generate.MethodReferenceTests::someMethod"; - - private static final String EXPECTED_ANONYMOUS_INSTANCE = "::someMethod"; - - private static final String EXPECTED_DECLARED_INSTANCE = "::someMethod"; - - - @Test - void ofWithStringWhenMethodNameIsNullThrowsException() { - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofWithStringCreatesMethodReference() { - String methodName = "someMethod"; - MethodReference reference = MethodReference.of(methodName); - assertThat(reference).hasToString(EXPECTED_ANONYMOUS_INSTANCE); - } - - @Test - void ofWithClassAndStringWhenDeclaringClassIsNullThrowsException() { - Class declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofWithClassAndStringWhenMethodNameIsNullThrowsException() { - Class declaringClass = MethodReferenceTests.class; - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofWithClassAndStringCreatesMethodReference() { - Class declaringClass = MethodReferenceTests.class; - String methodName = "someMethod"; - MethodReference reference = MethodReference.of(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE); - } - - @Test - void ofWithClassNameAndStringWhenDeclaringClassIsNullThrowsException() { - ClassName declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofWithClassNameAndStringWhenMethodNameIsNullThrowsException() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofWithClassNameAndStringCreateMethodReference() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = "someMethod"; - MethodReference reference = MethodReference.of(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE); - } - - @Test - void ofStaticWithClassAndStringWhenDeclaringClassIsNullThrowsException() { - Class declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofStaticWithClassAndStringWhenMethodNameIsEmptyThrowsException() { - Class declaringClass = MethodReferenceTests.class; - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofStaticWithClassAndStringCreatesMethodReference() { - Class declaringClass = MethodReferenceTests.class; - String methodName = "someMethod"; - MethodReference reference = MethodReference.ofStatic(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_STATIC); - } - - @Test - void ofStaticWithClassNameAndGeneratedMethodNameWhenDeclaringClassIsNullThrowsException() { - ClassName declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofStaticWithClassNameAndGeneratedMethodNameWhenMethodNameIsEmptyThrowsException() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofStaticWithClassNameAndGeneratedMethodNameCreatesMethodReference() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = "someMethod"; - MethodReference reference = MethodReference.ofStatic(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_STATIC); - } - - @Test - void toCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toCodeBlock(null)).hasToString("this::someMethod"); - } - - @Test - void toCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNotNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toCodeBlock("myInstance")) - .hasToString("myInstance::someMethod"); - } - - @Test - void toCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThat(reference.toCodeBlock(null)).hasToString(EXPECTED_STATIC); - } - - @Test - void toCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNotNullThrowsException() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThatIllegalArgumentException() - .isThrownBy(() -> reference.toCodeBlock("myInstance")).withMessage( - "'instanceVariable' must be null for static method references"); - } - - @Test - void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toInvokeCodeBlock()).hasToString("someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNullAndHasDecalredClass() { - MethodReference reference = MethodReference.of(MethodReferenceTests.class, - "someMethod"); - assertThat(reference.toInvokeCodeBlock()).hasToString( - "new org.springframework.aot.generate.MethodReferenceTests().someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNotNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toInvokeCodeBlock("myInstance")) - .hasToString("myInstance.someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThat(reference.toInvokeCodeBlock()).hasToString( - "org.springframework.aot.generate.MethodReferenceTests.someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNotNullThrowsException() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThatIllegalArgumentException() - .isThrownBy(() -> reference.toInvokeCodeBlock("myInstance")).withMessage( - "'instanceVariable' must be null for static method references"); - } - -} From 2b45fd438827cc0b62484283676b1ac8bd11cf18 Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Wed, 7 Sep 2022 16:04:29 +0200 Subject: [PATCH 3/3] Allow bean factory initialization to have a more flexible signature This commit allows bean factory initialization to use a more flexible signature than just consuming the DefaultListableBeanFactory. The environment and the resource loader can now be specified if necessary. See gh-29005 --- .../aot/BeanFactoryInitializationCode.java | 15 +++- ...ionContextInitializationCodeGenerator.java | 40 +++++++++- ...ntextInitializationCodeGeneratorTests.java | 77 +++++++++++++++++++ 3 files changed, 125 insertions(+), 7 deletions(-) create mode 100644 spring-context/src/test/java/org/springframework/context/aot/ApplicationContextInitializationCodeGeneratorTests.java diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java index 654fdda5866..e7da3299e81 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java @@ -24,6 +24,7 @@ import org.springframework.aot.generate.MethodReference; * perform bean factory initialization. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 * @see BeanFactoryInitializationAotContribution */ @@ -41,10 +42,16 @@ public interface BeanFactoryInitializationCode { GeneratedMethods getMethods(); /** - * Add an initializer method call. - * @param methodReference a reference to the initialize method to call. The - * referenced method must have the same functional signature as - * {@code Consumer}. + * Add an initializer method call. An initializer can use a flexible signature, + * using any of the following: + *

    + *
  • {@code DefaultListableBeanFactory}, or {@code ConfigurableListableBeanFactory} + * to use the bean factory.
  • + *
  • {@code ConfigurableEnvironment} or {@code Environment} to access the + * environment.
  • + *
  • {@code ResourceLoader} to load resources.
  • + *
+ * @param methodReference a reference to the initialize method to call. */ void addInitializer(MethodReference methodReference); diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java index b2bf870e2f5..1be6885e9f8 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java @@ -18,6 +18,7 @@ package org.springframework.context.aot; import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import javax.lang.model.element.Modifier; @@ -27,15 +28,22 @@ import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver; import org.springframework.context.support.GenericApplicationContext; import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.env.Environment; +import org.springframework.core.io.ResourceLoader; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.javapoet.TypeName; import org.springframework.javapoet.TypeSpec; +import org.springframework.lang.Nullable; /** * Internal code generator to create the {@link ApplicationContextInitializer}. @@ -89,15 +97,15 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia BEAN_FACTORY_VARIABLE, ContextAnnotationAutowireCandidateResolver.class); code.addStatement("$L.setDependencyComparator($T.INSTANCE)", BEAN_FACTORY_VARIABLE, AnnotationAwareOrderComparator.class); - ArgumentCodeGenerator argCodeGenerator = createInitializerMethodsArgumentCodeGenerator(); + ArgumentCodeGenerator argCodeGenerator = createInitializerMethodArgumentCodeGenerator(); for (MethodReference initializer : this.initializers) { code.addStatement(initializer.toInvokeCodeBlock(argCodeGenerator, this.generatedClass.getName())); } return code.build(); } - private ArgumentCodeGenerator createInitializerMethodsArgumentCodeGenerator() { - return ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, BEAN_FACTORY_VARIABLE); + static ArgumentCodeGenerator createInitializerMethodArgumentCodeGenerator() { + return ArgumentCodeGenerator.from(new InitializerMethodArgumentCodeGenerator()); } GeneratedClass getGeneratedClass() { @@ -114,4 +122,30 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia this.initializers.add(methodReference); } + private static class InitializerMethodArgumentCodeGenerator implements Function { + + @Override + @Nullable + public CodeBlock apply(TypeName typeName) { + return (typeName instanceof ClassName className ? apply(className) : null); + } + + @Nullable + private CodeBlock apply(ClassName className) { + String name = className.canonicalName(); + if (name.equals(DefaultListableBeanFactory.class.getName()) + || name.equals(ConfigurableListableBeanFactory.class.getName())) { + return CodeBlock.of(BEAN_FACTORY_VARIABLE); + } + else if (name.equals(ConfigurableEnvironment.class.getName()) + || name.equals(Environment.class.getName())) { + return CodeBlock.of("$L.getConfigurableEnvironment()", APPLICATION_CONTEXT_VARIABLE); + } + else if (name.equals(ResourceLoader.class.getName())) { + return CodeBlock.of(APPLICATION_CONTEXT_VARIABLE); + } + return null; + } + } + } diff --git a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextInitializationCodeGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextInitializationCodeGeneratorTests.java new file mode 100644 index 00000000000..1155be09528 --- /dev/null +++ b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextInitializationCodeGeneratorTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.aot; + +import java.util.stream.Stream; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.AbstractBeanFactory; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.env.Environment; +import org.springframework.core.env.StandardEnvironment; +import org.springframework.core.io.ResourceLoader; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link ApplicationContextInitializationCodeGenerator}. + * + * @author Stephane Nicoll + */ +class ApplicationContextInitializationCodeGeneratorTests { + + private static final ArgumentCodeGenerator argCodeGenerator = ApplicationContextInitializationCodeGenerator. + createInitializerMethodArgumentCodeGenerator(); + + @ParameterizedTest + @MethodSource("methodArguments") + void argumentsForSupportedTypesAreResolved(Class target, String expectedArgument) { + CodeBlock code = CodeBlock.of(expectedArgument); + assertThat(argCodeGenerator.generateCode(ClassName.get(target))).isEqualTo(code); + } + + @Test + void argumentForUnsupportedBeanFactoryIsNotResolved() { + assertThat(argCodeGenerator.generateCode(ClassName.get(AbstractBeanFactory.class))).isNull(); + } + + @Test + void argumentForUnsupportedEnvironmentIsNotResolved() { + assertThat(argCodeGenerator.generateCode(ClassName.get(StandardEnvironment.class))).isNull(); + } + + static Stream methodArguments() { + String applicationContext = "applicationContext"; + String environment = applicationContext + ".getConfigurableEnvironment()"; + return Stream.of( + Arguments.of(DefaultListableBeanFactory.class, "beanFactory"), + Arguments.of(ConfigurableListableBeanFactory.class, "beanFactory"), + Arguments.of(ConfigurableEnvironment.class, environment), + Arguments.of(Environment.class, environment), + Arguments.of(ResourceLoader.class, applicationContext)); + } + +}