From 2b45fd438827cc0b62484283676b1ac8bd11cf18 Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Wed, 7 Sep 2022 16:04:29 +0200 Subject: [PATCH] Allow bean factory initialization to have a more flexible signature This commit allows bean factory initialization to use a more flexible signature than just consuming the DefaultListableBeanFactory. The environment and the resource loader can now be specified if necessary. See gh-29005 --- .../aot/BeanFactoryInitializationCode.java | 15 +++- ...ionContextInitializationCodeGenerator.java | 40 +++++++++- ...ntextInitializationCodeGeneratorTests.java | 77 +++++++++++++++++++ 3 files changed, 125 insertions(+), 7 deletions(-) create mode 100644 spring-context/src/test/java/org/springframework/context/aot/ApplicationContextInitializationCodeGeneratorTests.java diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java index 654fdda5866..e7da3299e81 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java @@ -24,6 +24,7 @@ import org.springframework.aot.generate.MethodReference; * perform bean factory initialization. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 * @see BeanFactoryInitializationAotContribution */ @@ -41,10 +42,16 @@ public interface BeanFactoryInitializationCode { GeneratedMethods getMethods(); /** - * Add an initializer method call. - * @param methodReference a reference to the initialize method to call. The - * referenced method must have the same functional signature as - * {@code Consumer}. + * Add an initializer method call. An initializer can use a flexible signature, + * using any of the following: + * + * @param methodReference a reference to the initialize method to call. */ void addInitializer(MethodReference methodReference); diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java index b2bf870e2f5..1be6885e9f8 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java @@ -18,6 +18,7 @@ package org.springframework.context.aot; import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import javax.lang.model.element.Modifier; @@ -27,15 +28,22 @@ import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver; import org.springframework.context.support.GenericApplicationContext; import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.env.Environment; +import org.springframework.core.io.ResourceLoader; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.javapoet.TypeName; import org.springframework.javapoet.TypeSpec; +import org.springframework.lang.Nullable; /** * Internal code generator to create the {@link ApplicationContextInitializer}. @@ -89,15 +97,15 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia BEAN_FACTORY_VARIABLE, ContextAnnotationAutowireCandidateResolver.class); code.addStatement("$L.setDependencyComparator($T.INSTANCE)", BEAN_FACTORY_VARIABLE, AnnotationAwareOrderComparator.class); - ArgumentCodeGenerator argCodeGenerator = createInitializerMethodsArgumentCodeGenerator(); + ArgumentCodeGenerator argCodeGenerator = createInitializerMethodArgumentCodeGenerator(); for (MethodReference initializer : this.initializers) { code.addStatement(initializer.toInvokeCodeBlock(argCodeGenerator, this.generatedClass.getName())); } return code.build(); } - private ArgumentCodeGenerator createInitializerMethodsArgumentCodeGenerator() { - return ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, BEAN_FACTORY_VARIABLE); + static ArgumentCodeGenerator createInitializerMethodArgumentCodeGenerator() { + return ArgumentCodeGenerator.from(new InitializerMethodArgumentCodeGenerator()); } GeneratedClass getGeneratedClass() { @@ -114,4 +122,30 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia this.initializers.add(methodReference); } + private static class InitializerMethodArgumentCodeGenerator implements Function { + + @Override + @Nullable + public CodeBlock apply(TypeName typeName) { + return (typeName instanceof ClassName className ? apply(className) : null); + } + + @Nullable + private CodeBlock apply(ClassName className) { + String name = className.canonicalName(); + if (name.equals(DefaultListableBeanFactory.class.getName()) + || name.equals(ConfigurableListableBeanFactory.class.getName())) { + return CodeBlock.of(BEAN_FACTORY_VARIABLE); + } + else if (name.equals(ConfigurableEnvironment.class.getName()) + || name.equals(Environment.class.getName())) { + return CodeBlock.of("$L.getConfigurableEnvironment()", APPLICATION_CONTEXT_VARIABLE); + } + else if (name.equals(ResourceLoader.class.getName())) { + return CodeBlock.of(APPLICATION_CONTEXT_VARIABLE); + } + return null; + } + } + } diff --git a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextInitializationCodeGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextInitializationCodeGeneratorTests.java new file mode 100644 index 00000000000..1155be09528 --- /dev/null +++ b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextInitializationCodeGeneratorTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.aot; + +import java.util.stream.Stream; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.AbstractBeanFactory; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.env.Environment; +import org.springframework.core.env.StandardEnvironment; +import org.springframework.core.io.ResourceLoader; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link ApplicationContextInitializationCodeGenerator}. + * + * @author Stephane Nicoll + */ +class ApplicationContextInitializationCodeGeneratorTests { + + private static final ArgumentCodeGenerator argCodeGenerator = ApplicationContextInitializationCodeGenerator. + createInitializerMethodArgumentCodeGenerator(); + + @ParameterizedTest + @MethodSource("methodArguments") + void argumentsForSupportedTypesAreResolved(Class target, String expectedArgument) { + CodeBlock code = CodeBlock.of(expectedArgument); + assertThat(argCodeGenerator.generateCode(ClassName.get(target))).isEqualTo(code); + } + + @Test + void argumentForUnsupportedBeanFactoryIsNotResolved() { + assertThat(argCodeGenerator.generateCode(ClassName.get(AbstractBeanFactory.class))).isNull(); + } + + @Test + void argumentForUnsupportedEnvironmentIsNotResolved() { + assertThat(argCodeGenerator.generateCode(ClassName.get(StandardEnvironment.class))).isNull(); + } + + static Stream methodArguments() { + String applicationContext = "applicationContext"; + String environment = applicationContext + ".getConfigurableEnvironment()"; + return Stream.of( + Arguments.of(DefaultListableBeanFactory.class, "beanFactory"), + Arguments.of(ConfigurableListableBeanFactory.class, "beanFactory"), + Arguments.of(ConfigurableEnvironment.class, environment), + Arguments.of(Environment.class, environment), + Arguments.of(ResourceLoader.class, applicationContext)); + } + +}