@ -31,11 +31,14 @@ import org.springframework.core.MethodIntrospector;
import org.springframework.core.ResolvableType ;
import org.springframework.core.ResolvableType ;
import org.springframework.test.context.TestContextAnnotationUtils ;
import org.springframework.test.context.TestContextAnnotationUtils ;
import org.springframework.test.context.bean.override.BeanOverrideProcessor ;
import org.springframework.test.context.bean.override.BeanOverrideProcessor ;
import org.springframework.test.context.bean.override.BeanOverrideStrategy ;
import org.springframework.util.Assert ;
import org.springframework.util.Assert ;
import org.springframework.util.ClassUtils ;
import org.springframework.util.ClassUtils ;
import org.springframework.util.ReflectionUtils ;
import org.springframework.util.ReflectionUtils ;
import org.springframework.util.ReflectionUtils.MethodFilter ;
import org.springframework.util.ReflectionUtils.MethodFilter ;
import org.springframework.util.StringUtils ;
import static org.springframework.test.context.bean.override.BeanOverrideStrategy.REPLACE_DEFINITION ;
import static org.springframework.test.context.bean.override.BeanOverrideStrategy.REPLACE_OR_CREATE_DEFINITION ;
/ * *
/ * *
* { @link BeanOverrideProcessor } implementation for { @link TestBean @TestBean }
* { @link BeanOverrideProcessor } implementation for { @link TestBean @TestBean }
@ -52,30 +55,34 @@ class TestBeanOverrideProcessor implements BeanOverrideProcessor {
@Override
@Override
public TestBeanOverrideMetadata createMetadata ( Annotation overrideAnnotation , Class < ? > testClass , Field field ) {
public TestBeanOverrideMetadata createMetadata ( Annotation overrideAnnotation , Class < ? > testClass , Field field ) {
if ( ! ( overrideAnnotation instanceof TestBean testBeanAnnotation ) ) {
if ( ! ( overrideAnnotation instanceof TestBean testBean ) ) {
throw new IllegalStateException ( "Invalid annotation passed to %s: expected @TestBean on field %s.%s"
throw new IllegalStateException ( "Invalid annotation passed to %s: expected @TestBean on field %s.%s"
. formatted ( getClass ( ) . getSimpleName ( ) , field . getDeclaringClass ( ) . getName ( ) , field . getName ( ) ) ) ;
. formatted ( getClass ( ) . getSimpleName ( ) , field . getDeclaringClass ( ) . getName ( ) , field . getName ( ) ) ) ;
}
}
String beanName = ( ! testBean . name ( ) . isBlank ( ) ? testBean . name ( ) : null ) ;
String methodName = testBean . methodName ( ) ;
BeanOverrideStrategy strategy = ( testBean . enforceOverride ( ) ? REPLACE_DEFINITION : REPLACE_OR_CREATE_DEFINITION ) ;
Method overrideMethod ;
Method overrideMethod ;
String methodName = testBeanAnnotation . methodName ( ) ;
if ( ! methodName . isBlank ( ) ) {
if ( ! methodName . isBlank ( ) ) {
// If the user specified an explicit method name, search for that.
// If the user specified an explicit method name, search for that.
overrideMethod = findTestBeanFactoryMethod ( testClass , field . getType ( ) , methodName ) ;
overrideMethod = findTestBeanFactoryMethod ( testClass , field . getType ( ) , methodName ) ;
}
}
else {
else {
// Otherwise, search for candidate factory methods the field name
// Otherwise, search for candidate factory methods whose names match either
// or explicit bean name (if any).
// the field name or the explicit bean name (if any).
List < String > candidateMethodNames = new ArrayList < > ( ) ;
List < String > candidateMethodNames = new ArrayList < > ( ) ;
candidateMethodNames . add ( field . getName ( ) ) ;
candidateMethodNames . add ( field . getName ( ) ) ;
String beanName = testBeanAnnotation . name ( ) ;
if ( beanName ! = null ) {
if ( StringUtils . hasText ( beanName ) ) {
candidateMethodNames . add ( beanName ) ;
candidateMethodNames . add ( beanName ) ;
}
}
overrideMethod = findTestBeanFactoryMethod ( testClass , field . getType ( ) , candidateMethodNames ) ;
overrideMethod = findTestBeanFactoryMethod ( testClass , field . getType ( ) , candidateMethodNames ) ;
}
}
String beanName = ( StringUtils . hasText ( testBeanAnnotation . name ( ) ) ? testBeanAnnotation . name ( ) : null ) ;
return new TestBeanOverrideMetadata ( field , ResolvableType . forField ( field , testClass ) , beanName , overrideMethod ) ;
return new TestBeanOverrideMetadata (
field , ResolvableType . forField ( field , testClass ) , beanName , strategy , overrideMethod ) ;
}
}
/ * *
/ * *