Add support for ImportAware callback

This commit adds a way for a BeanFactoryPostProcessor to participate to
AOT optimizations by contributing code that replaces its runtime
behaviour.

ConfigurationClassPostProcessor does implement this new interface and
computes a mapping of the ImportAware configuration classes. The mapping
is generated for latter reuse by ImportAwareAotBeanPostProcessor.

Closes gh-2811
This commit is contained in:
Stephane Nicoll 2022-03-06 17:44:42 +01:00
parent ec6a19fc6b
commit 9ba927215e
7 changed files with 467 additions and 2 deletions

View File

@ -0,0 +1,49 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.beans.factory.generator;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.lang.Nullable;
/**
* Specialization of {@link BeanFactoryPostProcessor} that contributes bean
* factory optimizations ahead of time, using generated code that replaces
* runtime behavior.
*
* @author Stephane Nicoll
* @since 6.0
*/
@FunctionalInterface
public interface AotContributingBeanFactoryPostProcessor extends BeanFactoryPostProcessor {
/**
* Contribute a {@link BeanFactoryContribution} for the given bean factory,
* if applicable.
* @param beanFactory the bean factory to optimize
* @return the contribution to use or {@code null}
*/
@Nullable
BeanFactoryContribution contribute(ConfigurableListableBeanFactory beanFactory);
@Override
default void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@ package org.springframework.context.annotation;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
@ -25,10 +26,14 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import javax.lang.model.element.Modifier;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.aop.framework.autoproxy.AutoProxyUtils; import org.springframework.aop.framework.autoproxy.AutoProxyUtils;
import org.springframework.aot.hint.ResourceHints;
import org.springframework.aot.hint.TypeReference;
import org.springframework.beans.PropertyValues; import org.springframework.beans.PropertyValues;
import org.springframework.beans.factory.BeanClassLoaderAware; import org.springframework.beans.factory.BeanClassLoaderAware;
import org.springframework.beans.factory.BeanDefinitionStoreException; import org.springframework.beans.factory.BeanDefinitionStoreException;
@ -40,6 +45,9 @@ import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessor; import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessor;
import org.springframework.beans.factory.config.SingletonBeanRegistry; import org.springframework.beans.factory.config.SingletonBeanRegistry;
import org.springframework.beans.factory.generator.AotContributingBeanFactoryPostProcessor;
import org.springframework.beans.factory.generator.BeanFactoryContribution;
import org.springframework.beans.factory.generator.BeanFactoryInitialization;
import org.springframework.beans.factory.parsing.FailFastProblemReporter; import org.springframework.beans.factory.parsing.FailFastProblemReporter;
import org.springframework.beans.factory.parsing.PassThroughSourceExtractor; import org.springframework.beans.factory.parsing.PassThroughSourceExtractor;
import org.springframework.beans.factory.parsing.ProblemReporter; import org.springframework.beans.factory.parsing.ProblemReporter;
@ -65,6 +73,10 @@ import org.springframework.core.type.AnnotationMetadata;
import org.springframework.core.type.MethodMetadata; import org.springframework.core.type.MethodMetadata;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory; import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReaderFactory; import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.ParameterizedTypeName;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
@ -89,7 +101,8 @@ import org.springframework.util.ClassUtils;
* @since 3.0 * @since 3.0
*/ */
public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPostProcessor, public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPostProcessor,
PriorityOrdered, ResourceLoaderAware, ApplicationStartupAware, BeanClassLoaderAware, EnvironmentAware { AotContributingBeanFactoryPostProcessor, PriorityOrdered, ResourceLoaderAware, ApplicationStartupAware,
BeanClassLoaderAware, EnvironmentAware {
/** /**
* A {@code BeanNameGenerator} using fully qualified class names as default bean names. * A {@code BeanNameGenerator} using fully qualified class names as default bean names.
@ -269,6 +282,12 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo
beanFactory.addBeanPostProcessor(new ImportAwareBeanPostProcessor(beanFactory)); beanFactory.addBeanPostProcessor(new ImportAwareBeanPostProcessor(beanFactory));
} }
@Override
public BeanFactoryContribution contribute(ConfigurableListableBeanFactory beanFactory) {
return (beanFactory.containsBean(IMPORT_REGISTRY_BEAN_NAME)
? new ImportAwareBeanFactoryConfiguration(beanFactory) : null);
}
/** /**
* Build and validate a configuration model based on the registry of * Build and validate a configuration model based on the registry of
* {@link Configuration} classes. * {@link Configuration} classes.
@ -485,4 +504,55 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo
} }
} }
private static final class ImportAwareBeanFactoryConfiguration implements BeanFactoryContribution {
private final ConfigurableListableBeanFactory beanFactory;
private ImportAwareBeanFactoryConfiguration(ConfigurableListableBeanFactory beanFactory) {
this.beanFactory = beanFactory;
}
@Override
public void applyTo(BeanFactoryInitialization initialization) {
Map<String, String> mappings = buildImportAwareMappings();
if (!mappings.isEmpty()) {
MethodSpec method = initialization.generatedTypeContext().getMainGeneratedType()
.addMethod(beanPostProcessorMethod(mappings));
initialization.contribute(code -> code.addStatement("beanFactory.addBeanPostProcessor($N())", method));
ResourceHints resourceHints = initialization.generatedTypeContext().runtimeHints().resources();
mappings.forEach((target, importedFrom) -> resourceHints.registerType(
TypeReference.of(importedFrom)));
}
}
private MethodSpec.Builder beanPostProcessorMethod(Map<String, String> mappings) {
Builder code = CodeBlock.builder();
code.addStatement("$T mappings = new $T<>()", ParameterizedTypeName.get(
Map.class, String.class, String.class), HashMap.class);
mappings.forEach((key, value) -> code.addStatement("mappings.put($S, $S)", key, value));
code.addStatement("return new $T($L)", ImportAwareAotBeanPostProcessor.class, "mappings");
return MethodSpec.methodBuilder("createImportAwareBeanPostProcessor")
.returns(ImportAwareAotBeanPostProcessor.class)
.addModifiers(Modifier.PRIVATE).addCode(code.build());
}
private Map<String, String> buildImportAwareMappings() {
ImportRegistry ir = this.beanFactory.getBean(IMPORT_REGISTRY_BEAN_NAME, ImportRegistry.class);
Map<String, String> mappings = new LinkedHashMap<>();
for (String name : this.beanFactory.getBeanDefinitionNames()) {
Class<?> beanType = this.beanFactory.getType(name);
if (beanType != null && ImportAware.class.isAssignableFrom(beanType)) {
String type = ClassUtils.getUserClass(beanType).getName();
AnnotationMetadata importingClassMetadata = ir.getImportingClassFor(type);
if (importingClassMetadata != null) {
mappings.put(type, importingClassMetadata.getClassName());
}
}
}
return mappings;
}
}
} }

View File

@ -0,0 +1,75 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.context.annotation;
import java.io.IOException;
import java.util.Map;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;
/**
* A {@link BeanPostProcessor} that honours {@link ImportAware} callback using
* a mapping computed at build time.
*
* @author Stephane Nicoll
* @since 6.0
*/
public final class ImportAwareAotBeanPostProcessor implements BeanPostProcessor {
private final MetadataReaderFactory metadataReaderFactory;
private final Map<String, String> importsMapping;
public ImportAwareAotBeanPostProcessor(Map<String, String> importsMapping) {
this.metadataReaderFactory = new CachingMetadataReaderFactory();
this.importsMapping = Map.copyOf(importsMapping);
}
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) {
if (bean instanceof ImportAware) {
setAnnotationMetadata((ImportAware) bean);
}
return bean;
}
private void setAnnotationMetadata(ImportAware instance) {
String importingClass = getImportingClassFor(instance);
if (importingClass == null) {
return; // import aware configuration class not imported
}
try {
MetadataReader metadataReader = this.metadataReaderFactory.getMetadataReader(importingClass);
instance.setImportMetadata(metadataReader.getAnnotationMetadata());
}
catch (IOException ex) {
throw new IllegalStateException(String.format("Failed to read metadata for '%s'", importingClass), ex);
}
}
@Nullable
private String getImportingClassFor(ImportAware instance) {
String target = ClassUtils.getUserClass(instance).getName();
return this.importsMapping.get(target);
}
}

View File

@ -0,0 +1,90 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.context.annotation;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.springframework.core.type.AnnotationMetadata;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
/**
* Tests for {@link ImportAwareAotBeanPostProcessor}.
*
* @author Stephane Nicoll
*/
class ImportAwareAotBeanPostProcessorTests {
@Test
void postProcessOnMatchingCandidate() {
ImportAwareAotBeanPostProcessor postProcessor = new ImportAwareAotBeanPostProcessor(
Map.of(TestImportAware.class.getName(), ImportAwareAotBeanPostProcessorTests.class.getName()));
TestImportAware importAware = new TestImportAware();
postProcessor.postProcessBeforeInitialization(importAware, "test");
assertThat(importAware.importMetadata).isNotNull();
assertThat(importAware.importMetadata.getClassName())
.isEqualTo(ImportAwareAotBeanPostProcessorTests.class.getName());
}
@Test
void postProcessOnMatchingCandidateWithNestedClass() {
ImportAwareAotBeanPostProcessor postProcessor = new ImportAwareAotBeanPostProcessor(
Map.of(TestImportAware.class.getName(), TestImporting.class.getName()));
TestImportAware importAware = new TestImportAware();
postProcessor.postProcessBeforeInitialization(importAware, "test");
assertThat(importAware.importMetadata).isNotNull();
assertThat(importAware.importMetadata.getClassName())
.isEqualTo(TestImporting.class.getName());
}
@Test
void postProcessOnNoCandidateDoesNotInvokeCallback() {
ImportAwareAotBeanPostProcessor postProcessor = new ImportAwareAotBeanPostProcessor(
Map.of(String.class.getName(), ImportAwareAotBeanPostProcessorTests.class.getName()));
TestImportAware importAware = new TestImportAware();
postProcessor.postProcessBeforeInitialization(importAware, "test");
assertThat(importAware.importMetadata).isNull();
}
@Test
void postProcessOnMatchingCandidateWithNoMetadata() {
ImportAwareAotBeanPostProcessor postProcessor = new ImportAwareAotBeanPostProcessor(
Map.of(TestImportAware.class.getName(), "com.example.invalid.DoesNotExist"));
TestImportAware importAware = new TestImportAware();
assertThatIllegalStateException().isThrownBy(() -> postProcessor.postProcessBeforeInitialization(importAware, "test"))
.withMessageContaining("Failed to read metadata for 'com.example.invalid.DoesNotExist'");
}
static class TestImportAware implements ImportAware {
private AnnotationMetadata importMetadata;
@Override
public void setImportMetadata(AnnotationMetadata importMetadata) {
this.importMetadata = importMetadata;
}
}
static class TestImporting {
}
}

View File

@ -0,0 +1,115 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.context.annotation;
import java.io.IOException;
import java.io.StringWriter;
import org.junit.jupiter.api.Test;
import org.springframework.aot.generator.DefaultGeneratedTypeContext;
import org.springframework.aot.generator.GeneratedType;
import org.springframework.aot.generator.GeneratedTypeContext;
import org.springframework.beans.factory.generator.BeanFactoryContribution;
import org.springframework.beans.factory.generator.BeanFactoryInitialization;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.testfixture.beans.factory.generator.SimpleConfiguration;
import org.springframework.context.testfixture.context.generator.annotation.ImportConfiguration;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.support.CodeSnippet;
import org.springframework.lang.Nullable;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@code ImportAwareBeanFactoryConfiguration}.
*
* @author Stephane Nicoll
*/
public class ImportAwareBeanFactoryContributionTests {
@Test
void contributeWithImportAwareConfigurationRegistersBeanPostProcessor() {
BeanFactoryContribution contribution = createContribution(ImportConfiguration.class);
assertThat(contribution).isNotNull();
BeanFactoryInitialization initialization = new BeanFactoryInitialization(createGenerationContext());
contribution.applyTo(initialization);
assertThat(CodeSnippet.of(initialization.toCodeBlock()).getSnippet()).isEqualTo("""
beanFactory.addBeanPostProcessor(createImportAwareBeanPostProcessor());
""");
}
@Test
void contributeWithImportAwareConfigurationCreateMappingsMethod() {
BeanFactoryContribution contribution = createContribution(ImportConfiguration.class);
assertThat(contribution).isNotNull();
GeneratedTypeContext generationContext = createGenerationContext();
contribution.applyTo(new BeanFactoryInitialization(generationContext));
assertThat(codeOf(generationContext.getMainGeneratedType())).contains("""
private ImportAwareAotBeanPostProcessor createImportAwareBeanPostProcessor() {
Map<String, String> mappings = new HashMap<>();
mappings.put("org.springframework.context.testfixture.context.generator.annotation.ImportAwareConfiguration", "org.springframework.context.testfixture.context.generator.annotation.ImportConfiguration");
return new ImportAwareAotBeanPostProcessor(mappings);
}
""");
}
@Test
void contributeWithImportAwareConfigurationRegisterBytecodeResourceHint() {
BeanFactoryContribution contribution = createContribution(ImportConfiguration.class);
assertThat(contribution).isNotNull();
GeneratedTypeContext generationContext = createGenerationContext();
contribution.applyTo(new BeanFactoryInitialization(generationContext));
assertThat(generationContext.runtimeHints().resources().resourcePatterns())
.singleElement().satisfies(resourceHint -> assertThat(resourceHint.getIncludes()).containsOnly(
"org/springframework/context/testfixture/context/generator/annotation/ImportConfiguration.class"));
}
@Test
void contributeWithNoImportAwareConfigurationReturnsNull() {
assertThat(createContribution(SimpleConfiguration.class)).isNull();
}
@Nullable
private BeanFactoryContribution createContribution(Class<?> type) {
DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
beanFactory.registerBeanDefinition("configuration", new RootBeanDefinition(type));
ConfigurationClassPostProcessor pp = new ConfigurationClassPostProcessor();
pp.postProcessBeanFactory(beanFactory);
return pp.contribute(beanFactory);
}
private GeneratedTypeContext createGenerationContext() {
return new DefaultGeneratedTypeContext("com.example", packageName ->
GeneratedType.of(ClassName.get(packageName, "Test")));
}
private String codeOf(GeneratedType type) {
try {
StringWriter out = new StringWriter();
type.toJavaFile().writeTo(out);
return out.toString();
}
catch (IOException ex) {
throw new IllegalStateException(ex);
}
}
}

View File

@ -0,0 +1,41 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.context.testfixture.context.generator.annotation;
import org.springframework.context.EnvironmentAware;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.ImportAware;
import org.springframework.core.env.Environment;
import org.springframework.core.type.AnnotationMetadata;
@Configuration(proxyBeanMethods = false)
@SuppressWarnings("unused")
public class ImportAwareConfiguration implements ImportAware, EnvironmentAware {
private AnnotationMetadata annotationMetadata;
@Override
public void setImportMetadata(AnnotationMetadata importMetadata) {
this.annotationMetadata = importMetadata;
}
@Override
public void setEnvironment(Environment environment) {
}
}

View File

@ -0,0 +1,25 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.context.testfixture.context.generator.annotation;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
@Configuration(proxyBeanMethods = false)
@Import(ImportAwareConfiguration.class)
public class ImportConfiguration {
}