diff --git a/spring-context/src/main/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessor.java b/spring-context/src/main/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessor.java index 9f789fd20c6..87d808523c0 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessor.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessor.java @@ -40,6 +40,7 @@ import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.core.annotation.MergedAnnotations; import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils; /** @@ -48,6 +49,7 @@ import org.springframework.util.ReflectionUtils; * underlying {@link ReflectiveProcessor} implementations. * * @author Stephane Nicoll + * @author Sebastien Deleuze */ class ReflectiveProcessorBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor { @@ -58,23 +60,30 @@ class ReflectiveProcessorBeanRegistrationAotProcessor implements BeanRegistratio public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) { Class beanClass = registeredBean.getBeanClass(); Set entries = new LinkedHashSet<>(); - if (isReflective(beanClass)) { - entries.add(createEntry(beanClass)); + processType(entries, beanClass); + for (Class implementedInterface : ClassUtils.getAllInterfacesForClass(beanClass)) { + processType(entries, implementedInterface); } - doWithReflectiveConstructors(beanClass, constructor -> - entries.add(createEntry(constructor))); - ReflectionUtils.doWithFields(beanClass, field -> - entries.add(createEntry(field)), this::isReflective); - ReflectionUtils.doWithMethods(beanClass, method -> - entries.add(createEntry(method)), this::isReflective); if (!entries.isEmpty()) { return new ReflectiveProcessorBeanRegistrationAotContribution(entries); } return null; } - private void doWithReflectiveConstructors(Class beanClass, Consumer> consumer) { - for (Constructor constructor : beanClass.getDeclaredConstructors()) { + private void processType(Set entries, Class typeToProcess) { + if (isReflective(typeToProcess)) { + entries.add(createEntry(typeToProcess)); + } + doWithReflectiveConstructors(typeToProcess, constructor -> + entries.add(createEntry(constructor))); + ReflectionUtils.doWithFields(typeToProcess, field -> + entries.add(createEntry(field)), this::isReflective); + ReflectionUtils.doWithMethods(typeToProcess, method -> + entries.add(createEntry(method)), this::isReflective); + } + + private void doWithReflectiveConstructors(Class typeToProcess, Consumer> consumer) { + for (Constructor constructor : typeToProcess.getDeclaredConstructors()) { if (isReflective(constructor)) { consumer.accept(constructor); } @@ -82,13 +91,13 @@ class ReflectiveProcessorBeanRegistrationAotProcessor implements BeanRegistratio } private boolean isReflective(AnnotatedElement element) { - return MergedAnnotations.from(element).isPresent(Reflective.class); + return MergedAnnotations.from(element, MergedAnnotations.SearchStrategy.TYPE_HIERARCHY).isPresent(Reflective.class); } @SuppressWarnings("unchecked") private Entry createEntry(AnnotatedElement element) { Class[] processorClasses = (Class[]) - MergedAnnotations.from(element).get(Reflective.class).getClassArray("value"); + MergedAnnotations.from(element, MergedAnnotations.SearchStrategy.TYPE_HIERARCHY).get(Reflective.class).getClassArray("value"); List processors = Arrays.stream(processorClasses).distinct() .map(processorClass -> this.processors.computeIfAbsent(processorClass, BeanUtils::instantiateClass)) .toList(); @@ -125,7 +134,6 @@ class ReflectiveProcessorBeanRegistrationAotProcessor implements BeanRegistratio @Override public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) { RuntimeHints runtimeHints = generationContext.getRuntimeHints(); - RuntimeHintsUtils.registerAnnotation(runtimeHints, Reflective.class); this.entries.forEach(entry -> { AnnotatedElement element = entry.element(); entry.processor().registerReflectionHints(runtimeHints.reflection(), element); diff --git a/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java index 548977aa1c3..6cb85f517f7 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java @@ -48,6 +48,7 @@ import static org.mockito.Mockito.mock; * Tests for {@link ReflectiveProcessorBeanRegistrationAotProcessor}. * * @author Stephane Nicoll + * @author Sebastien Deleuze */ class ReflectiveProcessorBeanRegistrationAotProcessorTests { @@ -109,6 +110,28 @@ class ReflectiveProcessorBeanRegistrationAotProcessorTests { assertThat(RuntimeHintsPredicates.proxies().forInterfaces(RetryInvoker.class, SynthesizedAnnotation.class)).accepts(runtimeHints); } + @Test + void shouldProcessAnnotationOnInterface() { + process(SampleMethodAnnotatedBeanWithInterface.class); + assertThat(this.generationContext.getRuntimeHints().reflection().getTypeHint(SampleInterface.class)) + .satisfies(typeHint -> assertThat(typeHint.methods()).singleElement() + .satisfies(methodHint -> assertThat(methodHint.getName()).isEqualTo("managed"))); + assertThat(this.generationContext.getRuntimeHints().reflection().getTypeHint(SampleMethodAnnotatedBeanWithInterface.class)) + .satisfies(typeHint -> assertThat(typeHint.methods()).singleElement() + .satisfies(methodHint -> assertThat(methodHint.getName()).isEqualTo("managed"))); + } + + @Test + void shouldProcessAnnotationOnInheritedClass() { + process(SampleMethodAnnotatedBeanWithInheritance.class); + assertThat(this.generationContext.getRuntimeHints().reflection().getTypeHint(SampleInheritedClass.class)) + .satisfies(typeHint -> assertThat(typeHint.methods()).singleElement() + .satisfies(methodHint -> assertThat(methodHint.getName()).isEqualTo("managed"))); + assertThat(this.generationContext.getRuntimeHints().reflection().getTypeHint(SampleMethodAnnotatedBeanWithInheritance.class)) + .satisfies(typeHint -> assertThat(typeHint.methods()).singleElement() + .satisfies(methodHint -> assertThat(methodHint.getName()).isEqualTo("managed"))); + } + @Nullable private BeanRegistrationAotContribution createContribution(Class beanClass) { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); @@ -193,6 +216,28 @@ class ReflectiveProcessorBeanRegistrationAotProcessorTests { } + static class SampleMethodAnnotatedBeanWithInterface implements SampleInterface { + + @Override + public void managed() { + } + + public void notManaged() { + } + + } + + static class SampleMethodAnnotatedBeanWithInheritance extends SampleInheritedClass { + + @Override + public void managed() { + } + + public void notManaged() { + } + + } + @Target({ElementType.METHOD, ElementType.ANNOTATION_TYPE}) @Retention(RetentionPolicy.RUNTIME) @Documented @@ -214,4 +259,17 @@ class ReflectiveProcessorBeanRegistrationAotProcessorTests { } + interface SampleInterface { + + @Reflective + void managed(); + } + + static class SampleInheritedClass { + + @Reflective + void managed() { + } + } + }