diff --git a/spring-context/src/main/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessor.java b/spring-context/src/main/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessor.java index 0484fecc84f..4d80cf146f8 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessor.java +++ b/spring-context/src/main/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessor.java @@ -16,8 +16,12 @@ package org.springframework.context.aot; +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; @@ -32,8 +36,12 @@ import org.springframework.beans.factory.aot.AotServices; import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; +import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.annotation.ImportRuntimeHints; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.core.annotation.MergedAnnotations; import org.springframework.core.log.LogMessage; import org.springframework.lang.Nullable; @@ -45,6 +53,7 @@ import org.springframework.lang.Nullable; * classes or bean methods. * * @author Brian Clozel + * @author Sebastien Deleuze */ class RuntimeHintsBeanFactoryInitializationAotProcessor implements BeanFactoryInitializationAotProcessor { @@ -66,15 +75,50 @@ class RuntimeHintsBeanFactoryInitializationAotProcessor implements BeanFactoryIn Set> registrarClasses = new LinkedHashSet<>(); for (String beanName : beanFactory .getBeanNamesForAnnotation(ImportRuntimeHints.class)) { - ImportRuntimeHints annotation = beanFactory.findAnnotationOnBean(beanName, - ImportRuntimeHints.class); - if (annotation != null) { - registrarClasses.addAll(extractFromBeanDefinition(beanName, annotation)); - } + findAnnotationsOnBean(beanFactory, beanName, + ImportRuntimeHints.class).forEach(annotation -> + registrarClasses.addAll(extractFromBeanDefinition(beanName, annotation))); } return registrarClasses; } + private List findAnnotationsOnBean(ConfigurableListableBeanFactory beanFactory, + String beanName, Class annotationType) { + + List annotations = new ArrayList<>(); + Class beanType = beanFactory.getType(beanName, true); + if (beanType != null) { + MergedAnnotations.from(beanType, MergedAnnotations.SearchStrategy.TYPE_HIERARCHY) + .stream(annotationType) + .filter(MergedAnnotation::isPresent) + .forEach(mergedAnnotation -> annotations.add(mergedAnnotation.synthesize())); + } + if (beanFactory.containsBeanDefinition(beanName)) { + BeanDefinition bd = beanFactory.getBeanDefinition(beanName); + if (bd instanceof RootBeanDefinition rbd) { + // Check raw bean class, e.g. in case of a proxy. + if (rbd.hasBeanClass() && rbd.getFactoryMethodName() == null) { + Class beanClass = rbd.getBeanClass(); + if (beanClass != beanType) { + MergedAnnotations.from(beanClass, MergedAnnotations.SearchStrategy.TYPE_HIERARCHY) + .stream(annotationType) + .filter(MergedAnnotation::isPresent) + .forEach(mergedAnnotation -> annotations.add(mergedAnnotation.synthesize())); + } + } + // Check annotations declared on factory method, if any. + Method factoryMethod = rbd.getResolvedFactoryMethod(); + if (factoryMethod != null) { + MergedAnnotations.from(factoryMethod, MergedAnnotations.SearchStrategy.TYPE_HIERARCHY) + .stream(annotationType) + .filter(MergedAnnotation::isPresent) + .forEach(mergedAnnotation -> annotations.add(mergedAnnotation.synthesize())); + } + } + } + return annotations; + } + private Set> extractFromBeanDefinition(String beanName, ImportRuntimeHints annotation) { diff --git a/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java index 6bd70485c0b..c3542b630fd 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java @@ -46,6 +46,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * Tests for {@link RuntimeHintsBeanFactoryInitializationAotProcessor}. * * @author Brian Clozel + * @author Sebastien Deleuze */ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { @@ -68,6 +69,15 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { assertThatSampleRegistrarContributed(); } + @Test + void shouldProcessRegistrarsOnInheritedConfiguration() { + GenericApplicationContext applicationContext = createApplicationContext( + ExtendedConfigurationWithHints.class); + this.generator.processAheadOfTime(applicationContext, + this.generationContext); + assertThatInheritedSampleRegistrarContributed(); + } + @Test void shouldProcessRegistrarOnBeanMethod() { GenericApplicationContext applicationContext = createApplicationContext( @@ -121,6 +131,14 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { .anyMatch(bundleHint -> "sample".equals(bundleHint.getBaseName())); } + private void assertThatInheritedSampleRegistrarContributed() { + assertThatSampleRegistrarContributed(); + Stream bundleHints = this.generationContext.getRuntimeHints() + .resources().resourceBundleHints(); + assertThat(bundleHints) + .anyMatch(bundleHint -> "extendedSample".equals(bundleHint.getBaseName())); + } + private GenericApplicationContext createApplicationContext( Class... configClasses) { GenericApplicationContext applicationContext = new GenericApplicationContext(); @@ -138,6 +156,10 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { static class ConfigurationWithHints { } + @Configuration(proxyBeanMethods = false) + @ImportRuntimeHints(ExtendedSampleRuntimeHintsRegistrar.class) + static class ExtendedConfigurationWithHints extends ConfigurationWithHints { + } @Configuration(proxyBeanMethods = false) static class ConfigurationWithBeanDeclaringHints { @@ -159,6 +181,15 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { } + public static class ExtendedSampleRuntimeHintsRegistrar implements RuntimeHintsRegistrar { + + @Override + public void registerHints(RuntimeHints hints, ClassLoader classLoader) { + hints.resources().registerResourceBundle("extendedSample"); + } + + } + @Configuration(proxyBeanMethods = false) @ImportRuntimeHints(IncrementalRuntimeHintsRegistrar.class) static class ConfigurationWithIncrementalHints {