diff --git a/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideRegistry.java b/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideRegistry.java index 3afc7c885af..0ff3fbc3327 100644 --- a/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideRegistry.java +++ b/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideRegistry.java @@ -27,9 +27,9 @@ import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; -import org.springframework.util.StringUtils; /** * An internal class used to track {@link BeanOverrideHandler}-related state after @@ -110,14 +110,13 @@ class BeanOverrideRegistry { void inject(Object target, BeanOverrideHandler handler) { Field field = handler.getField(); Assert.notNull(field, () -> "BeanOverrideHandler must have a non-null field: " + handler); - String beanName = this.handlerToBeanNameMap.get(handler); - Assert.state(StringUtils.hasLength(beanName), () -> "No bean found for BeanOverrideHandler: " + handler); - inject(field, target, beanName); + Object bean = getBeanForHandler(handler, field.getType()); + Assert.state(bean != null, () -> "No bean found for BeanOverrideHandler: " + handler); + inject(field, target, bean); } - private void inject(Field field, Object target, String beanName) { + private void inject(Field field, Object target, Object bean) { try { - Object bean = this.beanFactory.getBean(beanName, field.getType()); ReflectionUtils.makeAccessible(field); ReflectionUtils.setField(field, target, bean); } @@ -126,4 +125,13 @@ class BeanOverrideRegistry { } } + @Nullable + private Object getBeanForHandler(BeanOverrideHandler handler, Class requiredType) { + String beanName = this.handlerToBeanNameMap.get(handler); + if (beanName != null) { + return this.beanFactory.getBean(beanName, requiredType); + } + return null; + } + }