diff --git a/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/BeanTypeRegistry.java b/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/BeanTypeRegistry.java index fc6ee6d46c6..5dcecbe977a 100644 --- a/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/BeanTypeRegistry.java +++ b/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/BeanTypeRegistry.java @@ -16,6 +16,7 @@ package org.springframework.boot.autoconfigure.condition; +import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.util.Arrays; import java.util.HashMap; @@ -40,6 +41,7 @@ import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.type.MethodMetadata; import org.springframework.core.type.StandardMethodMetadata; import org.springframework.util.Assert; @@ -85,7 +87,7 @@ final class BeanTypeRegistry implements SmartInitializingSingleton { * @param beanFactory the source bean factory * @return the {@link BeanTypeRegistry} for the given bean factory */ - public static BeanTypeRegistry create(ListableBeanFactory beanFactory) { + static BeanTypeRegistry get(ListableBeanFactory beanFactory) { Assert.isInstanceOf(DefaultListableBeanFactory.class, beanFactory); DefaultListableBeanFactory listableBeanFactory = (DefaultListableBeanFactory) beanFactory; Assert.isTrue(listableBeanFactory.isAllowEagerClassLoading(), @@ -101,26 +103,39 @@ final class BeanTypeRegistry implements SmartInitializingSingleton { /** * Return the names of beans matching the given type (including subclasses), judging - * from either bean definitions or the value of {@code getObjectType} in the case of - * FactoryBeans. Will include singletons but not cause early bean initialization. + * from either bean definitions or the value of {@link FactoryBean#getObjectType()} in + * the case of {@link FactoryBean FactoryBeans}. Will include singletons but will not + * cause early bean initialization. * @param type the class or interface to match (must not be {@code null}) * @return the names of beans (or objects created by FactoryBeans) matching the given * object type (including subclasses), or an empty set if none */ Set getNamesForType(Class type) { - if (this.lastBeanDefinitionCount != this.beanFactory.getBeanDefinitionCount()) { - Iterator names = this.beanFactory.getBeanNamesIterator(); - while (names.hasNext()) { - String name = names.next(); - if (!this.beanTypes.containsKey(name)) { - addBeanType(name); - } + updateTypesIfNecessary(); + Set matches = new LinkedHashSet(); + for (Map.Entry> entry : this.beanTypes.entrySet()) { + if (entry.getValue() != null && type.isAssignableFrom(entry.getValue())) { + matches.add(entry.getKey()); } - this.lastBeanDefinitionCount = this.beanFactory.getBeanDefinitionCount(); } + return matches; + } + + /** + * Returns the names of beans annotated with the given {@code annotation}, judging + * from either bean definitions or the value of {@link FactoryBean#getObjectType()} in + * the case of {@link FactoryBean FactoryBeans}. Will include singletons but will not + * cause early bean initialization. + * @param annotation the annotation to match (must not be {@code null}) + * @return the names of beans (or objects created by FactoryBeans) annoated with the + * given annotation, or an empty set if none + */ + Set getNamesForAnnotation(Class annotation) { + updateTypesIfNecessary(); Set matches = new LinkedHashSet(); for (Map.Entry> entry : this.beanTypes.entrySet()) { - if (entry.getValue() != null && type.isAssignableFrom(entry.getValue())) { + if (entry.getValue() != null && AnnotationUtils + .findAnnotation(entry.getValue(), annotation) != null) { matches.add(entry.getKey()); } } @@ -183,6 +198,19 @@ final class BeanTypeRegistry implements SmartInitializingSingleton { && !this.beanFactory.containsSingleton(factoryBeanName)); } + private void updateTypesIfNecessary() { + if (this.lastBeanDefinitionCount != this.beanFactory.getBeanDefinitionCount()) { + Iterator names = this.beanFactory.getBeanNamesIterator(); + while (names.hasNext()) { + String name = names.next(); + if (!this.beanTypes.containsKey(name)) { + addBeanType(name); + } + } + this.lastBeanDefinitionCount = this.beanFactory.getBeanDefinitionCount(); + } + } + /** * Attempt to guess the type that a {@link FactoryBean} will return based on the * generics in its method signature. diff --git a/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/OnBeanCondition.java b/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/OnBeanCondition.java index f7d8c52d00b..28d2d886cd1 100644 --- a/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/OnBeanCondition.java +++ b/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/OnBeanCondition.java @@ -59,8 +59,6 @@ import org.springframework.util.StringUtils; @Order(Ordered.LOWEST_PRECEDENCE) class OnBeanCondition extends SpringBootCondition implements ConfigurationCondition { - private static final String[] NO_BEANS = {}; - /** * Bean definition attribute name for factory beans to signal their product type (if * known and it can't be deduced from the factory bean class). @@ -267,7 +265,7 @@ class OnBeanCondition extends SpringBootCondition implements ConfigurationCondit private void collectBeanNamesForType(Set result, ListableBeanFactory beanFactory, Class type, boolean considerHierarchy) { - result.addAll(BeanTypeRegistry.create(beanFactory).getNamesForType(type)); + result.addAll(BeanTypeRegistry.get(beanFactory).getNamesForType(type)); if (considerHierarchy && beanFactory instanceof HierarchicalBeanFactory) { BeanFactory parent = ((HierarchicalBeanFactory) beanFactory) .getParentBeanFactory(); @@ -281,34 +279,32 @@ class OnBeanCondition extends SpringBootCondition implements ConfigurationCondit private String[] getBeanNamesForAnnotation( ConfigurableListableBeanFactory beanFactory, String type, ClassLoader classLoader, boolean considerHierarchy) throws LinkageError { - String[] result = NO_BEANS; + Set names = new HashSet(); try { @SuppressWarnings("unchecked") - Class typeClass = (Class) ClassUtils + Class annotationType = (Class) ClassUtils .forName(type, classLoader); - result = beanFactory.getBeanNamesForAnnotation(typeClass); - if (considerHierarchy) { - if (beanFactory - .getParentBeanFactory() instanceof ConfigurableListableBeanFactory) { - String[] parentResult = getBeanNamesForAnnotation( - (ConfigurableListableBeanFactory) beanFactory - .getParentBeanFactory(), - type, classLoader, true); - List resultList = new ArrayList(); - resultList.addAll(Arrays.asList(result)); - for (String beanName : parentResult) { - if (!resultList.contains(beanName) - && !beanFactory.containsLocalBean(beanName)) { - resultList.add(beanName); - } - } - result = StringUtils.toStringArray(resultList); - } - } - return result; + collectBeanNamesForAnnotation(names, beanFactory, annotationType, + considerHierarchy); } - catch (ClassNotFoundException ex) { - return NO_BEANS; + catch (ClassNotFoundException e) { + // Continue + } + return StringUtils.toStringArray(names); + } + + private void collectBeanNamesForAnnotation(Set names, + ListableBeanFactory beanFactory, Class annotationType, + boolean considerHierarchy) { + names.addAll( + BeanTypeRegistry.get(beanFactory).getNamesForAnnotation(annotationType)); + if (considerHierarchy) { + BeanFactory parent = ((HierarchicalBeanFactory) beanFactory) + .getParentBeanFactory(); + if (parent instanceof ListableBeanFactory) { + collectBeanNamesForAnnotation(names, (ListableBeanFactory) parent, + annotationType, considerHierarchy); + } } } diff --git a/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnBeanTests.java b/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnBeanTests.java index 5c798ac9480..4f8dfe1cdcb 100644 --- a/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnBeanTests.java +++ b/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnBeanTests.java @@ -16,10 +16,16 @@ package org.springframework.boot.autoconfigure.condition; +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; import java.util.Date; import org.junit.Test; +import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.boot.test.util.EnvironmentTestUtils; @@ -124,6 +130,15 @@ public class ConditionalOnBeanTests { this.context.refresh(); } + @Test + public void beanProducedByFactoryBeanIsConsideredWhenMatchingOnAnnotation() { + this.context.register(FactoryBeanConfiguration.class, + OnAnnotationWithFactoryBeanConfiguration.class); + this.context.refresh(); + assertThat(this.context.containsBean("bar")).isTrue(); + assertThat(this.context.getBeansOfType(ExampleBean.class)).hasSize(1); + } + @Configuration @ConditionalOnBean(name = "foo") protected static class OnBeanNameConfiguration { @@ -220,6 +235,27 @@ public class ConditionalOnBeanTests { } + @Configuration + static class FactoryBeanConfiguration { + + @Bean + public ExampleFactoryBean exampleBeanFactoryBean() { + return new ExampleFactoryBean(); + } + + } + + @Configuration + @ConditionalOnBean(annotation = TestAnnotation.class) + static class OnAnnotationWithFactoryBeanConfiguration { + + @Bean + public String bar() { + return "bar"; + } + + } + protected static class WithPropertyPlaceholderClassNameRegistrar implements ImportBeanDefinitionRegistrar { @@ -233,4 +269,46 @@ public class ConditionalOnBeanTests { } + public static class ExampleFactoryBean implements FactoryBean { + + @Override + public ExampleBean getObject() throws Exception { + return new ExampleBean("fromFactory"); + } + + @Override + public Class getObjectType() { + return ExampleBean.class; + } + + @Override + public boolean isSingleton() { + return false; + } + + } + + @TestAnnotation + public static class ExampleBean { + + private String value; + + public ExampleBean(String value) { + this.value = value; + } + + @Override + public String toString() { + return this.value; + } + + } + + @Target(ElementType.TYPE) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface TestAnnotation { + + } + } diff --git a/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnMissingBeanTests.java b/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnMissingBeanTests.java index ae3f8a51df9..95c0c63e2a0 100644 --- a/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnMissingBeanTests.java +++ b/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnMissingBeanTests.java @@ -16,6 +16,11 @@ package org.springframework.boot.autoconfigure.condition; +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; import java.util.Date; import org.junit.Test; @@ -285,6 +290,15 @@ public class ConditionalOnMissingBeanTests { assertThat(child.getBeansOfType(ExampleBean.class)).hasSize(2); } + @Test + public void beanProducedByFactoryBeanIsConsideredWhenMatchingOnAnnotation() { + this.context.register(ConcreteFactoryBeanConfiguration.class, + OnAnnotationWithFactoryBeanConfiguration.class); + this.context.refresh(); + assertThat(this.context.containsBean("bar")).isFalse(); + assertThat(this.context.getBeansOfType(ExampleBean.class)).hasSize(1); + } + @Configuration protected static class OnBeanInAncestorsConfiguration { @@ -500,6 +514,17 @@ public class ConditionalOnMissingBeanTests { } + @Configuration + @ConditionalOnMissingBean(annotation = TestAnnotation.class) + protected static class OnAnnotationWithFactoryBeanConfiguration { + + @Bean + public String bar() { + return "bar"; + } + + } + @Configuration @EnableScheduling protected static class FooConfiguration { @@ -554,6 +579,7 @@ public class ConditionalOnMissingBeanTests { } + @TestAnnotation public static class ExampleBean { private String value; @@ -623,4 +649,11 @@ public class ConditionalOnMissingBeanTests { } + @Target(ElementType.TYPE) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface TestAnnotation { + + } + }