diff --git a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java index e87b2e9fe77..35086d756f7 100644 --- a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java +++ b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java @@ -32,6 +32,9 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.aop.framework.autoproxy.AutoProxyUtils; +import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.ResourceHints; import org.springframework.aot.hint.TypeReference; import org.springframework.beans.PropertyValues; @@ -39,6 +42,9 @@ import org.springframework.beans.factory.BeanClassLoaderAware; import org.springframework.beans.factory.BeanDefinitionStoreException; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor; +import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionHolder; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; @@ -56,6 +62,7 @@ import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; import org.springframework.beans.factory.support.BeanNameGenerator; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.context.ApplicationStartupAware; import org.springframework.context.EnvironmentAware; import org.springframework.context.ResourceLoaderAware; @@ -101,8 +108,8 @@ import org.springframework.util.ClassUtils; * @since 3.0 */ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPostProcessor, - AotContributingBeanFactoryPostProcessor, PriorityOrdered, ResourceLoaderAware, ApplicationStartupAware, - BeanClassLoaderAware, EnvironmentAware { + AotContributingBeanFactoryPostProcessor, BeanFactoryInitializationAotProcessor, PriorityOrdered, + ResourceLoaderAware, ApplicationStartupAware, BeanClassLoaderAware, EnvironmentAware { /** * A {@code BeanNameGenerator} using fully qualified class names as default bean names. @@ -288,6 +295,13 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo ? new ImportAwareBeanFactoryConfiguration(beanFactory) : null); } + @Override + public BeanFactoryInitializationAotContribution processAheadOfTime( + ConfigurableListableBeanFactory beanFactory) { + return (beanFactory.containsBean(IMPORT_REGISTRY_BEAN_NAME) + ? new AotContribution(beanFactory) : null); + } + /** * Build and validate a configuration model based on the registry of * {@link Configuration} classes. @@ -555,4 +569,82 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo } + private class AotContribution implements BeanFactoryInitializationAotContribution { + + private static final String BEAN_FACTORY_VARIABLE = BeanFactoryInitializationCode.BEAN_FACTORY_VARIABLE; + + private static final ParameterizedTypeName STRING_STRING_MAP = ParameterizedTypeName + .get(Map.class, String.class, String.class); + + private static final String MAPPINGS_VARIABLE = "mappings"; + + + private final ConfigurableListableBeanFactory beanFactory; + + + public AotContribution(ConfigurableListableBeanFactory beanFactory) { + this.beanFactory = beanFactory; + } + + + @Override + public void applyTo(GenerationContext generationContext, + BeanFactoryInitializationCode beanFactoryInitializationCode) { + + Map mappings = buildImportAwareMappings(); + if (!mappings.isEmpty()) { + GeneratedMethod generatedMethod = beanFactoryInitializationCode + .getMethodGenerator() + .generateMethod("addImportAwareBeanPostProcessors") + .using(builder -> generateAddPostProcessorMethod(builder, + mappings)); + beanFactoryInitializationCode + .addInitializer(MethodReference.of(generatedMethod.getName())); + ResourceHints hints = generationContext.getRuntimeHints().resources(); + mappings.forEach( + (target, from) -> hints.registerType(TypeReference.of(from))); + } + } + + private void generateAddPostProcessorMethod(MethodSpec.Builder builder, + Map mappings) { + + builder.addJavadoc( + "Add ImportAwareBeanPostProcessor to support ImportAware beans"); + builder.addModifiers(Modifier.PRIVATE); + builder.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_VARIABLE); + builder.addCode(generateAddPostProcessorCode(mappings)); + } + + private CodeBlock generateAddPostProcessorCode(Map mappings) { + CodeBlock.Builder builder = CodeBlock.builder(); + builder.addStatement("$T $L = new $T<>()", STRING_STRING_MAP, + MAPPINGS_VARIABLE, HashMap.class); + mappings.forEach((type, from) -> builder.addStatement("$L.put($S, $S)", + MAPPINGS_VARIABLE, type, from)); + builder.addStatement("$L.addBeanPostProcessor(new $T($L))", + BEAN_FACTORY_VARIABLE, ImportAwareAotBeanPostProcessor.class, + MAPPINGS_VARIABLE); + return builder.build(); + } + + private Map buildImportAwareMappings() { + ImportRegistry importRegistry = this.beanFactory + .getBean(IMPORT_REGISTRY_BEAN_NAME, ImportRegistry.class); + Map mappings = new LinkedHashMap<>(); + for (String name : this.beanFactory.getBeanDefinitionNames()) { + Class beanType = this.beanFactory.getType(name); + if (beanType != null && ImportAware.class.isAssignableFrom(beanType)) { + String target = ClassUtils.getUserClass(beanType).getName(); + AnnotationMetadata from = importRegistry.getImportingClassFor(target); + if (from != null) { + mappings.put(target, from.getClassName()); + } + } + } + return mappings; + } + + } + } diff --git a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java new file mode 100644 index 00000000000..96b67559545 --- /dev/null +++ b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.annotation; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import javax.lang.model.element.Modifier; + +import org.assertj.core.api.InstanceOfAssertFactories; +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.DefaultGenerationContext; +import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.generate.InMemoryGeneratedFiles; +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.test.generator.compile.Compiled; +import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; +import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.factory.generator.SimpleConfiguration; +import org.springframework.context.testfixture.context.generator.annotation.ImportAwareConfiguration; +import org.springframework.context.testfixture.context.generator.annotation.ImportConfiguration; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.javapoet.TypeSpec; +import org.springframework.lang.Nullable; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.entry; + +/** + * Tests for {@link ConfigurationClassPostProcessor} AOT contributions. + * + * @author Phillip Webb + * @author Stephane Nicoll + */ +class ConfigurationClassPostProcessorAotContributionTests { + + private DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + + private InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); + + private DefaultGenerationContext generationContext = new DefaultGenerationContext( + this.generatedFiles); + + private MockBeanFactoryInitializationCode beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); + + @Test + void applyToWhenHasImportAwareConfigurationRegistersBeanPostProcessorWithMapEntry() { + BeanFactoryInitializationAotContribution contribution = getContribution( + ImportConfiguration.class); + contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); + testCompiledResult((initializer, compiled) -> { + DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); + initializer.accept(freshBeanFactory); + ImportAwareAotBeanPostProcessor postProcessor = (ImportAwareAotBeanPostProcessor) freshBeanFactory + .getBeanPostProcessors().get(0); + assertPostProcessorEntry(postProcessor, ImportAwareConfiguration.class, + ImportConfiguration.class); + }); + } + + @Test + void applyToWhenHasImportAwareConfigurationRegistersHints() { + BeanFactoryInitializationAotContribution contribution = getContribution( + ImportConfiguration.class); + contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); + assertThat(generationContext.getRuntimeHints().resources().resourcePatterns()) + .singleElement() + .satisfies(resourceHint -> assertThat(resourceHint.getIncludes()) + .containsOnly( + "org/springframework/context/testfixture/context/generator/annotation/ImportConfiguration.class")); + } + + @Test + void processAheadOfTimeWhenNoImportAwareConfigurationReturnsNull() { + assertThat(getContribution(SimpleConfiguration.class)).isNull(); + } + + @Nullable + private BeanFactoryInitializationAotContribution getContribution(Class type) { + this.beanFactory.registerBeanDefinition("configuration", + new RootBeanDefinition(type)); + ConfigurationClassPostProcessor postProcessor = new ConfigurationClassPostProcessor(); + postProcessor.postProcessBeanFactory(this.beanFactory); + return postProcessor.processAheadOfTime(this.beanFactory); + } + + @SuppressWarnings("unchecked") + private void testCompiledResult( + BiConsumer, Compiled> result) { + JavaFile javaFile = createJavaFile(); + this.generationContext.writeGeneratedContent(); + TestCompiler.forSystem().withFiles(this.generatedFiles).compile(javaFile::writeTo, + compiled -> result.accept(compiled.getInstance(Consumer.class), + compiled)); + } + + private JavaFile createJavaFile() { + MethodReference methodReference = this.beanFactoryInitializationCode.initializers + .get(0); + TypeSpec.Builder builder = TypeSpec.classBuilder("TestConsumer"); + builder.addModifiers(Modifier.PUBLIC); + builder.addSuperinterface(ParameterizedTypeName.get(Consumer.class, + DefaultListableBeanFactory.class)); + builder.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) + .addParameter(DefaultListableBeanFactory.class, "beanFactory") + .addStatement( + methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .build()); + this.beanFactoryInitializationCode.generatedMethods + .doWithMethodSpecs(builder::addMethod); + return JavaFile.builder("__", builder.build()).build(); + } + + private void assertPostProcessorEntry(ImportAwareAotBeanPostProcessor postProcessor, + Class key, Class value) { + assertThat(postProcessor).extracting("importsMapping") + .asInstanceOf(InstanceOfAssertFactories.MAP) + .containsExactly(entry(key.getName(), value.getName())); + } + + class MockBeanFactoryInitializationCode implements BeanFactoryInitializationCode { + + private final GeneratedMethods generatedMethods = new GeneratedMethods(); + + private final List initializers = new ArrayList<>(); + + @Override + public MethodGenerator getMethodGenerator() { + return this.generatedMethods; + } + + @Override + public void addInitializer(MethodReference methodReference) { + this.initializers.add(methodReference); + } + + } + +}