diff --git a/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/OnBeanCondition.java b/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/OnBeanCondition.java index 57ccd264d3e..435a5f659a0 100644 --- a/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/OnBeanCondition.java +++ b/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/OnBeanCondition.java @@ -20,15 +20,24 @@ import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Set; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.HierarchicalBeanFactory; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Condition; import org.springframework.context.annotation.ConditionContext; import org.springframework.context.annotation.ConfigurationCondition; +import org.springframework.core.ResolvableType; import org.springframework.core.type.AnnotatedTypeMetadata; import org.springframework.core.type.MethodMetadata; import org.springframework.util.Assert; @@ -43,6 +52,7 @@ import org.springframework.util.StringUtils; * * @author Phillip Webb * @author Dave Syer + * @author Jakub Kubrynski */ class OnBeanCondition extends SpringBootCondition implements ConfigurationCondition { @@ -100,8 +110,8 @@ class OnBeanCondition extends SpringBootCondition implements ConfigurationCondit boolean considerHierarchy = beans.getStrategy() == SearchStrategy.ALL; for (String type : beans.getTypes()) { - beanNames.addAll(Arrays.asList(getBeanNamesForType(beanFactory, type, - context.getClassLoader(), considerHierarchy))); + beanNames.addAll(getBeanNamesForType(beanFactory, type, + context.getClassLoader(), considerHierarchy)); } for (String annotation : beans.getAnnotations()) { @@ -126,25 +136,94 @@ class OnBeanCondition extends SpringBootCondition implements ConfigurationCondit return beanFactory.containsLocalBean(beanName); } - private String[] getBeanNamesForType(ConfigurableListableBeanFactory beanFactory, - String type, ClassLoader classLoader, boolean considerHierarchy) - throws LinkageError { - // eagerInit set to false to prevent early instantiation (some - // factory beans will not be able to determine their object type at this - // stage, so those are not eligible for matching this condition) + private Collection getBeanNamesForType( + ConfigurableListableBeanFactory beanFactory, String type, + ClassLoader classLoader, boolean considerHierarchy) throws LinkageError { try { - Class typeClass = ClassUtils.forName(type, classLoader); - if (considerHierarchy) { - return BeanFactoryUtils.beanNamesForTypeIncludingAncestors(beanFactory, - typeClass, false, false); - } - return beanFactory.getBeanNamesForType(typeClass, false, false); + Set result = new LinkedHashSet(); + collectBeanNamesForType(result, beanFactory, + ClassUtils.forName(type, classLoader), considerHierarchy); + return result; } catch (ClassNotFoundException ex) { - return NO_BEANS; + return Collections.emptySet(); + } + } + + private void collectBeanNamesForType(Set result, + ListableBeanFactory beanFactory, Class type, boolean considerHierarchy) { + // eagerInit set to false to prevent early instantiation + result.addAll(Arrays.asList(beanFactory.getBeanNamesForType(type, true, false))); + if (beanFactory instanceof ConfigurableListableBeanFactory) { + collectBeanNamesForTypeFromFactoryBeans(result, + (ConfigurableListableBeanFactory) beanFactory, type); + } + if (considerHierarchy && beanFactory instanceof HierarchicalBeanFactory) { + BeanFactory parent = ((HierarchicalBeanFactory) beanFactory) + .getParentBeanFactory(); + if (parent instanceof ListableBeanFactory) { + collectBeanNamesForType(result, (ListableBeanFactory) parent, type, + considerHierarchy); + } + } + } + + /** + * Attempt to collect bean names for type by considering FactoryBean generics. Some + * factory beans will not be able to determine their object type at this stage, so + * those are not eligible for matching this condition. + */ + private void collectBeanNamesForTypeFromFactoryBeans(Set result, + ConfigurableListableBeanFactory beanFactory, Class type) { + String[] names = beanFactory.getBeanNamesForType(FactoryBean.class, true, false); + for (String name : names) { + name = BeanFactoryUtils.transformedBeanName(name); + BeanDefinition beanDefinition = beanFactory.getBeanDefinition(name); + Class generic = getFactoryBeanGeneric(beanFactory, beanDefinition); + if (generic != null && ClassUtils.isAssignable(type, generic)) { + result.add(name); + } } } + private Class getFactoryBeanGeneric(ConfigurableListableBeanFactory beanFactory, + BeanDefinition definition) { + try { + if (StringUtils.hasLength(definition.getFactoryBeanName()) + && StringUtils.hasLength(definition.getFactoryMethodName())) { + return getConfigurationClassFactoryBeanGeneric(beanFactory, definition); + } + if (StringUtils.hasLength(definition.getBeanClassName())) { + return getDirectFactoryBeanGeneric(beanFactory, definition); + } + } + catch (Exception ex) { + } + return null; + } + + private Class getConfigurationClassFactoryBeanGeneric( + ConfigurableListableBeanFactory beanFactory, BeanDefinition definition) + throws Exception { + BeanDefinition factoryDefinition = beanFactory.getBeanDefinition(definition + .getFactoryBeanName()); + Class factoryClass = ClassUtils.forName(factoryDefinition.getBeanClassName(), + beanFactory.getBeanClassLoader()); + Method method = ReflectionUtils.findMethod(factoryClass, + definition.getFactoryMethodName()); + return ResolvableType.forMethodReturnType(method).as(FactoryBean.class) + .resolveGeneric(); + } + + private Class getDirectFactoryBeanGeneric( + ConfigurableListableBeanFactory beanFactory, BeanDefinition definition) + throws ClassNotFoundException, LinkageError { + Class factoryBeanClass = ClassUtils.forName(definition.getBeanClassName(), + beanFactory.getBeanClassLoader()); + return ResolvableType.forClass(factoryBeanClass).as(FactoryBean.class) + .resolveGeneric(); + } + private String[] getBeanNamesForAnnotation( ConfigurableListableBeanFactory beanFactory, String type, ClassLoader classLoader, boolean considerHierarchy) throws LinkageError { diff --git a/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnMissingBeanTests.java b/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnMissingBeanTests.java index e96531ba5ce..50ab687e5da 100644 --- a/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnMissingBeanTests.java +++ b/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnMissingBeanTests.java @@ -16,7 +16,6 @@ package org.springframework.boot.autoconfigure.condition; -import org.junit.Ignore; import org.junit.Test; import org.springframework.beans.factory.FactoryBean; import org.springframework.boot.autoconfigure.PropertyPlaceholderAutoConfiguration; @@ -38,6 +37,7 @@ import static org.junit.Assert.assertTrue; * * @author Dave Syer * @author Phillip Webb + * @author Jakub Kubrynski */ @SuppressWarnings("resource") public class ConditionalOnMissingBeanTests { @@ -102,7 +102,7 @@ public class ConditionalOnMissingBeanTests { @Test public void testAnnotationOnMissingBeanConditionWithEagerFactoryBean() { this.context.register(FooConfiguration.class, OnAnnotationConfiguration.class, - ConfigurationWithFactoryBean.class, + FactoryBeanXmlConfiguration.class, PropertyPlaceholderAutoConfiguration.class); this.context.refresh(); assertFalse(this.context.containsBean("bar")); @@ -111,22 +111,44 @@ public class ConditionalOnMissingBeanTests { } @Test - @Ignore("This will never work - you need to use XML for FactoryBeans, or else call getObject() inside the @Bean method") public void testOnMissingBeanConditionWithFactoryBean() { - this.context.register(ExampleBeanAndFactoryBeanConfiguration.class, + this.context.register(FactoryBeanConfiguration.class, + ConditionalOnFactoryBean.class, PropertyPlaceholderAutoConfiguration.class); this.context.refresh(); - // There should be only one - this.context.getBean(ExampleBean.class); + assertThat(this.context.getBean(ExampleBean.class).toString(), + equalTo("fromFactory")); + } + + @Test + public void testOnMissingBeanConditionWithConcreteFactoryBean() { + this.context.register(ConcreteFactoryBeanConfiguration.class, + ConditionalOnFactoryBean.class, + PropertyPlaceholderAutoConfiguration.class); + this.context.refresh(); + assertThat(this.context.getBean(ExampleBean.class).toString(), + equalTo("fromFactory")); + } + + @Test + public void testOnMissingBeanConditionWithUnhelpfulFactoryBean() { + this.context.register(UnhelpfulFactoryBeanConfiguration.class, + ConditionalOnFactoryBean.class, + PropertyPlaceholderAutoConfiguration.class); + this.context.refresh(); + // We could not tell that the FactoryBean would ultimately create an ExampleBean + assertThat(this.context.getBeansOfType(ExampleBean.class).values().size(), + equalTo(2)); } @Test public void testOnMissingBeanConditionWithFactoryBeanInXml() { - this.context.register(ConfigurationWithFactoryBean.class, + this.context.register(FactoryBeanXmlConfiguration.class, + ConditionalOnFactoryBean.class, PropertyPlaceholderAutoConfiguration.class); this.context.refresh(); - // There should be only one - this.context.getBean(ExampleBean.class); + assertThat(this.context.getBean(ExampleBean.class).toString(), + equalTo("fromFactory")); } @Configuration @@ -139,17 +161,41 @@ public class ConditionalOnMissingBeanTests { } @Configuration - protected static class ExampleBeanAndFactoryBeanConfiguration { - + protected static class FactoryBeanConfiguration { @Bean public FactoryBean exampleBeanFactoryBean() { return new ExampleFactoryBean("foo"); } + } + + @Configuration + protected static class ConcreteFactoryBeanConfiguration { + @Bean + public ExampleFactoryBean exampleBeanFactoryBean() { + return new ExampleFactoryBean("foo"); + } + } + + @Configuration + protected static class UnhelpfulFactoryBeanConfiguration { + @Bean + @SuppressWarnings("rawtypes") + public FactoryBean exampleBeanFactoryBean() { + return new ExampleFactoryBean("foo"); + } + } + + @Configuration + @ImportResource("org/springframework/boot/autoconfigure/condition/factorybean.xml") + protected static class FactoryBeanXmlConfiguration { + } + @Configuration + protected static class ConditionalOnFactoryBean { @Bean @ConditionalOnMissingBean(ExampleBean.class) public ExampleBean createExampleBean() { - return new ExampleBean(); + return new ExampleBean("direct"); } } @@ -162,11 +208,6 @@ public class ConditionalOnMissingBeanTests { } } - @Configuration - @ImportResource("org/springframework/boot/autoconfigure/condition/factorybean.xml") - protected static class ConfigurationWithFactoryBean { - } - @Configuration @EnableScheduling protected static class FooConfiguration { @@ -198,7 +239,7 @@ public class ConditionalOnMissingBeanTests { protected static class ExampleBeanConfiguration { @Bean public ExampleBean exampleBean() { - return new ExampleBean(); + return new ExampleBean("test"); } } @@ -208,12 +249,24 @@ public class ConditionalOnMissingBeanTests { @Bean @ConditionalOnMissingBean public ExampleBean exampleBean2() { - return new ExampleBean(); + return new ExampleBean("test"); } } public static class ExampleBean { + + private String value; + + public ExampleBean(String value) { + this.value = value; + } + + @Override + public String toString() { + return this.value; + } + } public static class ExampleFactoryBean implements FactoryBean { @@ -224,7 +277,7 @@ public class ConditionalOnMissingBeanTests { @Override public ExampleBean getObject() throws Exception { - return new ExampleBean(); + return new ExampleBean("fromFactory"); } @Override