diff --git a/core/spring-boot/src/main/java/org/springframework/boot/SpringApplicationAotProcessor.java b/core/spring-boot/src/main/java/org/springframework/boot/SpringApplicationAotProcessor.java index 13bdaf91e8e..6a78e519030 100644 --- a/core/spring-boot/src/main/java/org/springframework/boot/SpringApplicationAotProcessor.java +++ b/core/spring-boot/src/main/java/org/springframework/boot/SpringApplicationAotProcessor.java @@ -58,12 +58,27 @@ public class SpringApplicationAotProcessor extends ContextAotProcessor { @Override protected GenericApplicationContext prepareApplicationContext(Class application) { return new AotProcessorHook(application).run(() -> { - Method mainMethod = application.getMethod("main", String[].class); - ReflectionUtils.invokeMethod(mainMethod, null, new Object[] { this.applicationArgs }); + Method mainMethod = getMainMethod(application); + mainMethod.setAccessible(true); + if (mainMethod.getParameterCount() == 0) { + ReflectionUtils.invokeMethod(mainMethod, null); + } + else { + ReflectionUtils.invokeMethod(mainMethod, null, new Object[] { this.applicationArgs }); + } return Void.class; }); } + private static Method getMainMethod(Class application) throws Exception { + try { + return application.getDeclaredMethod("main", String[].class); + } + catch (NoSuchMethodException ex) { + return application.getDeclaredMethod("main"); + } + } + public static void main(String[] args) throws Exception { int requiredArgs = 6; Assert.state(args.length >= requiredArgs, () -> "Usage: " + SpringApplicationAotProcessor.class.getName() diff --git a/core/spring-boot/src/test/java/org/springframework/boot/SpringApplicationAotProcessorTests.java b/core/spring-boot/src/test/java/org/springframework/boot/SpringApplicationAotProcessorTests.java index a852f34da67..70b38780595 100644 --- a/core/spring-boot/src/test/java/org/springframework/boot/SpringApplicationAotProcessorTests.java +++ b/core/spring-boot/src/test/java/org/springframework/boot/SpringApplicationAotProcessorTests.java @@ -37,20 +37,21 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException; */ class SpringApplicationAotProcessorTests { + private static final ApplicationInvoker invoker = new ApplicationInvoker(); + @BeforeEach void setup() { - SampleApplication.argsHolder = null; - SampleApplication.postRunInvoked = false; + invoker.clean(); } @Test - void processApplicationInvokesRunMethod(@TempDir Path directory) { + void processApplicationInvokesMainMethod(@TempDir Path directory) { String[] arguments = new String[] { "1", "2" }; - SpringApplicationAotProcessor processor = new SpringApplicationAotProcessor(SampleApplication.class, + SpringApplicationAotProcessor processor = new SpringApplicationAotProcessor(PublicMainMethod.class, settings(directory), arguments); processor.process(); - assertThat(SampleApplication.argsHolder).isEqualTo(arguments); - assertThat(SampleApplication.postRunInvoked).isFalse(); + assertThat(ApplicationInvoker.argsHolder).isEqualTo(arguments); + assertThat(ApplicationInvoker.postRunInvoked).isFalse(); } @Test @@ -63,23 +64,53 @@ class SpringApplicationAotProcessorTests { } @Test - void invokeMainParsesArgumentsAndInvokesRunMethod(@TempDir Path directory) throws Exception { - String[] mainArguments = new String[] { SampleApplication.class.getName(), + void invokeMainParsesArgumentsAndInvokesMainMethod(@TempDir Path directory) throws Exception { + String[] mainArguments = new String[] { PublicMainMethod.class.getName(), + directory.resolve("source").toString(), directory.resolve("resource").toString(), + directory.resolve("class").toString(), "com.example", "example", "1", "2" }; + SpringApplicationAotProcessor.main(mainArguments); + assertThat(ApplicationInvoker.argsHolder).containsExactly("1", "2"); + assertThat(ApplicationInvoker.postRunInvoked).isFalse(); + } + + @Test + void invokeMainParsesArgumentsAndInvokesPackagePrivateMainMethod(@TempDir Path directory) throws Exception { + String[] mainArguments = new String[] { PackagePrivateMainMethod.class.getName(), directory.resolve("source").toString(), directory.resolve("resource").toString(), directory.resolve("class").toString(), "com.example", "example", "1", "2" }; SpringApplicationAotProcessor.main(mainArguments); - assertThat(SampleApplication.argsHolder).containsExactly("1", "2"); - assertThat(SampleApplication.postRunInvoked).isFalse(); + assertThat(ApplicationInvoker.argsHolder).containsExactly("1", "2"); + assertThat(ApplicationInvoker.postRunInvoked).isFalse(); + } + + @Test + void invokeMainParsesArgumentsAndInvokesParameterLessMainMethod(@TempDir Path directory) throws Exception { + String[] mainArguments = new String[] { PublicParameterlessMainMethod.class.getName(), + directory.resolve("source").toString(), directory.resolve("resource").toString(), + directory.resolve("class").toString(), "com.example", "example", "1", "2" }; + SpringApplicationAotProcessor.main(mainArguments); + assertThat(ApplicationInvoker.argsHolder).isNull(); + assertThat(ApplicationInvoker.postRunInvoked).isFalse(); + } + + @Test + void invokeMainParsesArgumentsAndInvokesPackagePrivateRunMethod(@TempDir Path directory) throws Exception { + String[] mainArguments = new String[] { PackagePrivateParameterlessMainMethod.class.getName(), + directory.resolve("source").toString(), directory.resolve("resource").toString(), + directory.resolve("class").toString(), "com.example", "example", "1", "2" }; + SpringApplicationAotProcessor.main(mainArguments); + assertThat(ApplicationInvoker.argsHolder).isNull(); + assertThat(ApplicationInvoker.postRunInvoked).isFalse(); } @Test void invokeMainParsesArgumentsAndInvokesRunMethodWithoutGroupId(@TempDir Path directory) throws Exception { - String[] mainArguments = new String[] { SampleApplication.class.getName(), + String[] mainArguments = new String[] { PublicMainMethod.class.getName(), directory.resolve("source").toString(), directory.resolve("resource").toString(), directory.resolve("class").toString(), "", "example", "1", "2" }; SpringApplicationAotProcessor.main(mainArguments); - assertThat(SampleApplication.argsHolder).containsExactly("1", "2"); - assertThat(SampleApplication.postRunInvoked).isFalse(); + assertThat(ApplicationInvoker.argsHolder).containsExactly("1", "2"); + assertThat(ApplicationInvoker.postRunInvoked).isFalse(); } @Test @@ -100,16 +131,37 @@ class SpringApplicationAotProcessorTests { } @Configuration(proxyBeanMethods = false) - public static class SampleApplication { + public static class PublicMainMethod { - public static String @Nullable [] argsHolder; + public static void main(String[] args) { + invoker.invoke(args, () -> SpringApplication.run(PublicMainMethod.class, args)); + } - public static boolean postRunInvoked; + } - public static void main(String[] args) { - argsHolder = args; - SpringApplication.run(SampleApplication.class, args); - postRunInvoked = true; + @Configuration(proxyBeanMethods = false) + public static class PackagePrivateMainMethod { + + static void main(String[] args) { + invoker.invoke(args, () -> SpringApplication.run(PackagePrivateMainMethod.class, args)); + } + + } + + @Configuration(proxyBeanMethods = false) + public static class PublicParameterlessMainMethod { + + public static void main() { + invoker.invoke(null, () -> SpringApplication.run(PublicParameterlessMainMethod.class)); + } + + } + + @Configuration(proxyBeanMethods = false) + public static class PackagePrivateParameterlessMainMethod { + + static void main() { + invoker.invoke(null, () -> SpringApplication.run(PackagePrivateParameterlessMainMethod.class)); } } @@ -122,4 +174,23 @@ class SpringApplicationAotProcessorTests { } + private static final class ApplicationInvoker { + + public static String @Nullable [] argsHolder; + + public static boolean postRunInvoked; + + void invoke(String @Nullable [] args, Runnable applicationRun) { + argsHolder = args; + applicationRun.run(); + postRunInvoked = true; + } + + void clean() { + argsHolder = null; + postRunInvoked = false; + } + + } + }