diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java index e59fa6bf69a..09be887a52a 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java @@ -25,6 +25,7 @@ import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Properties; import java.util.Set; import java.util.stream.Collectors; @@ -274,18 +275,10 @@ public class SpringApplication { } private Class deduceMainApplicationClass() { - try { - StackTraceElement[] stackTrace = new RuntimeException().getStackTrace(); - for (StackTraceElement stackTraceElement : stackTrace) { - if ("main".equals(stackTraceElement.getMethodName())) { - return Class.forName(stackTraceElement.getClassName()); - } - } - } - catch (ClassNotFoundException ex) { - // Swallow and continue - } - return null; + return StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE) + .walk((s) -> s.filter(e -> Objects.equals(e.getMethodName(), "main")).findFirst() + .map(StackWalker.StackFrame::getDeclaringClass)) + .orElse(null); } /** diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java index 5326bddcb4b..022c4d3301c 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java @@ -23,6 +23,7 @@ import java.util.Iterator; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -1314,6 +1315,35 @@ class SpringApplicationTests { .accepts(hints); } + @Test + void deduceMainApplicationClass() { + assertThat( + Objects.equals(deduceMainApplicationClassByStackWalker(), deduceMainApplicationClassByThrowException())) + .isTrue(); + } + + private Class deduceMainApplicationClassByThrowException() { + try { + StackTraceElement[] stackTrace = new RuntimeException().getStackTrace(); + for (StackTraceElement stackTraceElement : stackTrace) { + if ("main".equals(stackTraceElement.getMethodName())) { + return Class.forName(stackTraceElement.getClassName()); + } + } + } + catch (ClassNotFoundException ex) { + // Swallow and continue + } + return null; + } + + private Class deduceMainApplicationClassByStackWalker() { + return StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE) + .walk((s) -> s.filter(e -> Objects.equals(e.getMethodName(), "main")).findFirst() + .map(StackWalker.StackFrame::getDeclaringClass)) + .orElse(null); + } + private ArgumentMatcher isAvailabilityChangeEventWithState( S state) { return (argument) -> (argument instanceof AvailabilityChangeEvent)