diff --git a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/context/SpringBootContextLoader.java b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/context/SpringBootContextLoader.java index c58f1bc4da2..6fe35dc2b9d 100644 --- a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/context/SpringBootContextLoader.java +++ b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/context/SpringBootContextLoader.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2024 the original author or authors. + * Copyright 2012-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,11 +19,19 @@ package org.springframework.boot.test.context; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.function.Consumer; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.hint.ExecutableMode; +import org.springframework.aot.hint.ReflectionHints; import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor; +import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.boot.ApplicationContextFactory; import org.springframework.boot.Banner; import org.springframework.boot.ConfigurableBootstrapContext; @@ -158,20 +166,23 @@ public class SpringBootContextLoader extends AbstractContextLoader implements Ao .orElse(null); Assert.state(springBootConfiguration != null || useMainMethod == UseMainMethod.WHEN_AVAILABLE, "Cannot use main method as no @SpringBootConfiguration-annotated class is available"); - Method mainMethod = (springBootConfiguration != null) - ? ReflectionUtils.findMethod(springBootConfiguration, "main", String[].class) : null; + Method mainMethod = findMainMethod(springBootConfiguration); + Assert.state(mainMethod != null || useMainMethod == UseMainMethod.WHEN_AVAILABLE, + () -> "Main method not found on '%s'".formatted(springBootConfiguration.getName())); + return mainMethod; + } + + private static Method findMainMethod(Class type) { + Method mainMethod = (type != null) ? ReflectionUtils.findMethod(type, "main", String[].class) : null; if (mainMethod == null && KotlinDetector.isKotlinPresent()) { try { - Class kotlinClass = ClassUtils.forName(springBootConfiguration.getName() + "Kt", - springBootConfiguration.getClassLoader()); + Class kotlinClass = ClassUtils.forName(type.getName() + "Kt", type.getClassLoader()); mainMethod = ReflectionUtils.findMethod(kotlinClass, "main", String[].class); } catch (ClassNotFoundException ex) { // Ignore } } - Assert.state(mainMethod != null || useMainMethod == UseMainMethod.WHEN_AVAILABLE, - () -> "Main method not found on '%s'".formatted(springBootConfiguration.getName())); return mainMethod; } @@ -574,4 +585,39 @@ public class SpringBootContextLoader extends AbstractContextLoader implements Ao } + static class MainMethodBeanFactoryInitializationAotProcessor implements BeanFactoryInitializationAotProcessor { + + @Override + public BeanFactoryInitializationAotContribution processAheadOfTime( + ConfigurableListableBeanFactory beanFactory) { + List mainMethods = new ArrayList<>(); + for (String beanName : beanFactory.getBeanDefinitionNames()) { + Class beanType = beanFactory.getType(beanName); + Method mainMethod = findMainMethod(beanType); + if (mainMethod != null) { + mainMethods.add(mainMethod); + } + } + return !mainMethods.isEmpty() ? new AotContribution(mainMethods) : null; + } + + static class AotContribution implements BeanFactoryInitializationAotContribution { + + private final Collection mainMethods; + + AotContribution(Collection mainMethods) { + this.mainMethods = mainMethods; + } + + @Override + public void applyTo(GenerationContext generationContext, + BeanFactoryInitializationCode beanFactoryInitializationCode) { + ReflectionHints reflectionHints = generationContext.getRuntimeHints().reflection(); + this.mainMethods.forEach((method) -> reflectionHints.registerMethod(method, ExecutableMode.INVOKE)); + } + + } + + } + } diff --git a/spring-boot-project/spring-boot-test/src/main/resources/META-INF/spring/aot.factories b/spring-boot-project/spring-boot-test/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..ba72dcee9a7 --- /dev/null +++ b/spring-boot-project/spring-boot-test/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor=\ +org.springframework.boot.test.context.SpringBootContextLoader.MainMethodBeanFactoryInitializationAotProcessor diff --git a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/context/SpringBootContextLoaderTests.java b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/context/SpringBootContextLoaderTests.java index 3f6a4888d11..5d5b9920a7c 100644 --- a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/context/SpringBootContextLoaderTests.java +++ b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/context/SpringBootContextLoaderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2024 the original author or authors. + * Copyright 2012-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,10 +25,16 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; +import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.beans.factory.BeanCreationException; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.boot.ApplicationContextFactory; import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootContextLoader.MainMethodBeanFactoryInitializationAotProcessor; import org.springframework.boot.test.context.SpringBootTest.UseMainMethod; import org.springframework.boot.test.util.TestPropertyValues; import org.springframework.boot.web.reactive.context.GenericReactiveWebApplicationContext; @@ -248,6 +254,35 @@ class SpringBootContextLoaderTests { .withMessage("UseMainMethod.ALWAYS cannot be used with @ContextHierarchy tests"); } + @Test + void whenMainMethodPresentRegisterReflectionHints() { + TestContext testContext = new ExposedTestContextManager(UseMainMethodWhenAvailableAndNoMainMethod.class) + .getExposedTestContext(); + ConfigurableListableBeanFactory beanFactory = (ConfigurableListableBeanFactory) testContext + .getApplicationContext() + .getAutowireCapableBeanFactory(); + BeanFactoryInitializationAotContribution aotContribution = new MainMethodBeanFactoryInitializationAotProcessor() + .processAheadOfTime(beanFactory); + assertThat(aotContribution).isNull(); + } + + @Test + void whenMainMethodNotAvailableReturnsNoAotContribution() { + TestContext testContext = new ExposedTestContextManager(UseMainMethodWhenAvailableAndMainMethod.class) + .getExposedTestContext(); + ConfigurableListableBeanFactory beanFactory = (ConfigurableListableBeanFactory) testContext + .getApplicationContext() + .getAutowireCapableBeanFactory(); + BeanFactoryInitializationAotContribution aotContribution = new MainMethodBeanFactoryInitializationAotProcessor() + .processAheadOfTime(beanFactory); + assertThat(aotContribution).isNotNull(); + TestGenerationContext generationContext = new TestGenerationContext(); + aotContribution.applyTo(generationContext, null); + RuntimeHints runtimeHints = generationContext.getRuntimeHints(); + assertThat(RuntimeHintsPredicates.reflection().onMethod(ConfigWithMain.class, "main").invoke()) + .accepts(runtimeHints); + } + @Test void whenSubclassProvidesCustomApplicationContextFactory() { TestContext testContext = new ExposedTestContextManager(CustomApplicationContextTest.class)