diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/BeanFactoryUtils.java b/spring-beans/src/main/java/org/springframework/beans/factory/BeanFactoryUtils.java index 2fad1042ff5..2c37b5dfa65 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/BeanFactoryUtils.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/BeanFactoryUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -149,14 +149,7 @@ public abstract class BeanFactoryUtils { if (hbf.getParentBeanFactory() instanceof ListableBeanFactory) { String[] parentResult = beanNamesForTypeIncludingAncestors( (ListableBeanFactory) hbf.getParentBeanFactory(), type); - List resultList = new ArrayList<>(); - resultList.addAll(Arrays.asList(result)); - for (String beanName : parentResult) { - if (!resultList.contains(beanName) && !hbf.containsLocalBean(beanName)) { - resultList.add(beanName); - } - } - result = StringUtils.toStringArray(resultList); + result = mergeNamesWithParent(result, parentResult, hbf); } } return result; @@ -182,14 +175,7 @@ public abstract class BeanFactoryUtils { if (hbf.getParentBeanFactory() instanceof ListableBeanFactory) { String[] parentResult = beanNamesForTypeIncludingAncestors( (ListableBeanFactory) hbf.getParentBeanFactory(), type); - List resultList = new ArrayList<>(); - resultList.addAll(Arrays.asList(result)); - for (String beanName : parentResult) { - if (!resultList.contains(beanName) && !hbf.containsLocalBean(beanName)) { - resultList.add(beanName); - } - } - result = StringUtils.toStringArray(resultList); + result = mergeNamesWithParent(result, parentResult, hbf); } } return result; @@ -225,14 +211,7 @@ public abstract class BeanFactoryUtils { if (hbf.getParentBeanFactory() instanceof ListableBeanFactory) { String[] parentResult = beanNamesForTypeIncludingAncestors( (ListableBeanFactory) hbf.getParentBeanFactory(), type, includeNonSingletons, allowEagerInit); - List resultList = new ArrayList<>(); - resultList.addAll(Arrays.asList(result)); - for (String beanName : parentResult) { - if (!resultList.contains(beanName) && !hbf.containsLocalBean(beanName)) { - resultList.add(beanName); - } - } - result = StringUtils.toStringArray(resultList); + result = mergeNamesWithParent(result, parentResult, hbf); } } return result; @@ -365,6 +344,7 @@ public abstract class BeanFactoryUtils { */ public static String[] beanNamesForAnnotationIncludingAncestors( ListableBeanFactory lbf, Class annotationType) { + Assert.notNull(lbf, "ListableBeanFactory must not be null"); String[] result = lbf.getBeanNamesForAnnotation(annotationType); if (lbf instanceof HierarchicalBeanFactory) { @@ -372,14 +352,7 @@ public abstract class BeanFactoryUtils { if (hbf.getParentBeanFactory() instanceof ListableBeanFactory) { String[] parentResult = beanNamesForAnnotationIncludingAncestors( (ListableBeanFactory) hbf.getParentBeanFactory(), annotationType); - List resultList = new ArrayList<>(); - resultList.addAll(Arrays.asList(result)); - for (String beanName : parentResult) { - if (!resultList.contains(beanName) && !hbf.containsLocalBean(beanName)) { - resultList.add(beanName); - } - } - result = StringUtils.toStringArray(resultList); + result = mergeNamesWithParent(result, parentResult, hbf); } } return result; @@ -477,6 +450,29 @@ public abstract class BeanFactoryUtils { return uniqueBean(type, beansOfType); } + + /** + * Merge the given bean names result with the given parent result. + * @param result the local bean name result + * @param parentResult the parent bean name result (possibly empty) + * @param hbf the local bean factory + * @return the merged result (possibly the local result as-is) + * @since 4.3.15 + */ + private static String[] mergeNamesWithParent(String[] result, String[] parentResult, HierarchicalBeanFactory hbf) { + if (parentResult.length == 0) { + return result; + } + List merged = new ArrayList<>(result.length + parentResult.length); + merged.addAll(Arrays.asList(result)); + for (String beanName : parentResult) { + if (!merged.contains(beanName) && !hbf.containsLocalBean(beanName)) { + merged.add(beanName); + } + } + return StringUtils.toStringArray(merged); + } + /** * Extract a unique bean for the given type from the given Map of matching beans. * @param type type of bean to match @@ -486,11 +482,11 @@ public abstract class BeanFactoryUtils { * @throws NoUniqueBeanDefinitionException if more than one bean of the given type was found */ private static T uniqueBean(Class type, Map matchingBeans) { - int nrFound = matchingBeans.size(); - if (nrFound == 1) { + int count = matchingBeans.size(); + if (count == 1) { return matchingBeans.values().iterator().next(); } - else if (nrFound > 1) { + else if (count > 1) { throw new NoUniqueBeanDefinitionException(type, matchingBeans.keySet()); } else {