Harmonize generated class name conventions

This commit moves the responsibility of naming classes to the
GenerationContext. This was already largely the case before, except that
the concept of a "mainTarget" and "featureNamePrefix" was specific
to bean factory initialization contributors.

ClassNameGenerator should now be instantiated with a default target
and an optional feature name prefix. As a result, it does no longer
generate class names in the "__" package.

GeneratedClasses can now provide a new, unique, GeneratedClass or
offer a container for retrieving the same GeneratedClass based on an
identifier. This lets all contributors use this facility rather than
creating JavaFile manually. This also means that ClassNameGenerator
is no longer exposed.

Because the naming conventions are now part of the GenerationContext, it
is required to be able to retrieve a specialized version of it if a
code generation round needs to use different naming conventions. A new
withName method has been added to that effect.

Closes gh-28585
This commit is contained in:
Stephane Nicoll 2022-06-22 14:20:00 +02:00
parent b121eed753
commit 6199835d6e
31 changed files with 652 additions and 565 deletions

View File

@ -42,6 +42,7 @@ import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryInitializationCode; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryInitializationCode;
import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolder; import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolder;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -71,7 +72,7 @@ class ScopedProxyBeanRegistrationAotProcessorTests {
this.beanFactory = new DefaultListableBeanFactory(); this.beanFactory = new DefaultListableBeanFactory();
this.processor = new TestBeanRegistrationsAotProcessor(); this.processor = new TestBeanRegistrationsAotProcessor();
this.generatedFiles = new InMemoryGeneratedFiles(); this.generatedFiles = new InMemoryGeneratedFiles();
this.generationContext = new DefaultGenerationContext(this.generatedFiles); this.generationContext = new TestGenerationContext(this.generatedFiles);
this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode();
} }

View File

