diff --git a/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideRegistry.java b/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideRegistry.java index 3afc7c885af..d9c6deb6447 100644 --- a/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideRegistry.java +++ b/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideRegistry.java @@ -16,7 +16,6 @@ package org.springframework.test.context.bean.override; -import java.lang.reflect.Field; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -25,16 +24,14 @@ import java.util.Map.Entry; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import org.springframework.util.ReflectionUtils; -import org.springframework.util.StringUtils; /** * An internal class used to track {@link BeanOverrideHandler}-related state after - * the bean factory has been processed and to provide field injection utilities - * for test execution listeners. + * the bean factory has been processed and to provide lookup facilities to test + * execution listeners. * * @author Simon Baslé * @author Sam Brannen @@ -63,6 +60,7 @@ class BeanOverrideRegistry { *

Also associates a {@linkplain BeanOverrideStrategy#WRAP "wrapping"} handler * with the given {@code beanName}, allowing for subsequent wrapping of the * bean via {@link #wrapBeanIfNecessary(Object, String)}. + * @see #getBeanForHandler(BeanOverrideHandler, Class) */ void registerBeanOverrideHandler(BeanOverrideHandler handler, String beanName) { Assert.state(!this.handlerToBeanNameMap.containsKey(handler), () -> @@ -107,23 +105,22 @@ class BeanOverrideRegistry { return handler.createOverrideInstance(beanName, null, bean, this.beanFactory); } - void inject(Object target, BeanOverrideHandler handler) { - Field field = handler.getField(); - Assert.notNull(field, () -> "BeanOverrideHandler must have a non-null field: " + handler); + /** + * Get the bean instance that was created by the provided {@link BeanOverrideHandler}. + * @param handler the {@code BeanOverrideHandler} that created the bean + * @param requiredType the required bean type + * @return the bean instance, or {@code null} if the provided handler is not + * registered in this registry + * @since 6.2.6 + * @see #registerBeanOverrideHandler(BeanOverrideHandler, String) + */ + @Nullable + Object getBeanForHandler(BeanOverrideHandler handler, Class requiredType) { String beanName = this.handlerToBeanNameMap.get(handler); - Assert.state(StringUtils.hasLength(beanName), () -> "No bean found for BeanOverrideHandler: " + handler); - inject(field, target, beanName); - } - - private void inject(Field field, Object target, String beanName) { - try { - Object bean = this.beanFactory.getBean(beanName, field.getType()); - ReflectionUtils.makeAccessible(field); - ReflectionUtils.setField(field, target, bean); - } - catch (Throwable ex) { - throw new BeanCreationException("Could not inject field '" + field + "'", ex); + if (beanName != null) { + return this.beanFactory.getBean(beanName, requiredType); } + return null; } } diff --git a/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideTestExecutionListener.java index 736223358cc..ca0499c875d 100644 --- a/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideTestExecutionListener.java +++ b/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideTestExecutionListener.java @@ -16,11 +16,15 @@ package org.springframework.test.context.bean.override; +import java.lang.reflect.Field; import java.util.List; +import org.springframework.beans.factory.BeanCreationException; import org.springframework.test.context.TestContext; import org.springframework.test.context.support.AbstractTestExecutionListener; import org.springframework.test.context.support.DependencyInjectionTestExecutionListener; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; /** * {@code TestExecutionListener} that enables {@link BeanOverride @BeanOverride} @@ -94,9 +98,23 @@ public class BeanOverrideTestExecutionListener extends AbstractTestExecutionList .getBean(BeanOverrideContextCustomizer.REGISTRY_BEAN_NAME, BeanOverrideRegistry.class); for (BeanOverrideHandler handler : handlers) { - beanOverrideRegistry.inject(testInstance, handler); + Field field = handler.getField(); + Assert.state(field != null, () -> "BeanOverrideHandler must have a non-null field: " + handler); + Object bean = beanOverrideRegistry.getBeanForHandler(handler, field.getType()); + Assert.state(bean != null, () -> "No bean found for BeanOverrideHandler: " + handler); + injectField(field, testInstance, bean); } } } + private static void injectField(Field field, Object target, Object bean) { + try { + ReflectionUtils.makeAccessible(field); + ReflectionUtils.setField(field, target, bean); + } + catch (Throwable ex) { + throw new BeanCreationException("Could not inject field '" + field + "'", ex); + } + } + }