Use StackWalker to deduce main application class

See gh-31701
This commit is contained in:
GGGGGHT 2022-07-13 15:31:42 +08:00 committed by Andy Wilkinson
parent 19030f69c2
commit ea3fe95881
2 changed files with 35 additions and 12 deletions

View File

@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Properties; import java.util.Properties;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -274,18 +275,10 @@ public class SpringApplication {
} }
private Class<?> deduceMainApplicationClass() { private Class<?> deduceMainApplicationClass() {
try { return StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE)
StackTraceElement[] stackTrace = new RuntimeException().getStackTrace(); .walk((s) -> s.filter(e -> Objects.equals(e.getMethodName(), "main")).findFirst()
for (StackTraceElement stackTraceElement : stackTrace) { .map(StackWalker.StackFrame::getDeclaringClass))
if ("main".equals(stackTraceElement.getMethodName())) { .orElse(null);
return Class.forName(stackTraceElement.getClassName());
}
}
}
catch (ClassNotFoundException ex) {
// Swallow and continue
}
return null;
} }
/** /**

View File

@ -23,6 +23,7 @@ import java.util.Iterator;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -1314,6 +1315,35 @@ class SpringApplicationTests {
.accepts(hints); .accepts(hints);
} }
@Test
void deduceMainApplicationClass() {
assertThat(
Objects.equals(deduceMainApplicationClassByStackWalker(), deduceMainApplicationClassByThrowException()))
.isTrue();
}
private Class<?> deduceMainApplicationClassByThrowException() {
try {
StackTraceElement[] stackTrace = new RuntimeException().getStackTrace();
for (StackTraceElement stackTraceElement : stackTrace) {
if ("main".equals(stackTraceElement.getMethodName())) {
return Class.forName(stackTraceElement.getClassName());
}
}
}
catch (ClassNotFoundException ex) {
// Swallow and continue
}
return null;
}
private Class<?> deduceMainApplicationClassByStackWalker() {
return StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE)
.walk((s) -> s.filter(e -> Objects.equals(e.getMethodName(), "main")).findFirst()
.map(StackWalker.StackFrame::getDeclaringClass))
.orElse(null);
}
private <S extends AvailabilityState> ArgumentMatcher<ApplicationEvent> isAvailabilityChangeEventWithState( private <S extends AvailabilityState> ArgumentMatcher<ApplicationEvent> isAvailabilityChangeEventWithState(
S state) { S state) {
return (argument) -> (argument instanceof AvailabilityChangeEvent<?>) return (argument) -> (argument instanceof AvailabilityChangeEvent<?>)