@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2012 - 2024 the original author or authors .
* Copyright 2012 - 2025 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -19,11 +19,19 @@ package org.springframework.boot.test.context;
@@ -19,11 +19,19 @@ package org.springframework.boot.test.context;
import java.lang.reflect.Method ;
import java.util.ArrayList ;
import java.util.Arrays ;
import java.util.Collection ;
import java.util.Collections ;
import java.util.List ;
import java.util.function.Consumer ;
import org.springframework.aot.generate.GenerationContext ;
import org.springframework.aot.hint.ExecutableMode ;
import org.springframework.aot.hint.ReflectionHints ;
import org.springframework.beans.BeanUtils ;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution ;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor ;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode ;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory ;
import org.springframework.boot.ApplicationContextFactory ;
import org.springframework.boot.Banner ;
import org.springframework.boot.ConfigurableBootstrapContext ;
@ -158,20 +166,23 @@ public class SpringBootContextLoader extends AbstractContextLoader implements Ao
@@ -158,20 +166,23 @@ public class SpringBootContextLoader extends AbstractContextLoader implements Ao
. orElse ( null ) ;
Assert . state ( springBootConfiguration ! = null | | useMainMethod = = UseMainMethod . WHEN_AVAILABLE ,
"Cannot use main method as no @SpringBootConfiguration-annotated class is available" ) ;
Method mainMethod = ( springBootConfiguration ! = null )
? ReflectionUtils . findMethod ( springBootConfiguration , "main" , String [ ] . class ) : null ;
Method mainMethod = findMainMethod ( springBootConfiguration ) ;
Assert . state ( mainMethod ! = null | | useMainMethod = = UseMainMethod . WHEN_AVAILABLE ,
( ) - > "Main method not found on '%s'" . formatted ( springBootConfiguration . getName ( ) ) ) ;
return mainMethod ;
}
private static Method findMainMethod ( Class < ? > type ) {
Method mainMethod = ( type ! = null ) ? ReflectionUtils . findMethod ( type , "main" , String [ ] . class ) : null ;
if ( mainMethod = = null & & KotlinDetector . isKotlinPresent ( ) ) {
try {
Class < ? > kotlinClass = ClassUtils . forName ( springBootConfiguration . getName ( ) + "Kt" ,
springBootConfiguration . getClassLoader ( ) ) ;
Class < ? > kotlinClass = ClassUtils . forName ( type . getName ( ) + "Kt" , type . getClassLoader ( ) ) ;
mainMethod = ReflectionUtils . findMethod ( kotlinClass , "main" , String [ ] . class ) ;
}
catch ( ClassNotFoundException ex ) {
// Ignore
}
}
Assert . state ( mainMethod ! = null | | useMainMethod = = UseMainMethod . WHEN_AVAILABLE ,
( ) - > "Main method not found on '%s'" . formatted ( springBootConfiguration . getName ( ) ) ) ;
return mainMethod ;
}
@ -574,4 +585,39 @@ public class SpringBootContextLoader extends AbstractContextLoader implements Ao
@@ -574,4 +585,39 @@ public class SpringBootContextLoader extends AbstractContextLoader implements Ao
}
static class MainMethodBeanFactoryInitializationAotProcessor implements BeanFactoryInitializationAotProcessor {
@Override
public BeanFactoryInitializationAotContribution processAheadOfTime (
ConfigurableListableBeanFactory beanFactory ) {
List < Method > mainMethods = new ArrayList < > ( ) ;
for ( String beanName : beanFactory . getBeanDefinitionNames ( ) ) {
Class < ? > beanType = beanFactory . getType ( beanName ) ;
Method mainMethod = findMainMethod ( beanType ) ;
if ( mainMethod ! = null ) {
mainMethods . add ( mainMethod ) ;
}
}
return ! mainMethods . isEmpty ( ) ? new AotContribution ( mainMethods ) : null ;
}
static class AotContribution implements BeanFactoryInitializationAotContribution {
private final Collection < Method > mainMethods ;
AotContribution ( Collection < Method > mainMethods ) {
this . mainMethods = mainMethods ;
}
@Override
public void applyTo ( GenerationContext generationContext ,
BeanFactoryInitializationCode beanFactoryInitializationCode ) {
ReflectionHints reflectionHints = generationContext . getRuntimeHints ( ) . reflection ( ) ;
this . mainMethods . forEach ( ( method ) - > reflectionHints . registerMethod ( method , ExecutableMode . INVOKE ) ) ;
}
}
}
}