Make sure RuntimeHintsRegistrar are invoked only once

Close gh-28594
This commit is contained in:
Stephane Nicoll 2022-06-09 15:02:32 +02:00
parent 74d1be9bd8
commit 363722893b
4 changed files with 96 additions and 28 deletions

View File

@ -25,17 +25,43 @@ import java.lang.annotation.Target;
import org.springframework.aot.hint.RuntimeHintsRegistrar; import org.springframework.aot.hint.RuntimeHintsRegistrar;
/** /**
* Indicates that one or more {@link RuntimeHintsRegistrar} implementations should be processed. * Indicates that one or more {@link RuntimeHintsRegistrar} implementations
* <p>Unlike declaring {@link RuntimeHintsRegistrar} as {@code spring/aot.factories}, * should be processed.
* {@code @ImportRuntimeHints} allows for more flexible use cases where registrations are only *
* processed if the annotated configuration class or bean method is considered by the * <p>Unlike declaring {@link RuntimeHintsRegistrar} using
* application context. * {@code spring/aot.factories}, this annotation allows for more flexible
* registration where it is only processed if the annotated component or bean
* method is actually registered in the bean factory. To illustrate this
* behavior, consider the following example:
*
* <pre class="code">
* &#064;Configuration
* public class MyConfiguration {
*
* &#064;Bean
* &#064;ImportRuntimeHints(MyHints.class)
* &#064;Conditional(MyCondition.class)
* public MyService myService() {
* return new MyService();
* }
*
* }</pre>
*
* If the configuration class above is processed, {@code MyHints} will be
* contributed only if {@code MyCondition} matches. If it does not, and
* therefore {@code MyService} is not defined as a bean, the hints will
* not be processed either.
*
* <p>If several components refer to the same {@link RuntimeHintsRegistrar}
* implementation, it is invoked only once for a given bean factory
* processing.
* *
* @author Brian Clozel * @author Brian Clozel
* @author Stephane Nicoll
* @since 6.0 * @since 6.0
* @see org.springframework.aot.hint.RuntimeHints * @see org.springframework.aot.hint.RuntimeHints
*/ */
@Target({ElementType.TYPE, ElementType.METHOD}) @Target({ ElementType.TYPE, ElementType.METHOD })
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Documented @Documented
public @interface ImportRuntimeHints { public @interface ImportRuntimeHints {

View File

@ -16,8 +16,10 @@
package org.springframework.context.aot; package org.springframework.context.aot;
import java.util.ArrayList; import java.util.LinkedHashMap;
import java.util.List; import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
@ -55,29 +57,37 @@ class RuntimeHintsBeanFactoryInitializationAotProcessor
public BeanFactoryInitializationAotContribution processAheadOfTime( public BeanFactoryInitializationAotContribution processAheadOfTime(
ConfigurableListableBeanFactory beanFactory) { ConfigurableListableBeanFactory beanFactory) {
AotFactoriesLoader loader = new AotFactoriesLoader(beanFactory); AotFactoriesLoader loader = new AotFactoriesLoader(beanFactory);
List<RuntimeHintsRegistrar> registrars = new ArrayList<>( Map<Class<? extends RuntimeHintsRegistrar>, RuntimeHintsRegistrar> registrars = loader
loader.load(RuntimeHintsRegistrar.class)); .load(RuntimeHintsRegistrar.class).stream()
.collect(LinkedHashMap::new, (map, item) -> map.put(item.getClass(), item), Map::putAll);
extractFromBeanFactory(beanFactory).forEach(registrarClass ->
registrars.computeIfAbsent(registrarClass, BeanUtils::instantiateClass));
return new RuntimeHintsRegistrarContribution(registrars.values(),
beanFactory.getBeanClassLoader());
}
private Set<Class<? extends RuntimeHintsRegistrar>> extractFromBeanFactory(ConfigurableListableBeanFactory beanFactory) {
Set<Class<? extends RuntimeHintsRegistrar>> registrarClasses = new LinkedHashSet<>();
for (String beanName : beanFactory for (String beanName : beanFactory
.getBeanNamesForAnnotation(ImportRuntimeHints.class)) { .getBeanNamesForAnnotation(ImportRuntimeHints.class)) {
ImportRuntimeHints annotation = beanFactory.findAnnotationOnBean(beanName, ImportRuntimeHints annotation = beanFactory.findAnnotationOnBean(beanName,
ImportRuntimeHints.class); ImportRuntimeHints.class);
if (annotation != null) { if (annotation != null) {
registrars.addAll(extracted(beanName, annotation)); registrarClasses.addAll(extractFromBeanDefinition(beanName, annotation));
} }
} }
return new RuntimeHintsRegistrarContribution(registrars, return registrarClasses;
beanFactory.getBeanClassLoader());
} }
private List<RuntimeHintsRegistrar> extracted(String beanName, private Set<Class<? extends RuntimeHintsRegistrar>> extractFromBeanDefinition(String beanName,
ImportRuntimeHints annotation) { ImportRuntimeHints annotation) {
Class<? extends RuntimeHintsRegistrar>[] registrarClasses = annotation.value();
List<RuntimeHintsRegistrar> registrars = new ArrayList<>(registrarClasses.length); Set<Class<? extends RuntimeHintsRegistrar>> registrars = new LinkedHashSet<>();
for (Class<? extends RuntimeHintsRegistrar> registrarClass : registrarClasses) { for (Class<? extends RuntimeHintsRegistrar> registrarClass : annotation.value()) {
logger.trace( logger.trace(
LogMessage.format("Loaded [%s] registrar from annotated bean [%s]", LogMessage.format("Loaded [%s] registrar from annotated bean [%s]",
registrarClass.getCanonicalName(), beanName)); registrarClass.getCanonicalName(), beanName));
registrars.add(BeanUtils.instantiateClass(registrarClass)); registrars.add(registrarClass);
} }
return registrars; return registrars;
} }
@ -87,13 +97,13 @@ class RuntimeHintsBeanFactoryInitializationAotProcessor
implements BeanFactoryInitializationAotContribution { implements BeanFactoryInitializationAotContribution {
private final List<RuntimeHintsRegistrar> registrars; private final Iterable<RuntimeHintsRegistrar> registrars;
@Nullable @Nullable
private final ClassLoader beanClassLoader; private final ClassLoader beanClassLoader;
RuntimeHintsRegistrarContribution(List<RuntimeHintsRegistrar> registrars, RuntimeHintsRegistrarContribution(Iterable<RuntimeHintsRegistrar> registrars,
@Nullable ClassLoader beanClassLoader) { @Nullable ClassLoader beanClassLoader) {
this.registrars = registrars; this.registrars = registrars;
this.beanClassLoader = beanClassLoader; this.beanClassLoader = beanClassLoader;

View File

@ -19,6 +19,7 @@ package org.springframework.context.aot;
import java.io.IOException; import java.io.IOException;
import java.net.URL; import java.net.URL;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -38,6 +39,7 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.ImportRuntimeHints; import org.springframework.context.annotation.ImportRuntimeHints;
import org.springframework.context.support.GenericApplicationContext; import org.springframework.context.support.GenericApplicationContext;
import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ClassName;
import org.springframework.lang.Nullable;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -91,13 +93,31 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests {
assertThatSampleRegistrarContributed(); assertThatSampleRegistrarContributed();
} }
@Test
void shouldProcessDuplicatedRegistrarsOnlyOnce() {
GenericApplicationContext applicationContext = createApplicationContext();
applicationContext.registerBeanDefinition("incremental1",
new RootBeanDefinition(ConfigurationWithIncrementalHints.class));
applicationContext.registerBeanDefinition("incremental2",
new RootBeanDefinition(ConfigurationWithIncrementalHints.class));
applicationContext.setClassLoader(
new TestSpringFactoriesClassLoader("test-duplicated-runtime-hints-aot.factories"));
IncrementalRuntimeHintsRegistrar.counter.set(0);
this.generator.generateApplicationContext(applicationContext,
this.generationContext, MAIN_GENERATED_TYPE);
RuntimeHints runtimeHints = this.generationContext.getRuntimeHints();
assertThat(runtimeHints.resources().resourceBundles().map(ResourceBundleHint::getBaseName))
.containsOnly("com.example.example0", "sample");
assertThat(IncrementalRuntimeHintsRegistrar.counter.get()).isEqualTo(1);
}
@Test @Test
void shouldRejectRuntimeHintsRegistrarWithoutDefaultConstructor() { void shouldRejectRuntimeHintsRegistrarWithoutDefaultConstructor() {
GenericApplicationContext applicationContext = createApplicationContext( GenericApplicationContext applicationContext = createApplicationContext(
ConfigurationWithIllegalRegistrar.class); ConfigurationWithIllegalRegistrar.class);
assertThatThrownBy(() -> this.generator.generateApplicationContext( assertThatThrownBy(() -> this.generator.generateApplicationContext(
applicationContext, this.generationContext, MAIN_GENERATED_TYPE)) applicationContext, this.generationContext, MAIN_GENERATED_TYPE))
.isInstanceOf(BeanInstantiationException.class); .isInstanceOf(BeanInstantiationException.class);
} }
private void assertThatSampleRegistrarContributed() { private void assertThatSampleRegistrarContributed() {
@ -119,10 +139,9 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests {
} }
@ImportRuntimeHints(SampleRuntimeHintsRegistrar.class)
@Configuration(proxyBeanMethods = false) @Configuration(proxyBeanMethods = false)
@ImportRuntimeHints(SampleRuntimeHintsRegistrar.class)
static class ConfigurationWithHints { static class ConfigurationWithHints {
} }
@ -137,7 +156,6 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests {
} }
public static class SampleRuntimeHintsRegistrar implements RuntimeHintsRegistrar { public static class SampleRuntimeHintsRegistrar implements RuntimeHintsRegistrar {
@Override @Override
@ -147,19 +165,31 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests {
} }
@Configuration(proxyBeanMethods = false)
@ImportRuntimeHints(IncrementalRuntimeHintsRegistrar.class)
static class ConfigurationWithIncrementalHints {
}
static class IncrementalRuntimeHintsRegistrar implements RuntimeHintsRegistrar {
static final AtomicInteger counter = new AtomicInteger();
@Override
public void registerHints(RuntimeHints hints, @Nullable ClassLoader classLoader) {
hints.resources().registerResourceBundle("com.example.example" + counter.getAndIncrement());
}
}
static class SampleBean { static class SampleBean {
} }
@ImportRuntimeHints(IllegalRuntimeHintsRegistrar.class)
@Configuration(proxyBeanMethods = false) @Configuration(proxyBeanMethods = false)
@ImportRuntimeHints(IllegalRuntimeHintsRegistrar.class)
static class ConfigurationWithIllegalRegistrar { static class ConfigurationWithIllegalRegistrar {
} }
public static class IllegalRuntimeHintsRegistrar implements RuntimeHintsRegistrar { public static class IllegalRuntimeHintsRegistrar implements RuntimeHintsRegistrar {
public IllegalRuntimeHintsRegistrar(String arg) { public IllegalRuntimeHintsRegistrar(String arg) {
@ -173,7 +203,6 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests {
} }
static class TestSpringFactoriesClassLoader extends ClassLoader { static class TestSpringFactoriesClassLoader extends ClassLoader {
private final String factoriesName; private final String factoriesName;

View File

@ -0,0 +1,3 @@
org.springframework.aot.hint.RuntimeHintsRegistrar= \
org.springframework.context.aot.RuntimeHintsBeanFactoryInitializationAotProcessorTests.IncrementalRuntimeHintsRegistrar, \
org.springframework.context.aot.RuntimeHintsBeanFactoryInitializationAotProcessorTests.SampleRuntimeHintsRegistrar