@ -41,6 +41,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.aot.generate.AccessVisibility; import org.springframework.aot.generate.AccessVisibility;
import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference; import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.hint.ExecutableHint; import org.springframework.aot.hint.ExecutableHint;
@ -79,11 +80,8 @@ import org.springframework.core.annotation.AnnotationAttributes;
import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.annotation.MergedAnnotation; import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.core.annotation.MergedAnnotations; import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.TypeSpec;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
@ -910,30 +908,28 @@ public class AutowiredAnnotationBeanPostProcessor implements SmartInstantiationA
@Override @Override
public void applyTo(GenerationContext generationContext, public void applyTo(GenerationContext generationContext,
BeanRegistrationCode beanRegistrationCode) { BeanRegistrationCode beanRegistrationCode) {
GeneratedClass generatedClass = generationContext.getGeneratedClasses()
ClassName className = generationContext.getClassNameGenerator() .forFeatureComponent("Autowiring", this.target)
.generateClassName(this.target, "Autowiring"); .generate(type -> {
TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); type.addJavadoc("Autowiring for {@link $T}.", this.target);
classBuilder.addJavadoc("Autowiring for {@link $T}.", this.target); type.addModifiers(javax.lang.model.element.Modifier.PUBLIC);
classBuilder.addModifiers(javax.lang.model.element.Modifier.PUBLIC); });
classBuilder.addMethod(generateMethod(generationContext.getRuntimeHints())); generatedClass.getMethodGenerator().generateMethod(APPLY_METHOD)
JavaFile javaFile = JavaFile .using(generateMethod(generationContext.getRuntimeHints()));
.builder(className.packageName(), classBuilder.build()).build();
generationContext.getGeneratedFiles().addSourceFile(javaFile);
beanRegistrationCode.addInstancePostProcessor( beanRegistrationCode.addInstancePostProcessor(
MethodReference.ofStatic(className, APPLY_METHOD)); MethodReference.ofStatic(generatedClass.getName(), APPLY_METHOD));
} }
private MethodSpec generateMethod(RuntimeHints hints) { private Consumer<MethodSpec.Builder> generateMethod(RuntimeHints hints) {
MethodSpec.Builder builder = MethodSpec.methodBuilder(APPLY_METHOD); return method -> {
builder.addJavadoc("Apply the autowiring."); method.addJavadoc("Apply the autowiring.");
builder.addModifiers(javax.lang.model.element.Modifier.PUBLIC, method.addModifiers(javax.lang.model.element.Modifier.PUBLIC,
javax.lang.model.element.Modifier.STATIC); javax.lang.model.element.Modifier.STATIC);
builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER);
builder.addParameter(this.target, INSTANCE_PARAMETER); method.addParameter(this.target, INSTANCE_PARAMETER);
builder.returns(this.target); method.returns(this.target);
builder.addCode(generateMethodCode(hints)); method.addCode(generateMethodCode(hints));
return builder.build(); };
} }
private CodeBlock generateMethodCode(RuntimeHints hints) { private CodeBlock generateMethodCode(RuntimeHints hints) {

View File

@ -21,10 +21,8 @@ import java.util.List;
import javax.lang.model.element.Modifier; import javax.lang.model.element.Modifier;
import org.springframework.aot.generate.ClassGenerator.JavaFileGenerator;
import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.generate.MethodGenerator;
import org.springframework.aot.generate.MethodNameGenerator; import org.springframework.aot.generate.MethodNameGenerator;
@ -32,8 +30,6 @@ import org.springframework.aot.generate.MethodReference;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.TypeSpec;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
/** /**
@ -45,6 +41,8 @@ import org.springframework.lang.Nullable;
*/ */
class BeanDefinitionMethodGenerator { class BeanDefinitionMethodGenerator {
private static final String FEATURE_NAME = "BeanDefinitions";
private final BeanDefinitionMethodGeneratorFactory methodGeneratorFactory; private final BeanDefinitionMethodGeneratorFactory methodGeneratorFactory;
private final RegisteredBean registeredBean; private final RegisteredBean registeredBean;
@ -81,22 +79,23 @@ class BeanDefinitionMethodGenerator {
* Generate the method that returns the {@link BeanDefinition} to be * Generate the method that returns the {@link BeanDefinition} to be
* registered. * registered.
* @param generationContext the generation context * @param generationContext the generation context
* @param featureNamePrefix the prefix to use for the feature name
* @param beanRegistrationsCode the bean registrations code * @param beanRegistrationsCode the bean registrations code
* @return a reference to the generated method. * @return a reference to the generated method.
*/ */
MethodReference generateBeanDefinitionMethod(GenerationContext generationContext, MethodReference generateBeanDefinitionMethod(GenerationContext generationContext,
String featureNamePrefix, BeanRegistrationsCode beanRegistrationsCode) { BeanRegistrationsCode beanRegistrationsCode) {
BeanRegistrationCodeFragments codeFragments = getCodeFragments(generationContext, BeanRegistrationCodeFragments codeFragments = getCodeFragments(generationContext,
beanRegistrationsCode, featureNamePrefix); beanRegistrationsCode);
Class<?> target = codeFragments.getTarget(this.registeredBean, Class<?> target = codeFragments.getTarget(this.registeredBean,
this.constructorOrFactoryMethod); this.constructorOrFactoryMethod);
if (!target.getName().startsWith("java.")) { if (!target.getName().startsWith("java.")) {
String featureName = featureNamePrefix + "BeanDefinitions"; GeneratedClass generatedClass = generationContext.getGeneratedClasses()
GeneratedClass generatedClass = generationContext.getClassGenerator() .forFeatureComponent(FEATURE_NAME, target)
.getOrGenerateClass(new BeanDefinitionsJavaFileGenerator(target), .getOrGenerate(FEATURE_NAME, type -> {
target, featureName); type.addJavadoc("Bean definitions for {@link $T}", target);
type.addModifiers(Modifier.PUBLIC);
});
MethodGenerator methodGenerator = generatedClass.getMethodGenerator() MethodGenerator methodGenerator = generatedClass.getMethodGenerator()
.withName(getName()); .withName(getName());
GeneratedMethod generatedMethod = generateBeanDefinitionMethod( GeneratedMethod generatedMethod = generateBeanDefinitionMethod(
@ -115,11 +114,10 @@ class BeanDefinitionMethodGenerator {
} }
private BeanRegistrationCodeFragments getCodeFragments(GenerationContext generationContext, private BeanRegistrationCodeFragments getCodeFragments(GenerationContext generationContext,
BeanRegistrationsCode beanRegistrationsCode, String featureNamePrefix) { BeanRegistrationsCode beanRegistrationsCode) {
BeanRegistrationCodeFragments codeFragments = new DefaultBeanRegistrationCodeFragments( BeanRegistrationCodeFragments codeFragments = new DefaultBeanRegistrationCodeFragments(
beanRegistrationsCode, this.registeredBean, this.methodGeneratorFactory, beanRegistrationsCode, this.registeredBean, this.methodGeneratorFactory);
featureNamePrefix);
for (BeanRegistrationAotContribution aotContribution : this.aotContributions) { for (BeanRegistrationAotContribution aotContribution : this.aotContributions) {
codeFragments = aotContribution.customizeBeanRegistrationCodeFragments(generationContext, codeFragments); codeFragments = aotContribution.customizeBeanRegistrationCodeFragments(generationContext, codeFragments);
} }
@ -172,41 +170,4 @@ class BeanDefinitionMethodGenerator {
return beanName; return beanName;
} }
/**
* {@link BeanDefinitionsJavaFileGenerator} to create the
* {@code BeanDefinitions} file.
*/
private static class BeanDefinitionsJavaFileGenerator implements JavaFileGenerator {
private final Class<?> target;
BeanDefinitionsJavaFileGenerator(Class<?> target) {
this.target = target;
}
@Override
public JavaFile generateJavaFile(ClassName className, GeneratedMethods methods) {
TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className);
classBuilder.addJavadoc("Bean definitions for {@link $T}", this.target);
classBuilder.addModifiers(Modifier.PUBLIC);
methods.doWithMethodSpecs(classBuilder::addMethod);
return JavaFile.builder(className.packageName(), classBuilder.build())
.build();
}
@Override
public int hashCode() {
return getClass().hashCode();
}
@Override
public boolean equals(Object obj) {
return getClass() == obj.getClass();
}
}
} }

View File

@ -18,7 +18,6 @@ package org.springframework.beans.factory.aot;
import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.generate.MethodGenerator;
import org.springframework.aot.generate.MethodReference; import org.springframework.aot.generate.MethodReference;
import org.springframework.lang.Nullable;
/** /**
* Interface that can be used to configure the code that will be generated to * Interface that can be used to configure the code that will be generated to
@ -35,24 +34,6 @@ public interface BeanFactoryInitializationCode {
*/ */
String BEAN_FACTORY_VARIABLE = "beanFactory"; String BEAN_FACTORY_VARIABLE = "beanFactory";
/**
* Return the target class for this bean factory or {@code null} if there is
* no target.
* @return the target
*/
@Nullable
default Class<?> getTarget() {
return null;
}
/**
* Return the name of the bean factory or and empty string if no ID is available.
* @return the bean factory name
*/
default String getName() {
return "";
}
/** /**
* Return a {@link MethodGenerator} that can be used to add more methods to * Return a {@link MethodGenerator} that can be used to add more methods to
* the Initializing code. * the Initializing code.

View File

@ -20,17 +20,15 @@ import java.util.Map;
import javax.lang.model.element.Modifier; import javax.lang.model.element.Modifier;
import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.generate.MethodGenerator;
import org.springframework.aot.generate.MethodReference; import org.springframework.aot.generate.MethodReference;
import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.TypeSpec;
/** /**
* AOT contribution from a {@link BeanRegistrationsAotProcessor} used to * AOT contribution from a {@link BeanRegistrationsAotProcessor} used to
@ -61,24 +59,23 @@ class BeanRegistrationsAotContribution
public void applyTo(GenerationContext generationContext, public void applyTo(GenerationContext generationContext,
BeanFactoryInitializationCode beanFactoryInitializationCode) { BeanFactoryInitializationCode beanFactoryInitializationCode) {
ClassName className = generationContext.getClassNameGenerator().generateClassName( GeneratedClass generatedClass = generationContext.getGeneratedClasses()
beanFactoryInitializationCode.getTarget(), .forFeature("BeanFactoryRegistrations").generate(type -> {
beanFactoryInitializationCode.getName() + "BeanFactoryRegistrations"); type.addJavadoc("Register bean definitions for the bean factory.");
type.addModifiers(Modifier.PUBLIC);
});
BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator( BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator(
className); generatedClass);
GeneratedMethod registerMethod = codeGenerator.getMethodGenerator() GeneratedMethod registerMethod = codeGenerator.getMethodGenerator()
.generateMethod("registerBeanDefinitions") .generateMethod("registerBeanDefinitions")
.using(builder -> generateRegisterMethod(builder, generationContext, .using(builder -> generateRegisterMethod(builder, generationContext,
beanFactoryInitializationCode.getName(),
codeGenerator)); codeGenerator));
JavaFile javaFile = codeGenerator.generatedJavaFile(className);
generationContext.getGeneratedFiles().addSourceFile(javaFile);
beanFactoryInitializationCode beanFactoryInitializationCode
.addInitializer(MethodReference.of(className, registerMethod.getName())); .addInitializer(MethodReference.of(generatedClass.getName(), registerMethod.getName()));
} }
private void generateRegisterMethod(MethodSpec.Builder builder, private void generateRegisterMethod(MethodSpec.Builder builder,
GenerationContext generationContext, String featureNamePrefix, GenerationContext generationContext,
BeanRegistrationsCode beanRegistrationsCode) { BeanRegistrationsCode beanRegistrationsCode) {
builder.addJavadoc("Register the bean definitions."); builder.addJavadoc("Register the bean definitions.");
@ -88,7 +85,7 @@ class BeanRegistrationsAotContribution
CodeBlock.Builder code = CodeBlock.builder(); CodeBlock.Builder code = CodeBlock.builder();
this.registrations.forEach((beanName, beanDefinitionMethodGenerator) -> { this.registrations.forEach((beanName, beanDefinitionMethodGenerator) -> {
MethodReference beanDefinitionMethod = beanDefinitionMethodGenerator MethodReference beanDefinitionMethod = beanDefinitionMethodGenerator
.generateBeanDefinitionMethod(generationContext, featureNamePrefix, .generateBeanDefinitionMethod(generationContext,
beanRegistrationsCode); beanRegistrationsCode);
code.addStatement("$L.registerBeanDefinition($S, $L)", code.addStatement("$L.registerBeanDefinition($S, $L)",
BEAN_FACTORY_PARAMETER_NAME, beanName, BEAN_FACTORY_PARAMETER_NAME, beanName,
@ -103,33 +100,21 @@ class BeanRegistrationsAotContribution
*/ */
static class BeanRegistrationsCodeGenerator implements BeanRegistrationsCode { static class BeanRegistrationsCodeGenerator implements BeanRegistrationsCode {
private final ClassName className; private final GeneratedClass generatedClass;
private final GeneratedMethods generatedMethods = new GeneratedMethods(); public BeanRegistrationsCodeGenerator(GeneratedClass generatedClass) {
this.generatedClass = generatedClass;
public BeanRegistrationsCodeGenerator(ClassName className) {
this.className = className;
} }
@Override @Override
public ClassName getClassName() { public ClassName getClassName() {
return this.className; return this.generatedClass.getName();
} }
@Override @Override
public MethodGenerator getMethodGenerator() { public MethodGenerator getMethodGenerator() {
return this.generatedMethods; return this.generatedClass.getMethodGenerator();
}
JavaFile generatedJavaFile(ClassName className) {
TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className);
classBuilder.addJavadoc("Register bean definitions for the bean factory.");
classBuilder.addModifiers(Modifier.PUBLIC);
this.generatedMethods.doWithMethodSpecs(classBuilder::addMethod);
return JavaFile.builder(className.packageName(), classBuilder.build())
.build();
} }
} }

View File

@ -54,18 +54,14 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments
private final BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory; private final BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory;
private final String featureNamePrefix;
DefaultBeanRegistrationCodeFragments(BeanRegistrationsCode beanRegistrationsCode, DefaultBeanRegistrationCodeFragments(BeanRegistrationsCode beanRegistrationsCode,
RegisteredBean registeredBean, RegisteredBean registeredBean,
BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory, BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory) {
String featureNamePrefix) {
this.beanRegistrationsCode = beanRegistrationsCode; this.beanRegistrationsCode = beanRegistrationsCode;
this.registeredBean = registeredBean; this.registeredBean = registeredBean;
this.beanDefinitionMethodGeneratorFactory = beanDefinitionMethodGeneratorFactory; this.beanDefinitionMethodGeneratorFactory = beanDefinitionMethodGeneratorFactory;
this.featureNamePrefix = featureNamePrefix;
} }
@ -124,7 +120,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments
.getBeanDefinitionMethodGenerator(innerRegisteredBean, name); .getBeanDefinitionMethodGenerator(innerRegisteredBean, name);
Assert.state(methodGenerator != null, "Unexpected filtering of inner-bean"); Assert.state(methodGenerator != null, "Unexpected filtering of inner-bean");
MethodReference generatedMethod = methodGenerator MethodReference generatedMethod = methodGenerator
.generateBeanDefinitionMethod(generationContext, this.featureNamePrefix, .generateBeanDefinitionMethod(generationContext,
this.beanRegistrationsCode); this.beanRegistrationsCode);
return generatedMethod.toInvokeCodeBlock(); return generatedMethod.toInvokeCodeBlock();
} }

View File

@ -25,7 +25,6 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.DefaultGenerationContext; import org.springframework.aot.generate.DefaultGenerationContext;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.generate.InMemoryGeneratedFiles;
import org.springframework.aot.generate.MethodReference; import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHints;
@ -40,6 +39,7 @@ import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationCode; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationCode;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
import org.springframework.core.env.StandardEnvironment; import org.springframework.core.env.StandardEnvironment;
import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.MethodSpec;
@ -59,7 +59,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests {
private InMemoryGeneratedFiles generatedFiles; private InMemoryGeneratedFiles generatedFiles;
private GenerationContext generationContext; private DefaultGenerationContext generationContext;
private RuntimeHints runtimeHints; private RuntimeHints runtimeHints;
@ -70,7 +70,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests {
@BeforeEach @BeforeEach
void setup() { void setup() {
this.generatedFiles = new InMemoryGeneratedFiles(); this.generatedFiles = new InMemoryGeneratedFiles();
this.generationContext = new DefaultGenerationContext(this.generatedFiles); this.generationContext = new TestGenerationContext(this.generatedFiles);
this.runtimeHints = this.generationContext.getRuntimeHints(); this.runtimeHints = this.generationContext.getRuntimeHints();
this.beanRegistrationCode = new MockBeanRegistrationCode(); this.beanRegistrationCode = new MockBeanRegistrationCode();
this.beanFactory = new DefaultListableBeanFactory(); this.beanFactory = new DefaultListableBeanFactory();
@ -169,6 +169,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private void testCompiledResult(RegisteredBean registeredBean, private void testCompiledResult(RegisteredBean registeredBean,
BiConsumer<BiFunction<RegisteredBean, Object, Object>, Compiled> result) { BiConsumer<BiFunction<RegisteredBean, Object, Object>, Compiled> result) {
this.generationContext.writeGeneratedContent();
JavaFile javaFile = createJavaFile(registeredBean.getBeanClass()); JavaFile javaFile = createJavaFile(registeredBean.getBeanClass());
TestCompiler.forSystem().withFiles(this.generatedFiles).compile(javaFile::writeTo, TestCompiler.forSystem().withFiles(this.generatedFiles).compile(javaFile::writeTo,
compiled -> result.accept(compiled.getInstance(BiFunction.class), compiled -> result.accept(compiled.getInstance(BiFunction.class),

View File

@ -50,6 +50,7 @@ import org.springframework.beans.testfixture.beans.TestBean;
import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.core.mock.MockSpringFactoriesLoader; import org.springframework.core.mock.MockSpringFactoriesLoader;
import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.JavaFile;
@ -80,7 +81,7 @@ class BeanDefinitionMethodGeneratorTests {
@BeforeEach @BeforeEach
void setup() { void setup() {
this.generatedFiles = new InMemoryGeneratedFiles(); this.generatedFiles = new InMemoryGeneratedFiles();
this.generationContext = new DefaultGenerationContext(this.generatedFiles); this.generationContext = new TestGenerationContext(this.generatedFiles);
this.beanFactory = new DefaultListableBeanFactory(); this.beanFactory = new DefaultListableBeanFactory();
this.methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( this.methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory(
new AotFactoriesLoader(this.beanFactory, new MockSpringFactoriesLoader())); new AotFactoriesLoader(this.beanFactory, new MockSpringFactoriesLoader()));
@ -96,7 +97,7 @@ class BeanDefinitionMethodGeneratorTests {
this.methodGeneratorFactory, registeredBean, null, this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList()); Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod( MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, "", this.beanRegistrationsCode); this.generationContext, this.beanRegistrationsCode);
testCompiledResult(method, (actual, compiled) -> { testCompiledResult(method, (actual, compiled) -> {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
@ -114,7 +115,7 @@ class BeanDefinitionMethodGeneratorTests {
this.methodGeneratorFactory, registeredBean, null, this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList()); Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod( MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, "", this.beanRegistrationsCode); this.generationContext, this.beanRegistrationsCode);
testCompiledResult(method, (actual, compiled) -> { testCompiledResult(method, (actual, compiled) -> {
assertThat(actual.getResolvableType().resolve()).isEqualTo(GenericBean.class); assertThat(actual.getResolvableType().resolve()).isEqualTo(GenericBean.class);
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
@ -147,7 +148,7 @@ class BeanDefinitionMethodGeneratorTests {
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null, aotContributions); this.methodGeneratorFactory, registeredBean, null, aotContributions);
MethodReference method = generator.generateBeanDefinitionMethod( MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, "", this.beanRegistrationsCode); this.generationContext, this.beanRegistrationsCode);
testCompiledResult(method, (actual, compiled) -> { testCompiledResult(method, (actual, compiled) -> {
assertThat(actual.getBeanClass()).isEqualTo(TestBean.class); assertThat(actual.getBeanClass()).isEqualTo(TestBean.class);
InstanceSupplier<?> supplier = (InstanceSupplier<?>) actual InstanceSupplier<?> supplier = (InstanceSupplier<?>) actual
@ -173,7 +174,7 @@ class BeanDefinitionMethodGeneratorTests {
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null, aotContributions); this.methodGeneratorFactory, registeredBean, null, aotContributions);
MethodReference method = generator.generateBeanDefinitionMethod( MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, "", this.beanRegistrationsCode); this.generationContext, this.beanRegistrationsCode);
testCompiledResult(method, (actual, compiled) -> { testCompiledResult(method, (actual, compiled) -> {
assertThat(actual.getBeanClass()).isEqualTo(TestBean.class); assertThat(actual.getBeanClass()).isEqualTo(TestBean.class);
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
@ -213,7 +214,7 @@ class BeanDefinitionMethodGeneratorTests {
this.methodGeneratorFactory, registeredBean, null, this.methodGeneratorFactory, registeredBean, null,
aotContributions); aotContributions);
MethodReference method = generator.generateBeanDefinitionMethod( MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, "", this.beanRegistrationsCode); this.generationContext, this.beanRegistrationsCode);
testCompiledResult(method, (actual, compiled) -> { testCompiledResult(method, (actual, compiled) -> {
assertThat(actual.getAttribute("a")).isEqualTo("A"); assertThat(actual.getAttribute("a")).isEqualTo("A");
assertThat(actual.getAttribute("b")).isNull(); assertThat(actual.getAttribute("b")).isNull();
@ -246,7 +247,7 @@ class BeanDefinitionMethodGeneratorTests {
this.methodGeneratorFactory, innerBean, "testInnerBean", this.methodGeneratorFactory, innerBean, "testInnerBean",
Collections.emptyList()); Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod( MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, "", this.beanRegistrationsCode); this.generationContext, this.beanRegistrationsCode);
testCompiledResult(method, (actual, compiled) -> { testCompiledResult(method, (actual, compiled) -> {
assertThat(compiled.getSourceFile(".*BeanDefinitions")) assertThat(compiled.getSourceFile(".*BeanDefinitions"))
.contains("Get the inner-bean definition for 'testInnerBean'"); .contains("Get the inner-bean definition for 'testInnerBean'");
@ -267,7 +268,7 @@ class BeanDefinitionMethodGeneratorTests {
this.methodGeneratorFactory, registeredBean, null, this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList()); Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod( MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, "", this.beanRegistrationsCode); this.generationContext, this.beanRegistrationsCode);
testCompiledResult(method, (actual, compiled) -> { testCompiledResult(method, (actual, compiled) -> {
RootBeanDefinition actualInnerBeanDefinition = (RootBeanDefinition) actual RootBeanDefinition actualInnerBeanDefinition = (RootBeanDefinition) actual
.getPropertyValues().get("name"); .getPropertyValues().get("name");
@ -301,7 +302,7 @@ class BeanDefinitionMethodGeneratorTests {
this.methodGeneratorFactory, registeredBean, null, this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList()); Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod( MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, "", this.beanRegistrationsCode); this.generationContext, this.beanRegistrationsCode);
testCompiledResult(method, (actual, compiled) -> { testCompiledResult(method, (actual, compiled) -> {
RootBeanDefinition actualInnerBeanDefinition = (RootBeanDefinition) actual RootBeanDefinition actualInnerBeanDefinition = (RootBeanDefinition) actual
.getConstructorArgumentValues() .getConstructorArgumentValues()
@ -334,7 +335,7 @@ class BeanDefinitionMethodGeneratorTests {
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null, aotContributions); this.methodGeneratorFactory, registeredBean, null, aotContributions);
MethodReference method = generator.generateBeanDefinitionMethod( MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, "", this.beanRegistrationsCode); this.generationContext, this.beanRegistrationsCode);
testCompiledResult(method, (actual, compiled) -> { testCompiledResult(method, (actual, compiled) -> {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("AotContributedMethod()"); assertThat(sourceFile).contains("AotContributedMethod()");
@ -351,7 +352,7 @@ class BeanDefinitionMethodGeneratorTests {
this.methodGeneratorFactory, registeredBean, null, this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList()); Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod( MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, "", this.beanRegistrationsCode); this.generationContext, this.beanRegistrationsCode);
testCompiledResult(method, (actual, compiled) -> { testCompiledResult(method, (actual, compiled) -> {
DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory();
freshBeanFactory.registerBeanDefinition("test", actual); freshBeanFactory.registerBeanDefinition("test", actual);

View File

@ -29,6 +29,7 @@ import javax.lang.model.element.Modifier;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.ClassNameGenerator;
import org.springframework.aot.generate.DefaultGenerationContext; import org.springframework.aot.generate.DefaultGenerationContext;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.generate.InMemoryGeneratedFiles;
@ -42,6 +43,8 @@ import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.TestBean;
import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryInitializationCode; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryInitializationCode;
import org.springframework.core.mock.MockSpringFactoriesLoader; import org.springframework.core.mock.MockSpringFactoriesLoader;
import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import org.springframework.core.testfixture.aot.generate.TestTarget;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.MethodSpec;
@ -72,7 +75,7 @@ class BeanRegistrationsAotContributionTests {
@BeforeEach @BeforeEach
void setup() { void setup() {
this.generatedFiles = new InMemoryGeneratedFiles(); this.generatedFiles = new InMemoryGeneratedFiles();
this.generationContext = new DefaultGenerationContext(this.generatedFiles); this.generationContext = new TestGenerationContext(this.generatedFiles);
this.beanFactory = new DefaultListableBeanFactory(); this.beanFactory = new DefaultListableBeanFactory();
this.springFactoriesLoader = new MockSpringFactoriesLoader(); this.springFactoriesLoader = new MockSpringFactoriesLoader();
this.methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( this.methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory(
@ -100,7 +103,9 @@ class BeanRegistrationsAotContributionTests {
@Test @Test
void applyToWhenHasNameGeneratesPrefixedFeatureName() { void applyToWhenHasNameGeneratesPrefixedFeatureName() {
this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode("Management"); this.generationContext = new DefaultGenerationContext(
new ClassNameGenerator(TestTarget.class, "Management"), this.generatedFiles);
this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode();
Map<String, BeanDefinitionMethodGenerator> registrations = new LinkedHashMap<>(); Map<String, BeanDefinitionMethodGenerator> registrations = new LinkedHashMap<>();
RegisteredBean registeredBean = registerBean( RegisteredBean registeredBean = registerBean(
new RootBeanDefinition(TestBean.class)); new RootBeanDefinition(TestBean.class));
@ -129,11 +134,11 @@ class BeanRegistrationsAotContributionTests {
@Override @Override
MethodReference generateBeanDefinitionMethod( MethodReference generateBeanDefinitionMethod(
GenerationContext generationContext, String featureNamePrefix, GenerationContext generationContext,
BeanRegistrationsCode beanRegistrationsCode) { BeanRegistrationsCode beanRegistrationsCode) {
beanRegistrationsCodes.add(beanRegistrationsCode); beanRegistrationsCodes.add(beanRegistrationsCode);
return super.generateBeanDefinitionMethod(generationContext, return super.generateBeanDefinitionMethod(generationContext,
featureNamePrefix, beanRegistrationsCode); beanRegistrationsCode);
} }
}; };

View File

@ -52,6 +52,7 @@ import org.springframework.beans.testfixture.beans.factory.generator.factory.Num
import org.springframework.beans.testfixture.beans.factory.generator.factory.SampleFactory; import org.springframework.beans.testfixture.beans.factory.generator.factory.SampleFactory;
import org.springframework.beans.testfixture.beans.factory.generator.injection.InjectionComponent; import org.springframework.beans.testfixture.beans.factory.generator.injection.InjectionComponent;
import org.springframework.core.env.StandardEnvironment; import org.springframework.core.env.StandardEnvironment;
import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.JavaFile;
@ -82,7 +83,7 @@ class InstanceSupplierCodeGeneratorTests {
@BeforeEach @BeforeEach
void setup() { void setup() {
this.generatedFiles = new InMemoryGeneratedFiles(); this.generatedFiles = new InMemoryGeneratedFiles();
this.generationContext = new DefaultGenerationContext(this.generatedFiles); this.generationContext = new TestGenerationContext(this.generatedFiles);
} }

View File

@ -35,21 +35,6 @@ public class MockBeanFactoryInitializationCode implements BeanFactoryInitializat
private final List<MethodReference> initializers = new ArrayList<>(); private final List<MethodReference> initializers = new ArrayList<>();
private final String name;
public MockBeanFactoryInitializationCode() {
this("");
}
public MockBeanFactoryInitializationCode(String name) {
this.name = name;
}
@Override
public String getName() {
return this.name;
}
@Override @Override
public GeneratedMethods getMethodGenerator() { public GeneratedMethods getMethodGenerator() {
return this.generatedMethods; return this.generatedMethods;

View File

@ -16,14 +16,14 @@
package org.springframework.context.aot; package org.springframework.context.aot;
import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.support.GenericApplicationContext; import org.springframework.context.support.GenericApplicationContext;
import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.JavaFile;
import org.springframework.lang.Nullable;
/** /**
* Process an {@link ApplicationContext} and its {@link BeanFactory} to generate * Process an {@link ApplicationContext} and its {@link BeanFactory} to generate
@ -42,41 +42,20 @@ public class ApplicationContextAotGenerator {
* specified {@link GenerationContext}. * specified {@link GenerationContext}.
* @param applicationContext the application context to handle * @param applicationContext the application context to handle
* @param generationContext the generation context to use * @param generationContext the generation context to use
* @param generatedInitializerClassName the class name to use for the * @return the class name of the {@link ApplicationContextInitializer} entry point
* generated application context initializer
*/ */
public void generateApplicationContext(GenericApplicationContext applicationContext, public ClassName generateApplicationContext(GenericApplicationContext applicationContext,
GenerationContext generationContext, GenerationContext generationContext) {
ClassName generatedInitializerClassName) {
generateApplicationContext(applicationContext, null, null, generationContext,
generatedInitializerClassName);
}
/**
* Refresh the specified {@link GenericApplicationContext} and generate the
* necessary code to restore the state of its {@link BeanFactory}, using the
* specified {@link GenerationContext}.
* @param applicationContext the application context to handle
* @param target the target class for the generated initializer (used when generating class names)
* @param name the name of the application context (used when generating class names)
* @param generationContext the generation context to use
* @param generatedInitializerClassName the class name to use for the
* generated application context initializer
*/
public void generateApplicationContext(GenericApplicationContext applicationContext,
@Nullable Class<?> target, @Nullable String name, GenerationContext generationContext,
ClassName generatedInitializerClassName) {
applicationContext.refreshForAotProcessing(); applicationContext.refreshForAotProcessing();
DefaultListableBeanFactory beanFactory = applicationContext DefaultListableBeanFactory beanFactory = applicationContext
.getDefaultListableBeanFactory(); .getDefaultListableBeanFactory();
ApplicationContextInitializationCodeGenerator codeGenerator = new ApplicationContextInitializationCodeGenerator( ApplicationContextInitializationCodeGenerator codeGenerator = new ApplicationContextInitializationCodeGenerator();
target, name);
new BeanFactoryInitializationAotContributions(beanFactory).applyTo(generationContext, new BeanFactoryInitializationAotContributions(beanFactory).applyTo(generationContext,
codeGenerator); codeGenerator);
JavaFile javaFile = codeGenerator.generateJavaFile(generatedInitializerClassName); GeneratedClass applicationContextInitializer = generationContext.getGeneratedClasses()
generationContext.getGeneratedFiles().addSourceFile(javaFile); .forFeature("ApplicationContextInitializer")
.generate(codeGenerator.generateJavaFile());
return applicationContextInitializer.getName();
} }
} }

View File

@ -18,6 +18,7 @@ package org.springframework.context.aot;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.function.Consumer;
import javax.lang.model.element.Modifier; import javax.lang.model.element.Modifier;
@ -29,14 +30,10 @@ import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver; import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver;
import org.springframework.context.support.GenericApplicationContext; import org.springframework.context.support.GenericApplicationContext;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.ParameterizedTypeName; import org.springframework.javapoet.ParameterizedTypeName;
import org.springframework.javapoet.TypeSpec; import org.springframework.javapoet.TypeSpec;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
/** /**
* Internal code generator to create the application context initializer. * Internal code generator to create the application context initializer.
@ -50,33 +47,11 @@ class ApplicationContextInitializationCodeGenerator
private static final String APPLICATION_CONTEXT_VARIABLE = "applicationContext"; private static final String APPLICATION_CONTEXT_VARIABLE = "applicationContext";
@Nullable
private final Class<?> target;
private final String name;
private final GeneratedMethods generatedMethods = new GeneratedMethods(); private final GeneratedMethods generatedMethods = new GeneratedMethods();
private final List<MethodReference> initializers = new ArrayList<>(); private final List<MethodReference> initializers = new ArrayList<>();
ApplicationContextInitializationCodeGenerator(@Nullable Class<?> target, @Nullable String name) {
this.target = target;
this.name = (!StringUtils.hasText(name)) ? "" : name;
}
@Override
@Nullable
public Class<?> getTarget() {
return this.target;
}
@Override
public String getName() {
return this.name;
}
@Override @Override
public MethodGenerator getMethodGenerator() { public MethodGenerator getMethodGenerator() {
return this.generatedMethods; return this.generatedMethods;
@ -87,8 +62,8 @@ class ApplicationContextInitializationCodeGenerator
this.initializers.add(methodReference); this.initializers.add(methodReference);
} }
JavaFile generateJavaFile(ClassName className) { Consumer<TypeSpec.Builder> generateJavaFile() {
TypeSpec.Builder builder = TypeSpec.classBuilder(className); return builder -> {
builder.addJavadoc( builder.addJavadoc(
"{@link $T} to restore an application context based on previous AOT processing.", "{@link $T} to restore an application context based on previous AOT processing.",
ApplicationContextInitializer.class); ApplicationContextInitializer.class);
@ -97,7 +72,7 @@ class ApplicationContextInitializationCodeGenerator
ApplicationContextInitializer.class, GenericApplicationContext.class)); ApplicationContextInitializer.class, GenericApplicationContext.class));
builder.addMethod(generateInitializeMethod()); builder.addMethod(generateInitializeMethod());
this.generatedMethods.doWithMethodSpecs(builder::addMethod); this.generatedMethods.doWithMethodSpecs(builder::addMethod);
return JavaFile.builder(className.packageName(), builder.build()).build(); };
} }
private MethodSpec generateInitializeMethod() { private MethodSpec generateInitializeMethod() {

View File

@ -37,6 +37,7 @@ import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryIn
import org.springframework.beans.testfixture.beans.factory.generator.SimpleConfiguration; import org.springframework.beans.testfixture.beans.factory.generator.SimpleConfiguration;
import org.springframework.context.testfixture.context.generator.annotation.ImportAwareConfiguration; import org.springframework.context.testfixture.context.generator.annotation.ImportAwareConfiguration;
import org.springframework.context.testfixture.context.generator.annotation.ImportConfiguration; import org.springframework.context.testfixture.context.generator.annotation.ImportConfiguration;
import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.MethodSpec;
@ -59,7 +60,7 @@ class ConfigurationClassPostProcessorAotContributionTests {
private InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); private InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles();
private DefaultGenerationContext generationContext = new DefaultGenerationContext( private DefaultGenerationContext generationContext = new TestGenerationContext(
this.generatedFiles); this.generatedFiles);
private MockBeanFactoryInitializationCode beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); private MockBeanFactoryInitializationCode beanFactoryInitializationCode = new MockBeanFactoryInitializationCode();

View File

@ -44,7 +44,7 @@ import org.springframework.context.support.GenericApplicationContext;
import org.springframework.context.testfixture.context.generator.SimpleComponent; import org.springframework.context.testfixture.context.generator.SimpleComponent;
import org.springframework.context.testfixture.context.generator.annotation.AutowiredComponent; import org.springframework.context.testfixture.context.generator.annotation.AutowiredComponent;
import org.springframework.context.testfixture.context.generator.annotation.InitDestroyComponent; import org.springframework.context.testfixture.context.generator.annotation.InitDestroyComponent;
import org.springframework.javapoet.ClassName; import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -56,9 +56,6 @@ import static org.assertj.core.api.Assertions.assertThat;
*/ */
class ApplicationContextAotGeneratorTests { class ApplicationContextAotGeneratorTests {
private static final ClassName MAIN_GENERATED_TYPE = ClassName.get("__",
"TestInitializer");
@Test @Test
void generateApplicationContextWhenHasSimpleBean() { void generateApplicationContextWhenHasSimpleBean() {
GenericApplicationContext applicationContext = new GenericApplicationContext(); GenericApplicationContext applicationContext = new GenericApplicationContext();
@ -191,10 +188,9 @@ class ApplicationContextAotGeneratorTests {
BiConsumer<ApplicationContextInitializer<GenericApplicationContext>, Compiled> result) { BiConsumer<ApplicationContextInitializer<GenericApplicationContext>, Compiled> result) {
ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator();
InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles();
DefaultGenerationContext generationContext = new DefaultGenerationContext( DefaultGenerationContext generationContext = new TestGenerationContext(
generatedFiles); generatedFiles);
generator.generateApplicationContext(applicationContext, generationContext, generator.generateApplicationContext(applicationContext, generationContext);
MAIN_GENERATED_TYPE);
generationContext.writeGeneratedContent(); generationContext.writeGeneratedContent();
TestCompiler.forSystem().withFiles(generatedFiles) TestCompiler.forSystem().withFiles(generatedFiles)
.compile(compiled -> result.accept( .compile(compiled -> result.accept(

View File

@ -24,9 +24,7 @@ import java.lang.annotation.Target;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.DefaultGenerationContext;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.InMemoryGeneratedFiles;
import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsPredicates; import org.springframework.aot.hint.RuntimeHintsPredicates;
@ -39,6 +37,7 @@ import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.annotation.AliasFor; import org.springframework.core.annotation.AliasFor;
import org.springframework.core.annotation.SynthesizedAnnotation; import org.springframework.core.annotation.SynthesizedAnnotation;
import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -54,8 +53,7 @@ class ReflectiveProcessorBeanRegistrationAotProcessorTests {
private final ReflectiveProcessorBeanRegistrationAotProcessor processor = new ReflectiveProcessorBeanRegistrationAotProcessor(); private final ReflectiveProcessorBeanRegistrationAotProcessor processor = new ReflectiveProcessorBeanRegistrationAotProcessor();
private final GenerationContext generationContext = new DefaultGenerationContext( private final GenerationContext generationContext = new TestGenerationContext();
new InMemoryGeneratedFiles());
@Test @Test
void shouldIgnoreNonAnnotatedType() { void shouldIgnoreNonAnnotatedType() {

View File

@ -25,9 +25,7 @@ import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.DefaultGenerationContext;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.InMemoryGeneratedFiles;
import org.springframework.aot.hint.ResourceBundleHint; import org.springframework.aot.hint.ResourceBundleHint;
import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar; import org.springframework.aot.hint.RuntimeHintsRegistrar;
@ -38,7 +36,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.ImportRuntimeHints; import org.springframework.context.annotation.ImportRuntimeHints;
import org.springframework.context.support.GenericApplicationContext; import org.springframework.context.support.GenericApplicationContext;
import org.springframework.javapoet.ClassName; import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -51,17 +49,13 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
*/ */
class RuntimeHintsBeanFactoryInitializationAotProcessorTests { class RuntimeHintsBeanFactoryInitializationAotProcessorTests {
private static final ClassName MAIN_GENERATED_TYPE = ClassName.get("__",
"TestInitializer");
private GenerationContext generationContext; private GenerationContext generationContext;
private ApplicationContextAotGenerator generator; private ApplicationContextAotGenerator generator;
@BeforeEach @BeforeEach
void setup() { void setup() {
this.generationContext = new DefaultGenerationContext( this.generationContext = new TestGenerationContext();
new InMemoryGeneratedFiles());
this.generator = new ApplicationContextAotGenerator(); this.generator = new ApplicationContextAotGenerator();
} }
@ -70,7 +64,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests {
GenericApplicationContext applicationContext = createApplicationContext( GenericApplicationContext applicationContext = createApplicationContext(
ConfigurationWithHints.class); ConfigurationWithHints.class);
this.generator.generateApplicationContext(applicationContext, this.generator.generateApplicationContext(applicationContext,
this.generationContext, MAIN_GENERATED_TYPE); this.generationContext);
assertThatSampleRegistrarContributed(); assertThatSampleRegistrarContributed();
} }
@ -79,7 +73,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests {
GenericApplicationContext applicationContext = createApplicationContext( GenericApplicationContext applicationContext = createApplicationContext(
ConfigurationWithBeanDeclaringHints.class); ConfigurationWithBeanDeclaringHints.class);
this.generator.generateApplicationContext(applicationContext, this.generator.generateApplicationContext(applicationContext,
this.generationContext, MAIN_GENERATED_TYPE); this.generationContext);
assertThatSampleRegistrarContributed(); assertThatSampleRegistrarContributed();
} }
@ -89,7 +83,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests {
applicationContext.setClassLoader( applicationContext.setClassLoader(
new TestSpringFactoriesClassLoader("test-runtime-hints-aot.factories")); new TestSpringFactoriesClassLoader("test-runtime-hints-aot.factories"));
this.generator.generateApplicationContext(applicationContext, this.generator.generateApplicationContext(applicationContext,
this.generationContext, MAIN_GENERATED_TYPE); this.generationContext);
assertThatSampleRegistrarContributed(); assertThatSampleRegistrarContributed();
} }
@ -104,7 +98,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests {
new TestSpringFactoriesClassLoader("test-duplicated-runtime-hints-aot.factories")); new TestSpringFactoriesClassLoader("test-duplicated-runtime-hints-aot.factories"));
IncrementalRuntimeHintsRegistrar.counter.set(0); IncrementalRuntimeHintsRegistrar.counter.set(0);
this.generator.generateApplicationContext(applicationContext, this.generator.generateApplicationContext(applicationContext,
this.generationContext, MAIN_GENERATED_TYPE); this.generationContext);
RuntimeHints runtimeHints = this.generationContext.getRuntimeHints(); RuntimeHints runtimeHints = this.generationContext.getRuntimeHints();
assertThat(runtimeHints.resources().resourceBundles().map(ResourceBundleHint::getBaseName)) assertThat(runtimeHints.resources().resourceBundles().map(ResourceBundleHint::getBaseName))
.containsOnly("com.example.example0", "sample"); .containsOnly("com.example.example0", "sample");
@ -116,7 +110,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests {
GenericApplicationContext applicationContext = createApplicationContext( GenericApplicationContext applicationContext = createApplicationContext(
ConfigurationWithIllegalRegistrar.class); ConfigurationWithIllegalRegistrar.class);
assertThatThrownBy(() -> this.generator.generateApplicationContext( assertThatThrownBy(() -> this.generator.generateApplicationContext(
applicationContext, this.generationContext, MAIN_GENERATED_TYPE)) applicationContext, this.generationContext))
.isInstanceOf(BeanInstantiationException.class); .isInstanceOf(BeanInstantiationException.class);
} }

View File

@ -1,75 +0,0 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.aot.generate;
import java.util.Collection;
import java.util.Collections;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.JavaFile;
/**
* Generates new {@link GeneratedClass} instances.
*
* @author Phillip Webb
* @since 6.0
* @see GeneratedMethods
*/
public interface ClassGenerator {
/**
* Get or generate a new {@link GeneratedClass} for a given java file
* generator, target and feature name.
* @param javaFileGenerator the java file generator
* @param target the target of the newly generated class
* @param featureName the name of the feature that the generated class
* supports
* @return a {@link GeneratedClass} instance
*/
GeneratedClass getOrGenerateClass(JavaFileGenerator javaFileGenerator,
Class<?> target, String featureName);
/**
* Strategy used to generate the java file for the generated class.
* Implementations of this interface are included as part of the key used to
* identify classes that have already been created and as such should be
* static final instances or implement a valid
* {@code equals}/{@code hashCode}.
*/
@FunctionalInterface
interface JavaFileGenerator {
/**
* Generate the file {@link JavaFile} to be written.
* @param className the class name of the file
* @param methods the generated methods that must be included
* @return the generated files
*/
JavaFile generateJavaFile(ClassName className, GeneratedMethods methods);
/**
* Return method names that must not be generated.
* @return the reserved method names
*/
default Collection<String> getReservedMethodNames() {
return Collections.emptySet();
}
}
}

View File

@ -27,10 +27,9 @@ import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
/** /**
* Generate unique class names based on an optional target {@link Class} and * Generate unique class names based on target {@link Class} and a feature
* a feature name. This class is stateful so the same instance should be used * name. This class is stateful so the same instance should be used for all
* for all name generation. Most commonly the class name generator is obtained * name generation.
* via a {@link GenerationContext}.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Stephane Nicoll * @author Stephane Nicoll
@ -40,38 +39,92 @@ public final class ClassNameGenerator {
private static final String SEPARATOR = "__"; private static final String SEPARATOR = "__";
private static final String AOT_PACKAGE = "__.";
private static final String AOT_FEATURE = "Aot"; private static final String AOT_FEATURE = "Aot";
private final Map<String, AtomicInteger> sequenceGenerator = new ConcurrentHashMap<>(); private final Class<?> defaultTarget;
private final String featureNamePrefix;
private final Map<String, AtomicInteger> sequenceGenerator;
/**
* Create a new instance using the specified {@code defaultTarget} and no
* feature name prefix.
* @param defaultTarget the default target class to use
*/
public ClassNameGenerator(Class<?> defaultTarget) {
this(defaultTarget, "");
}
/**
* Create a new instance using the specified {@code defaultTarget} and
* feature name prefix.
* @param defaultTarget the default target class to use
* @param featureNamePrefix the prefix to use to qualify feature names
*/
public ClassNameGenerator(Class<?> defaultTarget, String featureNamePrefix) {
this(defaultTarget, featureNamePrefix, new ConcurrentHashMap<>());
}
private ClassNameGenerator(Class<?> defaultTarget, String featureNamePrefix,
Map<String, AtomicInteger> sequenceGenerator) {
this.defaultTarget = defaultTarget;
this.featureNamePrefix = (!StringUtils.hasText(featureNamePrefix) ? "" : featureNamePrefix);
this.sequenceGenerator = sequenceGenerator;
}
/** /**
* Generate a unique {@link ClassName} based on the specified {@code target} * Generate a unique {@link ClassName} based on the specified
* class and {@code featureName}. If a {@code target} is specified, the * {@code featureName} and {@code target}. If the {@code target} is
* generated class name is a suffixed version of it. * {@code null}, the configured main target of this instance is used.
* <p>For instance, a {@code com.example.Demo} target with an * <p>The class name is a suffixed version of the target. For instance, a
* {@code Initializer} feature name leads to a * {@code com.example.Demo} target with an {@code Initializer} feature name
* {@code com.example.Demo__Initializer} generated class name. If such a * leads to a {@code com.example.Demo__Initializer} generated class name.
* feature was already requested for this target, a counter is used to * The feature name is qualified by the configured feature name prefix,
* ensure uniqueness. * if any.
* <p>If there is no target, the {@code featureName} is used to generate the * <p>Generated class names are unique. If such a feature was already
* class name in the {@value #AOT_PACKAGE} package. * requested for this target, a counter is used to ensure uniqueness.
* @param target the class the newly generated class relates to, or * @param target the class the newly generated class relates to, or
* {@code null} if there is not target * {@code null} to use the main target
* @param featureName the name of the feature that the generated class * @param featureName the name of the feature that the generated class
* supports * supports
* @return a unique generated class name * @return a unique generated class name
*/ */
public ClassName generateClassName(@Nullable Class<?> target, String featureName) { public ClassName generateClassName(@Nullable Class<?> target, String featureName) {
return generateSequencedClassName(getClassName(target, featureName));
}
/**
* Return a class name based on the specified {@code target} and
* {@code featureName}. This uses the same algorithm as
* {@link #generateClassName(Class, String)} but does not register
* the class name, nor add a unique suffix to it if necessary.
* @param target the class the newly generated class relates to, or
* {@code null} to use the main target
* @param featureName the name of the feature that the generated class
* supports
* @return the class name
*/
String getClassName(@Nullable Class<?> target, String featureName) {
Assert.hasLength(featureName, "'featureName' must not be empty"); Assert.hasLength(featureName, "'featureName' must not be empty");
featureName = clean(featureName); featureName = clean(featureName);
if (target != null) { Class<?> targetToUse = (target != null ? target : this.defaultTarget);
return generateSequencedClassName(target.getName().replace("$", "_") String featureNameToUse = this.featureNamePrefix + featureName;
+ SEPARATOR + StringUtils.capitalize(featureName)); return targetToUse.getName().replace("$", "_")
+ SEPARATOR + StringUtils.capitalize(featureNameToUse);
} }
return generateSequencedClassName(AOT_PACKAGE + featureName);
/**
* Return a new {@link ClassNameGenerator} instance for the specified
* feature name prefix, keeping track of all the class names generated
* by this instance.
* @param featureNamePrefix the feature name prefix to use
* @return a new instance for the specified feature name prefix
*/
ClassNameGenerator usingFeatureNamePrefix(String featureNamePrefix) {
return new ClassNameGenerator(this.defaultTarget, featureNamePrefix,
this.sequenceGenerator);
} }
private String clean(String name) { private String clean(String name) {

View File

@ -17,12 +17,15 @@
package org.springframework.aot.generate; package org.springframework.aot.generate;
import java.io.IOException; import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHints;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
* Default implementation of {@link GenerationContext}. * Default {@link GenerationContext} implementation.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Stephane Nicoll * @author Stephane Nicoll
@ -30,7 +33,7 @@ import org.springframework.util.Assert;
*/ */
public class DefaultGenerationContext implements GenerationContext { public class DefaultGenerationContext implements GenerationContext {
private final ClassNameGenerator classNameGenerator; private final Map<String, AtomicInteger> sequenceGenerator;
private final GeneratedClasses generatedClasses; private final GeneratedClasses generatedClasses;
@ -41,39 +44,45 @@ public class DefaultGenerationContext implements GenerationContext {
/** /**
* Create a new {@link DefaultGenerationContext} instance backed by the * Create a new {@link DefaultGenerationContext} instance backed by the
* specified {@code generatedFiles}. * specified {@link ClassNameGenerator} and {@link GeneratedFiles}.
* @param classNameGenerator the naming convention to use for generated
* class names
* @param generatedFiles the generated files * @param generatedFiles the generated files
*/ */
public DefaultGenerationContext(GeneratedFiles generatedFiles) { public DefaultGenerationContext(ClassNameGenerator classNameGenerator, GeneratedFiles generatedFiles) {
this(new ClassNameGenerator(), generatedFiles, new RuntimeHints()); this(new GeneratedClasses(classNameGenerator), generatedFiles, new RuntimeHints());
} }
/** /**
* Create a new {@link DefaultGenerationContext} instance backed by the * Create a new {@link DefaultGenerationContext} instance backed by the
* specified items. * specified items.
* @param classNameGenerator the class name generator * @param generatedClasses the generated classes
* @param generatedFiles the generated files * @param generatedFiles the generated files
* @param runtimeHints the runtime hints * @param runtimeHints the runtime hints
*/ */
public DefaultGenerationContext(ClassNameGenerator classNameGenerator, public DefaultGenerationContext(GeneratedClasses generatedClasses,
GeneratedFiles generatedFiles, RuntimeHints runtimeHints) { GeneratedFiles generatedFiles, RuntimeHints runtimeHints) {
Assert.notNull(classNameGenerator, "'classNameGenerator' must not be null"); Assert.notNull(generatedClasses, "'generatedClasses' must not be null");
Assert.notNull(generatedFiles, "'generatedFiles' must not be null"); Assert.notNull(generatedFiles, "'generatedFiles' must not be null");
Assert.notNull(runtimeHints, "'runtimeHints' must not be null"); Assert.notNull(runtimeHints, "'runtimeHints' must not be null");
this.classNameGenerator = classNameGenerator; this.sequenceGenerator = new ConcurrentHashMap<>();
this.generatedClasses = new GeneratedClasses(classNameGenerator); this.generatedClasses = generatedClasses;
this.generatedFiles = generatedFiles; this.generatedFiles = generatedFiles;
this.runtimeHints = runtimeHints; this.runtimeHints = runtimeHints;
} }
private DefaultGenerationContext(DefaultGenerationContext existing, String name) {
@Override int sequence = existing.sequenceGenerator
public ClassNameGenerator getClassNameGenerator() { .computeIfAbsent(name, key -> new AtomicInteger()).getAndIncrement();
return this.classNameGenerator; String nameToUse = (sequence > 0 ? name + sequence : name);
this.sequenceGenerator = existing.sequenceGenerator;
this.generatedClasses = existing.generatedClasses.withName(nameToUse);
this.generatedFiles = existing.generatedFiles;
this.runtimeHints = existing.runtimeHints;
} }
@Override @Override
public GeneratedClasses getClassGenerator() { public GeneratedClasses getGeneratedClasses() {
return this.generatedClasses; return this.generatedClasses;
} }
@ -87,6 +96,11 @@ public class DefaultGenerationContext implements GenerationContext {
return this.runtimeHints; return this.runtimeHints;
} }
@Override
public GenerationContext withName(String name) {
return new DefaultGenerationContext(this, name);
}
/** /**
* Write any generated content out to the generated files. * Write any generated content out to the generated files.
*/ */

View File

@ -16,22 +16,24 @@
package org.springframework.aot.generate; package org.springframework.aot.generate;
import org.springframework.aot.generate.ClassGenerator.JavaFileGenerator; import java.util.function.Consumer;
import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.JavaFile;
import org.springframework.util.Assert; import org.springframework.javapoet.TypeSpec;
import org.springframework.javapoet.TypeSpec.Builder;
/** /**
* A generated class. * A generated class is a container for generated methods.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Stephane Nicoll
* @since 6.0 * @since 6.0
* @see GeneratedClasses * @see GeneratedClasses
* @see ClassGenerator
*/ */
public final class GeneratedClass { public final class GeneratedClass {
private final JavaFileGenerator JavaFileGenerator; private final Consumer<Builder> typeSpecCustomizer;
private final ClassName name; private final ClassName name;
@ -44,12 +46,10 @@ public final class GeneratedClass {
* {@link GeneratedClasses}. * {@link GeneratedClasses}.
* @param name the generated name * @param name the generated name
*/ */
GeneratedClass(JavaFileGenerator javaFileGenerator, ClassName name) { GeneratedClass(Consumer<Builder> typeSpecCustomizer, ClassName name) {
MethodNameGenerator methodNameGenerator = new MethodNameGenerator( this.typeSpecCustomizer = typeSpecCustomizer;
javaFileGenerator.getReservedMethodNames());
this.JavaFileGenerator = javaFileGenerator;
this.name = name; this.name = name;
this.methods = new GeneratedMethods(methodNameGenerator); this.methods = new GeneratedMethods(new MethodNameGenerator());
} }
@ -70,15 +70,11 @@ public final class GeneratedClass {
} }
JavaFile generateJavaFile() { JavaFile generateJavaFile() {
JavaFile javaFile = this.JavaFileGenerator.generateJavaFile(this.name, TypeSpec.Builder typeSpecBuilder = TypeSpec.classBuilder(this.name);
this.methods); this.typeSpecCustomizer.accept(typeSpecBuilder);
Assert.state(this.name.packageName().equals(javaFile.packageName), this.methods.doWithMethodSpecs(typeSpecBuilder::addMethod);
() -> "Generated JavaFile should be in package '" return JavaFile.builder(this.name.packageName(), typeSpecBuilder.build())
+ this.name.packageName() + "'"); .build();
Assert.state(this.name.simpleName().equals(javaFile.typeSpec.name),
() -> "Generated JavaFile should be named '" + this.name.simpleName()
+ "'");
return javaFile;
} }
} }

View File

@ -22,59 +22,143 @@ import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.TypeSpec;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
* A managed collection of generated classes. * A managed collection of generated classes. This class is stateful so the
* same instance should be used for all class generation.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Stephane Nicoll
* @since 6.0 * @since 6.0
* @see GeneratedClass * @see GeneratedClass
*/ */
public class GeneratedClasses implements ClassGenerator { public class GeneratedClasses {
private final ClassNameGenerator classNameGenerator; private final ClassNameGenerator classNameGenerator;
private final Map<Owner, GeneratedClass> classes = new ConcurrentHashMap<>(); private final List<GeneratedClass> classes;
private final Map<Owner, GeneratedClass> classesByOwner;
/**
* Create a new instance using the specified naming conventions.
* @param classNameGenerator the class name generator to use
*/
public GeneratedClasses(ClassNameGenerator classNameGenerator) { public GeneratedClasses(ClassNameGenerator classNameGenerator) {
Assert.notNull(classNameGenerator, "'classNameGenerator' must not be null"); this(classNameGenerator, new ArrayList<>(), new ConcurrentHashMap<>());
this.classNameGenerator = classNameGenerator;
} }
private GeneratedClasses(ClassNameGenerator classNameGenerator,
@Override List<GeneratedClass> classes, Map<Owner, GeneratedClass> classesByOwner) {
public GeneratedClass getOrGenerateClass(JavaFileGenerator javaFileGenerator, Assert.notNull(classNameGenerator, "'classNameGenerator' must not be null");
Class<?> target, String featureName) { this.classNameGenerator = classNameGenerator;
this.classes = classes;
Assert.notNull(javaFileGenerator, "'javaFileGenerator' must not be null"); this.classesByOwner = classesByOwner;
Assert.notNull(target, "'target' must not be null");
Assert.hasLength(featureName, "'featureName' must not be empty");
Owner owner = new Owner(javaFileGenerator, target.getName(), featureName);
return this.classes.computeIfAbsent(owner,
key -> new GeneratedClass(javaFileGenerator,
this.classNameGenerator.generateClassName(target, featureName)));
} }
/** /**
* Write generated Spring {@code .factories} files to the given * Prepare a {@link GeneratedClass} for the specified {@code featureName}
* targeting the specified {@code component}.
* @param featureName the name of the feature to associate with the generated class
* @param component the target component
* @return a {@link Builder} for further configuration
*/
public Builder forFeatureComponent(String featureName, Class<?> component) {
Assert.hasLength(featureName, "'featureName' must not be empty");
Assert.notNull(component, "'component' must not be null");
return new Builder(featureName, component);
}
/**
* Prepare a {@link GeneratedClass} for the specified {@code featureName}
* and no particular component. This should be used for high-level code
* generation that are widely applicable and for entry points.
* @param featureName the name of the feature to associate with the generated class
* @return a {@link Builder} for further configuration
*/
public Builder forFeature(String featureName) {
Assert.hasLength(featureName, "'featureName' must not be empty");
return new Builder(featureName, null);
}
/**
* Write the {@link GeneratedClass generated classes} using the given
* {@link GeneratedFiles} instance. * {@link GeneratedFiles} instance.
* @param generatedFiles where to write the generated files * @param generatedFiles where to write the generated classes
* @throws IOException on IO error * @throws IOException on IO error
*/ */
public void writeTo(GeneratedFiles generatedFiles) throws IOException { public void writeTo(GeneratedFiles generatedFiles) throws IOException {
Assert.notNull(generatedFiles, "'generatedFiles' must not be null"); Assert.notNull(generatedFiles, "'generatedFiles' must not be null");
List<GeneratedClass> generatedClasses = new ArrayList<>(this.classes.values()); List<GeneratedClass> generatedClasses = new ArrayList<>(this.classes);
generatedClasses.sort(Comparator.comparing(GeneratedClass::getName)); generatedClasses.sort(Comparator.comparing(GeneratedClass::getName));
for (GeneratedClass generatedClass : generatedClasses) { for (GeneratedClass generatedClass : generatedClasses) {
generatedFiles.addSourceFile(generatedClass.generateJavaFile()); generatedFiles.addSourceFile(generatedClass.generateJavaFile());
} }
} }
private record Owner(JavaFileGenerator javaFileGenerator, String target, GeneratedClasses withName(String name) {
String featureName) { return new GeneratedClasses(this.classNameGenerator.usingFeatureNamePrefix(name),
this.classes, this.classesByOwner);
}
private record Owner(String id, String className) {
}
public class Builder {
private final String featureName;
@Nullable
private final Class<?> target;
Builder(String featureName, @Nullable Class<?> target) {
this.target = target;
this.featureName = featureName;
}
/**
* Generate a new {@link GeneratedClass} using the specified type
* customizer.
* @param typeSpecCustomizer a customizer for the {@link TypeSpec.Builder}
* @return a new {@link GeneratedClass}
*/
public GeneratedClass generate(Consumer<TypeSpec.Builder> typeSpecCustomizer) {
Assert.notNull(typeSpecCustomizer, "'typeSpecCustomizer' must not be null");
return createGeneratedClass(typeSpecCustomizer);
}
/**
* Get or generate a new {@link GeneratedClass} for the specified {@code id}.
* @param id a unique identifier
* @param typeSpecCustomizer a customizer for the {@link TypeSpec.Builder}
* @return a {@link GeneratedClass} instance
*/
public GeneratedClass getOrGenerate(String id,
Consumer<TypeSpec.Builder> typeSpecCustomizer) {
Assert.hasLength(id, "'id' must not be empty");
Assert.notNull(typeSpecCustomizer, "'typeSpecCustomizer' must not be null");
Owner owner = new Owner(id, GeneratedClasses.this.classNameGenerator
.getClassName(this.target, this.featureName));
return GeneratedClasses.this.classesByOwner.computeIfAbsent(owner,
key -> createGeneratedClass(typeSpecCustomizer));
}
private GeneratedClass createGeneratedClass(Consumer<TypeSpec.Builder> typeSpecCustomizer) {
ClassName className = GeneratedClasses.this.classNameGenerator
.generateClassName(this.target, this.featureName);
GeneratedClass generatedClass = new GeneratedClass(typeSpecCustomizer, className);
GeneratedClasses.this.classes.add(generatedClass);
return generatedClass;
}
} }

View File

@ -24,38 +24,31 @@ import org.springframework.aot.hint.SerializationHints;
/** /**
* Central interface used for code generation. * Central interface used for code generation.
* <p> *
* A generation context provides: * <p>A generation context provides:
* <ul> * <ul>
* <li>Support for {@link #getClassNameGenerator() class name generation}.</li> * <li>Management of all {@link #getGeneratedClasses()} generated classes},
* <li>Central management of all {@link #getGeneratedFiles() generated * including naming convention support.</li>
* files}.</li> * <li>Central management of all {@link #getGeneratedFiles() generated files}.</li>
* <li>Support for the recording of {@link #getRuntimeHints() runtime * <li>Support for the recording of {@link #getRuntimeHints() runtime hints}.</li>
* hints}.</li>
* </ul> * </ul>
* *
* <p>If a dedicated round of code generation is required while processing, it
* is possible to create a specialized context using {@link #withName(String)}.
*
* @author Phillip Webb * @author Phillip Webb
* @author Stephane Nicoll * @author Stephane Nicoll
* @since 6.0 * @since 6.0
*/ */
public interface GenerationContext { public interface GenerationContext {
/**
* Return the {@link ClassNameGenerator} being used by the context. Allows
* new class names to be generated before they are added to the
* {@link #getGeneratedFiles() generated files}.
* @return the class name generator
* @see #getGeneratedFiles()
*/
ClassNameGenerator getClassNameGenerator();
/** /**
* Return the {@link GeneratedClasses} being used by the context. Allows a * Return the {@link GeneratedClasses} being used by the context. Allows a
* single generated class to be shared across multiple AOT processors. All * single generated class to be shared across multiple AOT processors. All
* generated classes are written at the end of AOT processing. * generated classes are written at the end of AOT processing.
* @return the generated classes * @return the generated classes
*/ */
ClassGenerator getClassGenerator(); GeneratedClasses getGeneratedClasses();
/** /**
* Return the {@link GeneratedFiles} being used by the context. Used to * Return the {@link GeneratedFiles} being used by the context. Used to
@ -73,4 +66,14 @@ public interface GenerationContext {
*/ */
RuntimeHints getRuntimeHints(); RuntimeHints getRuntimeHints();
/**
* Return a new {@link GenerationContext} instance using the specified
* name to qualify generated assets for a dedicated round of code
* generation. If this name is already in use, a unique sequence is added
* to ensure the name is unique.
* @param name the name to use
* @return a specialized {@link GenerationContext} for the specified name
*/
GenerationContext withName(String name);
} }

View File

@ -32,12 +32,26 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
*/ */
class ClassNameGeneratorTests { class ClassNameGeneratorTests {
private final ClassNameGenerator generator = new ClassNameGenerator(); private final ClassNameGenerator generator = new ClassNameGenerator(Object.class);
@Test @Test
void generateClassNameWhenTargetClassIsNullUsesAotPackage() { void generateClassNameWhenTargetClassIsNullUsesMainTarget() {
ClassName generated = this.generator.generateClassName((Class<?>) null, "test"); ClassName generated = this.generator.generateClassName(null, "test");
assertThat(generated).hasToString("__.Test"); assertThat(generated).hasToString("java.lang.Object__Test");
}
@Test
void generateClassNameUseFeatureNamePrefix() {
ClassName generated = new ClassNameGenerator(Object.class, "One")
.generateClassName(InputStream.class, "test");
assertThat(generated).hasToString("java.io.InputStream__OneTest");
}
@Test
void generateClassNameWithNoTextFeatureNamePrefix() {
ClassName generated = new ClassNameGenerator(Object.class, " ")
.generateClassName(InputStream.class, "test");
assertThat(generated).hasToString("java.io.InputStream__Test");
} }
@Test @Test
@ -59,8 +73,7 @@ class ClassNameGeneratorTests {
@Test @Test
void generateClassNameWithClassWhenLowercaseFeatureNameGeneratesName() { void generateClassNameWithClassWhenLowercaseFeatureNameGeneratesName() {
ClassName generated = this.generator.generateClassName(InputStream.class, ClassName generated = this.generator.generateClassName(InputStream.class, "bytes");
"bytes");
assertThat(generated).hasToString("java.io.InputStream__Bytes"); assertThat(generated).hasToString("java.io.InputStream__Bytes");
} }
@ -81,6 +94,15 @@ class ClassNameGeneratorTests {
assertThat(generated3).hasToString("java.io.InputStream__Bytes2"); assertThat(generated3).hasToString("java.io.InputStream__Bytes2");
} }
@Test
void getClassNameWhenMultipleCallsReturnsSameName() {
String name1 = this.generator.getClassName(InputStream.class, "bytes");
String name2 = this.generator.getClassName(InputStream.class, "bytes");
String name3 = this.generator.getClassName(InputStream.class, "bytes");
assertThat(name1).hasToString("java.io.InputStream__Bytes")
.isEqualTo(name2).isEqualTo(name3);
}
static class TestBean { static class TestBean {
} }

View File

@ -16,9 +16,14 @@
package org.springframework.aot.generate; package org.springframework.aot.generate;
import java.util.function.Consumer;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.GeneratedFiles.Kind;
import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHints;
import org.springframework.core.testfixture.aot.generate.TestTarget;
import org.springframework.javapoet.TypeSpec.Builder;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@ -31,9 +36,12 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
*/ */
class DefaultGenerationContextTests { class DefaultGenerationContextTests {
private final ClassNameGenerator classNameGenerator = new ClassNameGenerator(); private static final Consumer<Builder> typeSpecCustomizer = type -> {};
private final GeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); private final GeneratedClasses generatedClasses = new GeneratedClasses(
new ClassNameGenerator(TestTarget.class));
private final InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles();
private final RuntimeHints runtimeHints = new RuntimeHints(); private final RuntimeHints runtimeHints = new RuntimeHints();
@ -41,9 +49,7 @@ class DefaultGenerationContextTests {
@Test @Test
void createWithOnlyGeneratedFilesCreatesContext() { void createWithOnlyGeneratedFilesCreatesContext() {
DefaultGenerationContext context = new DefaultGenerationContext( DefaultGenerationContext context = new DefaultGenerationContext(
this.generatedFiles); new ClassNameGenerator(TestTarget.class), this.generatedFiles);
assertThat(context.getClassNameGenerator())
.isInstanceOf(ClassNameGenerator.class);
assertThat(context.getGeneratedFiles()).isSameAs(this.generatedFiles); assertThat(context.getGeneratedFiles()).isSameAs(this.generatedFiles);
assertThat(context.getRuntimeHints()).isInstanceOf(RuntimeHints.class); assertThat(context.getRuntimeHints()).isInstanceOf(RuntimeHints.class);
} }
@ -51,24 +57,23 @@ class DefaultGenerationContextTests {
@Test @Test
void createCreatesContext() { void createCreatesContext() {
DefaultGenerationContext context = new DefaultGenerationContext( DefaultGenerationContext context = new DefaultGenerationContext(
this.classNameGenerator, this.generatedFiles, this.runtimeHints); this.generatedClasses, this.generatedFiles, this.runtimeHints);
assertThat(context.getClassNameGenerator()).isNotNull();
assertThat(context.getGeneratedFiles()).isNotNull(); assertThat(context.getGeneratedFiles()).isNotNull();
assertThat(context.getRuntimeHints()).isNotNull(); assertThat(context.getRuntimeHints()).isNotNull();
} }
@Test @Test
void createWhenClassNameGeneratorIsNullThrowsException() { void createWhenGeneratedClassesIsNullThrowsException() {
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> new DefaultGenerationContext(null, this.generatedFiles, .isThrownBy(() -> new DefaultGenerationContext(null, this.generatedFiles,
this.runtimeHints)) this.runtimeHints))
.withMessage("'classNameGenerator' must not be null"); .withMessage("'generatedClasses' must not be null");
} }
@Test @Test
void createWhenGeneratedFilesIsNullThrowsException() { void createWhenGeneratedFilesIsNullThrowsException() {
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> new DefaultGenerationContext(this.classNameGenerator, .isThrownBy(() -> new DefaultGenerationContext(this.generatedClasses,
null, this.runtimeHints)) null, this.runtimeHints))
.withMessage("'generatedFiles' must not be null"); .withMessage("'generatedFiles' must not be null");
} }
@ -76,30 +81,71 @@ class DefaultGenerationContextTests {
@Test @Test
void createWhenRuntimeHintsIsNullThrowsException() { void createWhenRuntimeHintsIsNullThrowsException() {
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> new DefaultGenerationContext(this.classNameGenerator, .isThrownBy(() -> new DefaultGenerationContext(this.generatedClasses,
this.generatedFiles, null)) this.generatedFiles, null))
.withMessage("'runtimeHints' must not be null"); .withMessage("'runtimeHints' must not be null");
} }
@Test @Test
void getClassNameGeneratorReturnsClassNameGenerator() { void getGeneratedClassesReturnsClassNameGenerator() {
DefaultGenerationContext context = new DefaultGenerationContext( DefaultGenerationContext context = new DefaultGenerationContext(
this.classNameGenerator, this.generatedFiles, this.runtimeHints); this.generatedClasses, this.generatedFiles, this.runtimeHints);
assertThat(context.getClassNameGenerator()).isSameAs(this.classNameGenerator); assertThat(context.getGeneratedClasses()).isSameAs(this.generatedClasses);
} }
@Test @Test
void getGeneratedFilesReturnsGeneratedFiles() { void getGeneratedFilesReturnsGeneratedFiles() {
DefaultGenerationContext context = new DefaultGenerationContext( DefaultGenerationContext context = new DefaultGenerationContext(
this.classNameGenerator, this.generatedFiles, this.runtimeHints); this.generatedClasses, this.generatedFiles, this.runtimeHints);
assertThat(context.getGeneratedFiles()).isSameAs(this.generatedFiles); assertThat(context.getGeneratedFiles()).isSameAs(this.generatedFiles);
} }
@Test @Test
void getRuntimeHintsReturnsRuntimeHints() { void getRuntimeHintsReturnsRuntimeHints() {
DefaultGenerationContext context = new DefaultGenerationContext( DefaultGenerationContext context = new DefaultGenerationContext(
this.classNameGenerator, this.generatedFiles, this.runtimeHints); this.generatedClasses, this.generatedFiles, this.runtimeHints);
assertThat(context.getRuntimeHints()).isSameAs(this.runtimeHints); assertThat(context.getRuntimeHints()).isSameAs(this.runtimeHints);
} }
@Test
void withNameUpdateNamingConvention() {
DefaultGenerationContext context = new DefaultGenerationContext(
new ClassNameGenerator(TestTarget.class), this.generatedFiles);
GenerationContext anotherContext = context.withName("Another");
GeneratedClass generatedClass = anotherContext.getGeneratedClasses()
.forFeature("Test").generate(typeSpecCustomizer);
assertThat(generatedClass.getName().simpleName()).endsWith("__AnotherTest");
}
@Test
void withNameKeepTrackOfAllGeneratedFiles() {
DefaultGenerationContext context = new DefaultGenerationContext(
new ClassNameGenerator(TestTarget.class), this.generatedFiles);
context.getGeneratedClasses().forFeature("Test").generate(typeSpecCustomizer);
GenerationContext anotherContext = context.withName("Another");
assertThat(anotherContext.getGeneratedClasses()).isNotSameAs(context.getGeneratedClasses());
assertThat(anotherContext.getGeneratedFiles()).isSameAs(context.getGeneratedFiles());
assertThat(anotherContext.getRuntimeHints()).isSameAs(context.getRuntimeHints());
anotherContext.getGeneratedClasses().forFeature("Test").generate(typeSpecCustomizer);
context.writeGeneratedContent();
assertThat(this.generatedFiles.getGeneratedFiles(Kind.SOURCE)).hasSize(2);
}
@Test
void withNameGenerateUniqueName() {
DefaultGenerationContext context = new DefaultGenerationContext(
new ClassNameGenerator(Object.class), this.generatedFiles);
context.withName("Test").getGeneratedClasses()
.forFeature("Feature").generate(typeSpecCustomizer);
context.withName("Test").getGeneratedClasses()
.forFeature("Feature").generate(typeSpecCustomizer);
context.withName("Test").getGeneratedClasses()
.forFeature("Feature").generate(typeSpecCustomizer);
context.writeGeneratedContent();
assertThat(this.generatedFiles.getGeneratedFiles(Kind.SOURCE)).containsOnlyKeys(
"java/lang/Object__TestFeature.java",
"java/lang/Object__Test1Feature.java",
"java/lang/Object__Test2Feature.java");
}
} }

View File

@ -16,77 +16,43 @@
package org.springframework.aot.generate; package org.springframework.aot.generate;
import java.util.function.Consumer;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.TypeSpec.Builder;
import org.springframework.javapoet.TypeSpec;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
/** /**
* Tests for {@link GeneratedClass}. * Tests for {@link GeneratedClass}.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Stephane Nicoll
*/ */
class GeneratedClassTests { class GeneratedClassTests {
@Test @Test
void getNameReturnsName() { void getNameReturnsName() {
ClassName name = ClassName.bestGuess("com.example.Test"); ClassName name = ClassName.bestGuess("com.example.Test");
GeneratedClass generatedClass = new GeneratedClass(this::generateJavaFile, name); GeneratedClass generatedClass = new GeneratedClass(emptyTypeSpec(), name);
assertThat(generatedClass.getName()).isSameAs(name); assertThat(generatedClass.getName()).isSameAs(name);
} }
@Test @Test
void generateJavaFileSuppliesGeneratedMethods() { void generateJavaFileIncludesGeneratedMethods() {
ClassName name = ClassName.bestGuess("com.example.Test"); ClassName name = ClassName.bestGuess("com.example.Test");
GeneratedClass generatedClass = new GeneratedClass(this::generateJavaFile, name); GeneratedClass generatedClass = new GeneratedClass(emptyTypeSpec(), name);
MethodGenerator methodGenerator = generatedClass.getMethodGenerator(); MethodGenerator methodGenerator = generatedClass.getMethodGenerator();
methodGenerator.generateMethod("test") methodGenerator.generateMethod("test")
.using(builder -> builder.addJavadoc("Test Method")); .using(builder -> builder.addJavadoc("Test Method"));
assertThat(generatedClass.generateJavaFile().toString()).contains("Test Method"); assertThat(generatedClass.generateJavaFile().toString()).contains("Test Method");
} }
@Test
void generateJavaFileWhenHasBadPackageThrowsException() {
ClassName name = ClassName.bestGuess("com.example.Test");
GeneratedClass generatedClass = new GeneratedClass(
this::generateBadPackageJavaFile, name);
assertThatIllegalStateException()
.isThrownBy(
() -> assertThat(generatedClass.generateJavaFile().toString()))
.withMessageContaining("should be in package");
}
@Test private Consumer<Builder> emptyTypeSpec() {
void generateJavaFileWhenHasBadNameThrowsException() { return type -> {};
ClassName name = ClassName.bestGuess("com.example.Test");
GeneratedClass generatedClass = new GeneratedClass(this::generateBadNameJavaFile,
name);
assertThatIllegalStateException()
.isThrownBy(
() -> assertThat(generatedClass.generateJavaFile().toString()))
.withMessageContaining("should be named");
}
private JavaFile generateJavaFile(ClassName className, GeneratedMethods methods) {
TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className);
methods.doWithMethodSpecs(classBuilder::addMethod);
return JavaFile.builder(className.packageName(), classBuilder.build()).build();
}
private JavaFile generateBadPackageJavaFile(ClassName className,
GeneratedMethods methods) {
TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className);
return JavaFile.builder("naughty", classBuilder.build()).build();
}
private JavaFile generateBadNameJavaFile(ClassName className,
GeneratedMethods methods) {
TypeSpec.Builder classBuilder = TypeSpec.classBuilder("Naughty");
return JavaFile.builder(className.packageName(), classBuilder.build()).build();
} }
} }

View File

@ -16,27 +16,34 @@
package org.springframework.aot.generate; package org.springframework.aot.generate;
import java.io.IOException;
import java.util.function.Consumer;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.ClassGenerator.JavaFileGenerator; import org.springframework.aot.generate.GeneratedFiles.Kind;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.TypeSpec; import org.springframework.javapoet.TypeSpec;
import org.springframework.javapoet.TypeSpec.Builder;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
/** /**
* Tests for {@link GeneratedClasses}. * Tests for {@link GeneratedClasses}.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Stephane Nicoll
*/ */
class GeneratedClassesTests { class GeneratedClassesTests {
private GeneratedClasses generatedClasses = new GeneratedClasses( private static final Consumer<TypeSpec.Builder> emptyTypeCustomizer = type -> {};
new ClassNameGenerator());
private static final JavaFileGenerator JAVA_FILE_GENERATOR = GeneratedClassesTests::generateJavaFile; private final GeneratedClasses generatedClasses = new GeneratedClasses(
new ClassNameGenerator(Object.class));
@Test @Test
void createWhenClassNameGeneratorIsNullThrowsException() { void createWhenClassNameGeneratorIsNullThrowsException() {
@ -45,61 +52,118 @@ class GeneratedClassesTests {
} }
@Test @Test
void getOrGenerateWithClassTargetWhenJavaFileGeneratorIsNullThrowsException() { void forFeatureComponentWhenTargetIsNullThrowsException() {
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.generatedClasses.getOrGenerateClass(null, .isThrownBy(() -> this.generatedClasses.forFeatureComponent("test", null))
TestTarget.class, "test")) .withMessage("'component' must not be null");
.withMessage("'javaFileGenerator' must not be null");
} }
@Test @Test
void getOrGenerateWithClassTargetWhenTargetIsNullThrowsException() { void forFeatureComponentWhenFeatureNameIsEmptyThrowsException() {
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.generatedClasses .isThrownBy(() -> this.generatedClasses.forFeatureComponent("", TestComponent.class))
.getOrGenerateClass(JAVA_FILE_GENERATOR, (Class<?>) null, "test"))
.withMessage("'target' must not be null");
}
@Test
void getOrGenerateWithClassTargetWhenFeatureIsNullThrowsException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.generatedClasses
.getOrGenerateClass(JAVA_FILE_GENERATOR, TestTarget.class, null))
.withMessage("'featureName' must not be empty"); .withMessage("'featureName' must not be empty");
} }
@Test @Test
void getOrGenerateWhenNewReturnsGeneratedMethod() { void forFeatureWhenFeatureNameIsEmptyThrowsException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.generatedClasses.forFeature(""))
.withMessage("'featureName' must not be empty");
}
@Test
void generateWhenTypeSpecCustomizerIsNullThrowsException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.generatedClasses
.forFeatureComponent("test", TestComponent.class).generate(null))
.withMessage("'typeSpecCustomizer' must not be null");
}
@Test
void forFeatureUsesDefaultTarget() {
GeneratedClass generatedClass = this.generatedClasses
.forFeature("Test").generate(emptyTypeCustomizer);
assertThat(generatedClass.getName()).hasToString("java.lang.Object__Test");
}
@Test
void forFeatureComponentUsesComponent() {
GeneratedClass generatedClass = this.generatedClasses
.forFeatureComponent("Test", TestComponent.class).generate(emptyTypeCustomizer);
assertThat(generatedClass.getName().toString()).endsWith("TestComponent__Test");
}
@Test
void generateReturnsDifferentInstances() {
Consumer<Builder> typeCustomizer = mockTypeCustomizer();
GeneratedClass generatedClass1 = this.generatedClasses GeneratedClass generatedClass1 = this.generatedClasses
.getOrGenerateClass(JAVA_FILE_GENERATOR, TestTarget.class, "one"); .forFeatureComponent("one", TestComponent.class).generate(typeCustomizer);
GeneratedClass generatedClass2 = this.generatedClasses GeneratedClass generatedClass2 = this.generatedClasses
.getOrGenerateClass(JAVA_FILE_GENERATOR, TestTarget.class, "two"); .forFeatureComponent("one", TestComponent.class).generate(typeCustomizer);
assertThat(generatedClass1).isNotSameAs(generatedClass2);
assertThat(generatedClass1.getName().simpleName()).endsWith("__One");
assertThat(generatedClass2.getName().simpleName()).endsWith("__One1");
}
@Test
void getOrGenerateWhenNewReturnsGeneratedMethod() {
Consumer<Builder> typeCustomizer = mockTypeCustomizer();
GeneratedClass generatedClass1 = this.generatedClasses
.forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer);
GeneratedClass generatedClass2 = this.generatedClasses
.forFeatureComponent("two", TestComponent.class).getOrGenerate("facet", typeCustomizer);
assertThat(generatedClass1).isNotNull().isNotEqualTo(generatedClass2); assertThat(generatedClass1).isNotNull().isNotEqualTo(generatedClass2);
assertThat(generatedClass2).isNotNull(); assertThat(generatedClass2).isNotNull();
} }
@Test @Test
void getOrGenerateWhenRepeatReturnsSameGeneratedMethod() { void getOrGenerateWhenRepeatReturnsSameGeneratedMethod() {
GeneratedClasses generated = this.generatedClasses; Consumer<Builder> typeCustomizer = mockTypeCustomizer();
GeneratedClass generatedClass1 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR, GeneratedClass generatedClass1 = this.generatedClasses
TestTarget.class, "one"); .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer);
GeneratedClass generatedClass2 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR, GeneratedClass generatedClass2 = this.generatedClasses
TestTarget.class, "one"); .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer);
GeneratedClass generatedClass3 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR, GeneratedClass generatedClass3 = this.generatedClasses
TestTarget.class, "one"); .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer);
GeneratedClass generatedClass4 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR,
TestTarget.class, "two");
assertThat(generatedClass1).isNotNull().isSameAs(generatedClass2) assertThat(generatedClass1).isNotNull().isSameAs(generatedClass2)
.isSameAs(generatedClass3).isNotSameAs(generatedClass4); .isSameAs(generatedClass3);
verifyNoInteractions(typeCustomizer);
generatedClass1.generateJavaFile();
verify(typeCustomizer).accept(any());
} }
static JavaFile generateJavaFile(ClassName className, @Test
GeneratedMethods generatedMethods) { @SuppressWarnings("unchecked")
TypeSpec typeSpec = TypeSpec.classBuilder(className).addJavadoc("Test").build(); void writeToInvokeTypeSpecCustomizer() throws IOException {
return JavaFile.builder(className.packageName(), typeSpec).build(); Consumer<TypeSpec.Builder> typeSpecCustomizer = mock(Consumer.class);
this.generatedClasses.forFeatureComponent("one", TestComponent.class)
.generate(typeSpecCustomizer);
verifyNoInteractions(typeSpecCustomizer);
InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles();
this.generatedClasses.writeTo(generatedFiles);
verify(typeSpecCustomizer).accept(any());
assertThat(generatedFiles.getGeneratedFiles(Kind.SOURCE)).hasSize(1);
} }
private static class TestTarget { @Test
void withNameUpdatesNamingConventions() {
GeneratedClass generatedClass1 = this.generatedClasses
.forFeatureComponent("one", TestComponent.class).generate(emptyTypeCustomizer);
GeneratedClass generatedClass2 = this.generatedClasses.withName("Another")
.forFeatureComponent("one", TestComponent.class).generate(emptyTypeCustomizer);
assertThat(generatedClass1.getName().toString()).endsWith("TestComponent__One");
assertThat(generatedClass2.getName().toString()).endsWith("TestComponent__AnotherOne");
}
@SuppressWarnings("unchecked")
private Consumer<TypeSpec.Builder> mockTypeCustomizer() {
return mock(Consumer.class);
}
private static class TestComponent {
} }

View File

@ -0,0 +1,40 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.core.testfixture.aot.generate;
import org.springframework.aot.generate.ClassNameGenerator;
import org.springframework.aot.generate.DefaultGenerationContext;
import org.springframework.aot.generate.GeneratedFiles;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.InMemoryGeneratedFiles;
/**
* Test {@link GenerationContext} implementation that uses
* {@link TestTarget} as the main target.
*
* @author Stephane Nicoll
*/
public class TestGenerationContext extends DefaultGenerationContext {
public TestGenerationContext(GeneratedFiles generatedFiles) {
super(new ClassNameGenerator(TestTarget.class), generatedFiles);
}
public TestGenerationContext() {
this(new InMemoryGeneratedFiles());
}
}

View File

@ -0,0 +1,25 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.core.testfixture.aot.generate;
/**
* A target used by tests of code generation.
*
* @author Stephane Nicoll
*/
public class TestTarget {
}

View File

@ -30,6 +30,7 @@ import java.util.Map;
import java.util.Properties; import java.util.Properties;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import jakarta.persistence.EntityManager; import jakarta.persistence.EntityManager;
import jakarta.persistence.EntityManagerFactory; import jakarta.persistence.EntityManagerFactory;
@ -39,11 +40,10 @@ import jakarta.persistence.PersistenceProperty;
import jakarta.persistence.PersistenceUnit; import jakarta.persistence.PersistenceUnit;
import jakarta.persistence.SynchronizationType; import jakarta.persistence.SynchronizationType;
import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.generate.MethodGenerator;
import org.springframework.aot.generate.MethodNameGenerator;
import org.springframework.aot.generate.MethodReference; import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
@ -70,11 +70,9 @@ import org.springframework.core.BridgeMethodResolver;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.core.PriorityOrdered; import org.springframework.core.PriorityOrdered;
import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.TypeSpec; import org.springframework.javapoet.MethodSpec.Builder;
import org.springframework.jndi.JndiLocatorDelegate; import org.springframework.jndi.JndiLocatorDelegate;
import org.springframework.jndi.JndiTemplate; import org.springframework.jndi.JndiTemplate;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
@ -789,34 +787,27 @@ public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwar
@Override @Override
public void applyTo(GenerationContext generationContext, public void applyTo(GenerationContext generationContext,
BeanRegistrationCode beanRegistrationCode) { BeanRegistrationCode beanRegistrationCode) {
ClassName className = generationContext.getClassNameGenerator() GeneratedClass generatedClass = generationContext.getGeneratedClasses()
.generateClassName(this.target, "PersistenceInjection"); .forFeatureComponent("PersistenceInjection", this.target).generate(type -> {
TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); type.addJavadoc("Persistence injection for {@link $T}.", this.target);
classBuilder.addJavadoc("Persistence injection for {@link $T}.", this.target); type.addModifiers(javax.lang.model.element.Modifier.PUBLIC);
classBuilder.addModifiers(javax.lang.model.element.Modifier.PUBLIC); });
GeneratedMethods methods = new GeneratedMethods( generatedClass.getMethodGenerator().generateMethod(APPLY_METHOD)
new MethodNameGenerator(APPLY_METHOD)); .using(generateMethod(generationContext.getRuntimeHints(), generatedClass.getMethodGenerator()));
classBuilder.addMethod(generateMethod(generationContext.getRuntimeHints(),
className, methods));
methods.doWithMethodSpecs(classBuilder::addMethod);
JavaFile javaFile = JavaFile
.builder(className.packageName(), classBuilder.build()).build();
generationContext.getGeneratedFiles().addSourceFile(javaFile);
beanRegistrationCode.addInstancePostProcessor( beanRegistrationCode.addInstancePostProcessor(
MethodReference.ofStatic(className, APPLY_METHOD)); MethodReference.ofStatic(generatedClass.getName(), APPLY_METHOD));
} }
private MethodSpec generateMethod(RuntimeHints hints, ClassName className, private Consumer<Builder> generateMethod(RuntimeHints hints, MethodGenerator methodGenerator) {
MethodGenerator methodGenerator) { return method -> {
MethodSpec.Builder builder = MethodSpec.methodBuilder(APPLY_METHOD); method.addJavadoc("Apply the persistence injection.");
builder.addJavadoc("Apply the persistence injection."); method.addModifiers(javax.lang.model.element.Modifier.PUBLIC,
builder.addModifiers(javax.lang.model.element.Modifier.PUBLIC,
javax.lang.model.element.Modifier.STATIC); javax.lang.model.element.Modifier.STATIC);
builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER);
builder.addParameter(this.target, INSTANCE_PARAMETER); method.addParameter(this.target, INSTANCE_PARAMETER);
builder.returns(this.target); method.returns(this.target);
builder.addCode(generateMethodCode(hints, methodGenerator)); method.addCode(generateMethodCode(hints, methodGenerator));
return builder.build(); };
} }
private CodeBlock generateMethodCode(RuntimeHints hints, private CodeBlock generateMethodCode(RuntimeHints hints,

View File

@ -43,6 +43,7 @@ import org.springframework.beans.factory.aot.BeanRegistrationCode;
import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -67,7 +68,7 @@ class PersistenceAnnotationBeanPostProcessorAotContributionTests {
void setup() { void setup() {
this.beanFactory = new DefaultListableBeanFactory(); this.beanFactory = new DefaultListableBeanFactory();
this.generatedFiles = new InMemoryGeneratedFiles(); this.generatedFiles = new InMemoryGeneratedFiles();
this.generationContext = new DefaultGenerationContext(generatedFiles); this.generationContext = new TestGenerationContext(generatedFiles);
} }
@Test @Test
@ -183,6 +184,7 @@ class PersistenceAnnotationBeanPostProcessorAotContributionTests {
.processAheadOfTime(registeredBean); .processAheadOfTime(registeredBean);
BeanRegistrationCode beanRegistrationCode = mock(BeanRegistrationCode.class); BeanRegistrationCode beanRegistrationCode = mock(BeanRegistrationCode.class);
contribution.applyTo(generationContext, beanRegistrationCode); contribution.applyTo(generationContext, beanRegistrationCode);
generationContext.writeGeneratedContent();
TestCompiler.forSystem().withFiles(generatedFiles) TestCompiler.forSystem().withFiles(generatedFiles)
.compile(compiled -> result.accept(new Invoker(compiled), compiled)); .compile(compiled -> result.accept(new Invoker(compiled), compiled));
} }