From 6199835d6ea791a7f74a8e98f1a2a08c17014aef Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Wed, 22 Jun 2022 14:20:00 +0200 Subject: [PATCH] Harmonize generated class name conventions This commit moves the responsibility of naming classes to the GenerationContext. This was already largely the case before, except that the concept of a "mainTarget" and "featureNamePrefix" was specific to bean factory initialization contributors. ClassNameGenerator should now be instantiated with a default target and an optional feature name prefix. As a result, it does no longer generate class names in the "__" package. GeneratedClasses can now provide a new, unique, GeneratedClass or offer a container for retrieving the same GeneratedClass based on an identifier. This lets all contributors use this facility rather than creating JavaFile manually. This also means that ClassNameGenerator is no longer exposed. Because the naming conventions are now part of the GenerationContext, it is required to be able to retrieve a specialized version of it if a code generation round needs to use different naming conventions. A new withName method has been added to that effect. Closes gh-28585 --- ...roxyBeanRegistrationAotProcessorTests.java | 3 +- .../AutowiredAnnotationBeanPostProcessor.java | 44 +++--- .../aot/BeanDefinitionMethodGenerator.java | 63 ++------ .../aot/BeanFactoryInitializationCode.java | 19 --- .../aot/BeanRegistrationsAotContribution.java | 45 ++---- .../DefaultBeanRegistrationCodeFragments.java | 8 +- ...nBeanRegistrationAotContributionTests.java | 7 +- .../BeanDefinitionMethodGeneratorTests.java | 23 +-- ...BeanRegistrationsAotContributionTests.java | 13 +- .../InstanceSupplierCodeGeneratorTests.java | 3 +- .../MockBeanFactoryInitializationCode.java | 15 -- .../aot/ApplicationContextAotGenerator.java | 41 ++--- ...ionContextInitializationCodeGenerator.java | 49 ++---- ...lassPostProcessorAotContributionTests.java | 3 +- .../ApplicationContextAotGeneratorTests.java | 10 +- ...ssorBeanRegistrationAotProcessorTests.java | 6 +- ...actoryInitializationAotProcessorTests.java | 20 +-- .../aot/generate/ClassGenerator.java | 75 --------- .../aot/generate/ClassNameGenerator.java | 99 +++++++++--- .../generate/DefaultGenerationContext.java | 44 ++++-- .../aot/generate/GeneratedClass.java | 34 ++--- .../aot/generate/GeneratedClasses.java | 128 +++++++++++++--- .../aot/generate/GenerationContext.java | 37 ++--- .../aot/generate/ClassNameGeneratorTests.java | 36 ++++- .../DefaultGenerationContextTests.java | 78 ++++++++-- .../aot/generate/GeneratedClassTests.java | 52 ++----- .../aot/generate/GeneratedClassesTests.java | 142 +++++++++++++----- .../aot/generate/TestGenerationContext.java | 40 +++++ .../testfixture/aot/generate/TestTarget.java | 25 +++ ...ersistenceAnnotationBeanPostProcessor.java | 51 +++---- ...BeanPostProcessorAotContributionTests.java | 4 +- 31 files changed, 652 insertions(+), 565 deletions(-) delete mode 100644 spring-core/src/main/java/org/springframework/aot/generate/ClassGenerator.java create mode 100644 spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestGenerationContext.java create mode 100644 spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestTarget.java diff --git a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java index 634d0a097a..d9897fa588 100644 --- a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java +++ b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java @@ -42,6 +42,7 @@ import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryInitializationCode; import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolder; import org.springframework.core.ResolvableType; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -71,7 +72,7 @@ class ScopedProxyBeanRegistrationAotProcessorTests { this.beanFactory = new DefaultListableBeanFactory(); this.processor = new TestBeanRegistrationsAotProcessor(); this.generatedFiles = new InMemoryGeneratedFiles(); - this.generationContext = new DefaultGenerationContext(this.generatedFiles); + this.generationContext = new TestGenerationContext(this.generatedFiles); this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java index dc4db39242..c75d94aa32 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java @@ -41,6 +41,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.aot.generate.AccessVisibility; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.ExecutableHint; @@ -79,11 +80,8 @@ import org.springframework.core.annotation.AnnotationAttributes; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.core.annotation.MergedAnnotations; -import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; -import org.springframework.javapoet.TypeSpec; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -910,30 +908,28 @@ public class AutowiredAnnotationBeanPostProcessor implements SmartInstantiationA @Override public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) { - - ClassName className = generationContext.getClassNameGenerator() - .generateClassName(this.target, "Autowiring"); - TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); - classBuilder.addJavadoc("Autowiring for {@link $T}.", this.target); - classBuilder.addModifiers(javax.lang.model.element.Modifier.PUBLIC); - classBuilder.addMethod(generateMethod(generationContext.getRuntimeHints())); - JavaFile javaFile = JavaFile - .builder(className.packageName(), classBuilder.build()).build(); - generationContext.getGeneratedFiles().addSourceFile(javaFile); + GeneratedClass generatedClass = generationContext.getGeneratedClasses() + .forFeatureComponent("Autowiring", this.target) + .generate(type -> { + type.addJavadoc("Autowiring for {@link $T}.", this.target); + type.addModifiers(javax.lang.model.element.Modifier.PUBLIC); + }); + generatedClass.getMethodGenerator().generateMethod(APPLY_METHOD) + .using(generateMethod(generationContext.getRuntimeHints())); beanRegistrationCode.addInstancePostProcessor( - MethodReference.ofStatic(className, APPLY_METHOD)); + MethodReference.ofStatic(generatedClass.getName(), APPLY_METHOD)); } - private MethodSpec generateMethod(RuntimeHints hints) { - MethodSpec.Builder builder = MethodSpec.methodBuilder(APPLY_METHOD); - builder.addJavadoc("Apply the autowiring."); - builder.addModifiers(javax.lang.model.element.Modifier.PUBLIC, - javax.lang.model.element.Modifier.STATIC); - builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); - builder.addParameter(this.target, INSTANCE_PARAMETER); - builder.returns(this.target); - builder.addCode(generateMethodCode(hints)); - return builder.build(); + private Consumer generateMethod(RuntimeHints hints) { + return method -> { + method.addJavadoc("Apply the autowiring."); + method.addModifiers(javax.lang.model.element.Modifier.PUBLIC, + javax.lang.model.element.Modifier.STATIC); + method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); + method.addParameter(this.target, INSTANCE_PARAMETER); + method.returns(this.target); + method.addCode(generateMethodCode(hints)); + }; } private CodeBlock generateMethodCode(RuntimeHints hints) { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java index de604409b1..ed7682a882 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java @@ -21,10 +21,8 @@ import java.util.List; import javax.lang.model.element.Modifier; -import org.springframework.aot.generate.ClassGenerator.JavaFileGenerator; import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; -import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.generate.MethodNameGenerator; @@ -32,8 +30,6 @@ import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.javapoet.ClassName; -import org.springframework.javapoet.JavaFile; -import org.springframework.javapoet.TypeSpec; import org.springframework.lang.Nullable; /** @@ -45,6 +41,8 @@ import org.springframework.lang.Nullable; */ class BeanDefinitionMethodGenerator { + private static final String FEATURE_NAME = "BeanDefinitions"; + private final BeanDefinitionMethodGeneratorFactory methodGeneratorFactory; private final RegisteredBean registeredBean; @@ -81,22 +79,23 @@ class BeanDefinitionMethodGenerator { * Generate the method that returns the {@link BeanDefinition} to be * registered. * @param generationContext the generation context - * @param featureNamePrefix the prefix to use for the feature name * @param beanRegistrationsCode the bean registrations code * @return a reference to the generated method. */ MethodReference generateBeanDefinitionMethod(GenerationContext generationContext, - String featureNamePrefix, BeanRegistrationsCode beanRegistrationsCode) { + BeanRegistrationsCode beanRegistrationsCode) { BeanRegistrationCodeFragments codeFragments = getCodeFragments(generationContext, - beanRegistrationsCode, featureNamePrefix); + beanRegistrationsCode); Class target = codeFragments.getTarget(this.registeredBean, this.constructorOrFactoryMethod); if (!target.getName().startsWith("java.")) { - String featureName = featureNamePrefix + "BeanDefinitions"; - GeneratedClass generatedClass = generationContext.getClassGenerator() - .getOrGenerateClass(new BeanDefinitionsJavaFileGenerator(target), - target, featureName); + GeneratedClass generatedClass = generationContext.getGeneratedClasses() + .forFeatureComponent(FEATURE_NAME, target) + .getOrGenerate(FEATURE_NAME, type -> { + type.addJavadoc("Bean definitions for {@link $T}", target); + type.addModifiers(Modifier.PUBLIC); + }); MethodGenerator methodGenerator = generatedClass.getMethodGenerator() .withName(getName()); GeneratedMethod generatedMethod = generateBeanDefinitionMethod( @@ -115,11 +114,10 @@ class BeanDefinitionMethodGenerator { } private BeanRegistrationCodeFragments getCodeFragments(GenerationContext generationContext, - BeanRegistrationsCode beanRegistrationsCode, String featureNamePrefix) { + BeanRegistrationsCode beanRegistrationsCode) { BeanRegistrationCodeFragments codeFragments = new DefaultBeanRegistrationCodeFragments( - beanRegistrationsCode, this.registeredBean, this.methodGeneratorFactory, - featureNamePrefix); + beanRegistrationsCode, this.registeredBean, this.methodGeneratorFactory); for (BeanRegistrationAotContribution aotContribution : this.aotContributions) { codeFragments = aotContribution.customizeBeanRegistrationCodeFragments(generationContext, codeFragments); } @@ -172,41 +170,4 @@ class BeanDefinitionMethodGenerator { return beanName; } - - /** - * {@link BeanDefinitionsJavaFileGenerator} to create the - * {@code BeanDefinitions} file. - */ - private static class BeanDefinitionsJavaFileGenerator implements JavaFileGenerator { - - private final Class target; - - - BeanDefinitionsJavaFileGenerator(Class target) { - this.target = target; - } - - - @Override - public JavaFile generateJavaFile(ClassName className, GeneratedMethods methods) { - TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); - classBuilder.addJavadoc("Bean definitions for {@link $T}", this.target); - classBuilder.addModifiers(Modifier.PUBLIC); - methods.doWithMethodSpecs(classBuilder::addMethod); - return JavaFile.builder(className.packageName(), classBuilder.build()) - .build(); - } - - @Override - public int hashCode() { - return getClass().hashCode(); - } - - @Override - public boolean equals(Object obj) { - return getClass() == obj.getClass(); - } - - } - } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java index ceefb8e2e6..92e250ba7e 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java @@ -18,7 +18,6 @@ package org.springframework.beans.factory.aot; import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.generate.MethodReference; -import org.springframework.lang.Nullable; /** * Interface that can be used to configure the code that will be generated to @@ -35,24 +34,6 @@ public interface BeanFactoryInitializationCode { */ String BEAN_FACTORY_VARIABLE = "beanFactory"; - /** - * Return the target class for this bean factory or {@code null} if there is - * no target. - * @return the target - */ - @Nullable - default Class getTarget() { - return null; - } - - /** - * Return the name of the bean factory or and empty string if no ID is available. - * @return the bean factory name - */ - default String getName() { - return ""; - } - /** * Return a {@link MethodGenerator} that can be used to add more methods to * the Initializing code. diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java index 52e661b2c1..7b29174c23 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java @@ -20,17 +20,15 @@ import java.util.Map; import javax.lang.model.element.Modifier; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; -import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; -import org.springframework.javapoet.TypeSpec; /** * AOT contribution from a {@link BeanRegistrationsAotProcessor} used to @@ -61,24 +59,23 @@ class BeanRegistrationsAotContribution public void applyTo(GenerationContext generationContext, BeanFactoryInitializationCode beanFactoryInitializationCode) { - ClassName className = generationContext.getClassNameGenerator().generateClassName( - beanFactoryInitializationCode.getTarget(), - beanFactoryInitializationCode.getName() + "BeanFactoryRegistrations"); + GeneratedClass generatedClass = generationContext.getGeneratedClasses() + .forFeature("BeanFactoryRegistrations").generate(type -> { + type.addJavadoc("Register bean definitions for the bean factory."); + type.addModifiers(Modifier.PUBLIC); + }); BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator( - className); + generatedClass); GeneratedMethod registerMethod = codeGenerator.getMethodGenerator() .generateMethod("registerBeanDefinitions") .using(builder -> generateRegisterMethod(builder, generationContext, - beanFactoryInitializationCode.getName(), codeGenerator)); - JavaFile javaFile = codeGenerator.generatedJavaFile(className); - generationContext.getGeneratedFiles().addSourceFile(javaFile); beanFactoryInitializationCode - .addInitializer(MethodReference.of(className, registerMethod.getName())); + .addInitializer(MethodReference.of(generatedClass.getName(), registerMethod.getName())); } private void generateRegisterMethod(MethodSpec.Builder builder, - GenerationContext generationContext, String featureNamePrefix, + GenerationContext generationContext, BeanRegistrationsCode beanRegistrationsCode) { builder.addJavadoc("Register the bean definitions."); @@ -88,7 +85,7 @@ class BeanRegistrationsAotContribution CodeBlock.Builder code = CodeBlock.builder(); this.registrations.forEach((beanName, beanDefinitionMethodGenerator) -> { MethodReference beanDefinitionMethod = beanDefinitionMethodGenerator - .generateBeanDefinitionMethod(generationContext, featureNamePrefix, + .generateBeanDefinitionMethod(generationContext, beanRegistrationsCode); code.addStatement("$L.registerBeanDefinition($S, $L)", BEAN_FACTORY_PARAMETER_NAME, beanName, @@ -103,33 +100,21 @@ class BeanRegistrationsAotContribution */ static class BeanRegistrationsCodeGenerator implements BeanRegistrationsCode { - private final ClassName className; + private final GeneratedClass generatedClass; - private final GeneratedMethods generatedMethods = new GeneratedMethods(); - - - public BeanRegistrationsCodeGenerator(ClassName className) { - this.className = className; + public BeanRegistrationsCodeGenerator(GeneratedClass generatedClass) { + this.generatedClass = generatedClass; } @Override public ClassName getClassName() { - return this.className; + return this.generatedClass.getName(); } @Override public MethodGenerator getMethodGenerator() { - return this.generatedMethods; - } - - JavaFile generatedJavaFile(ClassName className) { - TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); - classBuilder.addJavadoc("Register bean definitions for the bean factory."); - classBuilder.addModifiers(Modifier.PUBLIC); - this.generatedMethods.doWithMethodSpecs(classBuilder::addMethod); - return JavaFile.builder(className.packageName(), classBuilder.build()) - .build(); + return this.generatedClass.getMethodGenerator(); } } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index f5f5137f0f..4f00fa68a4 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -54,18 +54,14 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments private final BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory; - private final String featureNamePrefix; - DefaultBeanRegistrationCodeFragments(BeanRegistrationsCode beanRegistrationsCode, RegisteredBean registeredBean, - BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory, - String featureNamePrefix) { + BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory) { this.beanRegistrationsCode = beanRegistrationsCode; this.registeredBean = registeredBean; this.beanDefinitionMethodGeneratorFactory = beanDefinitionMethodGeneratorFactory; - this.featureNamePrefix = featureNamePrefix; } @@ -124,7 +120,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments .getBeanDefinitionMethodGenerator(innerRegisteredBean, name); Assert.state(methodGenerator != null, "Unexpected filtering of inner-bean"); MethodReference generatedMethod = methodGenerator - .generateBeanDefinitionMethod(generationContext, this.featureNamePrefix, + .generateBeanDefinitionMethod(generationContext, this.beanRegistrationsCode); return generatedMethod.toInvokeCodeBlock(); } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java index 2c88fb2230..4237c126bf 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java @@ -25,7 +25,6 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.DefaultGenerationContext; -import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.RuntimeHints; @@ -40,6 +39,7 @@ import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationCode; import org.springframework.core.env.Environment; import org.springframework.core.env.StandardEnvironment; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; @@ -59,7 +59,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { private InMemoryGeneratedFiles generatedFiles; - private GenerationContext generationContext; + private DefaultGenerationContext generationContext; private RuntimeHints runtimeHints; @@ -70,7 +70,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { @BeforeEach void setup() { this.generatedFiles = new InMemoryGeneratedFiles(); - this.generationContext = new DefaultGenerationContext(this.generatedFiles); + this.generationContext = new TestGenerationContext(this.generatedFiles); this.runtimeHints = this.generationContext.getRuntimeHints(); this.beanRegistrationCode = new MockBeanRegistrationCode(); this.beanFactory = new DefaultListableBeanFactory(); @@ -169,6 +169,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { @SuppressWarnings("unchecked") private void testCompiledResult(RegisteredBean registeredBean, BiConsumer, Compiled> result) { + this.generationContext.writeGeneratedContent(); JavaFile javaFile = createJavaFile(registeredBean.getBeanClass()); TestCompiler.forSystem().withFiles(this.generatedFiles).compile(javaFile::writeTo, compiled -> result.accept(compiled.getInstance(BiFunction.class), diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 572b70ac2e..a8183ce4cb 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -50,6 +50,7 @@ import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; import org.springframework.core.ResolvableType; import org.springframework.core.mock.MockSpringFactoriesLoader; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.JavaFile; @@ -80,7 +81,7 @@ class BeanDefinitionMethodGeneratorTests { @BeforeEach void setup() { this.generatedFiles = new InMemoryGeneratedFiles(); - this.generationContext = new DefaultGenerationContext(this.generatedFiles); + this.generationContext = new TestGenerationContext(this.generatedFiles); this.beanFactory = new DefaultListableBeanFactory(); this.methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( new AotFactoriesLoader(this.beanFactory, new MockSpringFactoriesLoader())); @@ -96,7 +97,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); @@ -114,7 +115,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { assertThat(actual.getResolvableType().resolve()).isEqualTo(GenericBean.class); SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); @@ -147,7 +148,7 @@ class BeanDefinitionMethodGeneratorTests { BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( this.methodGeneratorFactory, registeredBean, null, aotContributions); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { assertThat(actual.getBeanClass()).isEqualTo(TestBean.class); InstanceSupplier supplier = (InstanceSupplier) actual @@ -173,7 +174,7 @@ class BeanDefinitionMethodGeneratorTests { BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( this.methodGeneratorFactory, registeredBean, null, aotContributions); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { assertThat(actual.getBeanClass()).isEqualTo(TestBean.class); SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); @@ -213,7 +214,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, aotContributions); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { assertThat(actual.getAttribute("a")).isEqualTo("A"); assertThat(actual.getAttribute("b")).isNull(); @@ -246,7 +247,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, innerBean, "testInnerBean", Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { assertThat(compiled.getSourceFile(".*BeanDefinitions")) .contains("Get the inner-bean definition for 'testInnerBean'"); @@ -267,7 +268,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { RootBeanDefinition actualInnerBeanDefinition = (RootBeanDefinition) actual .getPropertyValues().get("name"); @@ -301,7 +302,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { RootBeanDefinition actualInnerBeanDefinition = (RootBeanDefinition) actual .getConstructorArgumentValues() @@ -334,7 +335,7 @@ class BeanDefinitionMethodGeneratorTests { BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( this.methodGeneratorFactory, registeredBean, null, aotContributions); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("AotContributedMethod()"); @@ -351,7 +352,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); freshBeanFactory.registerBeanDefinition("test", actual); diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java index ae91b492a7..580855d946 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java @@ -29,6 +29,7 @@ import javax.lang.model.element.Modifier; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.aot.generate.ClassNameGenerator; import org.springframework.aot.generate.DefaultGenerationContext; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.InMemoryGeneratedFiles; @@ -42,6 +43,8 @@ import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryInitializationCode; import org.springframework.core.mock.MockSpringFactoriesLoader; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; +import org.springframework.core.testfixture.aot.generate.TestTarget; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; @@ -72,7 +75,7 @@ class BeanRegistrationsAotContributionTests { @BeforeEach void setup() { this.generatedFiles = new InMemoryGeneratedFiles(); - this.generationContext = new DefaultGenerationContext(this.generatedFiles); + this.generationContext = new TestGenerationContext(this.generatedFiles); this.beanFactory = new DefaultListableBeanFactory(); this.springFactoriesLoader = new MockSpringFactoriesLoader(); this.methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( @@ -100,7 +103,9 @@ class BeanRegistrationsAotContributionTests { @Test void applyToWhenHasNameGeneratesPrefixedFeatureName() { - this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode("Management"); + this.generationContext = new DefaultGenerationContext( + new ClassNameGenerator(TestTarget.class, "Management"), this.generatedFiles); + this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); Map registrations = new LinkedHashMap<>(); RegisteredBean registeredBean = registerBean( new RootBeanDefinition(TestBean.class)); @@ -129,11 +134,11 @@ class BeanRegistrationsAotContributionTests { @Override MethodReference generateBeanDefinitionMethod( - GenerationContext generationContext, String featureNamePrefix, + GenerationContext generationContext, BeanRegistrationsCode beanRegistrationsCode) { beanRegistrationsCodes.add(beanRegistrationsCode); return super.generateBeanDefinitionMethod(generationContext, - featureNamePrefix, beanRegistrationsCode); + beanRegistrationsCode); } }; diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java index d4de4edc3b..beb867620c 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java @@ -52,6 +52,7 @@ import org.springframework.beans.testfixture.beans.factory.generator.factory.Num import org.springframework.beans.testfixture.beans.factory.generator.factory.SampleFactory; import org.springframework.beans.testfixture.beans.factory.generator.injection.InjectionComponent; import org.springframework.core.env.StandardEnvironment; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.JavaFile; @@ -82,7 +83,7 @@ class InstanceSupplierCodeGeneratorTests { @BeforeEach void setup() { this.generatedFiles = new InMemoryGeneratedFiles(); - this.generationContext = new DefaultGenerationContext(this.generatedFiles); + this.generationContext = new TestGenerationContext(this.generatedFiles); } diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java index d1aee20458..b931cccdeb 100644 --- a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java @@ -35,21 +35,6 @@ public class MockBeanFactoryInitializationCode implements BeanFactoryInitializat private final List initializers = new ArrayList<>(); - private final String name; - - public MockBeanFactoryInitializationCode() { - this(""); - } - - public MockBeanFactoryInitializationCode(String name) { - this.name = name; - } - - @Override - public String getName() { - return this.name; - } - @Override public GeneratedMethods getMethodGenerator() { return this.generatedMethods; diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java index 0ea7827427..f5a450d1e6 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java @@ -16,14 +16,14 @@ package org.springframework.context.aot; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GenerationContext; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.support.GenericApplicationContext; import org.springframework.javapoet.ClassName; -import org.springframework.javapoet.JavaFile; -import org.springframework.lang.Nullable; /** * Process an {@link ApplicationContext} and its {@link BeanFactory} to generate @@ -42,41 +42,20 @@ public class ApplicationContextAotGenerator { * specified {@link GenerationContext}. * @param applicationContext the application context to handle * @param generationContext the generation context to use - * @param generatedInitializerClassName the class name to use for the - * generated application context initializer + * @return the class name of the {@link ApplicationContextInitializer} entry point */ - public void generateApplicationContext(GenericApplicationContext applicationContext, - GenerationContext generationContext, - ClassName generatedInitializerClassName) { - - generateApplicationContext(applicationContext, null, null, generationContext, - generatedInitializerClassName); - } - - /** - * Refresh the specified {@link GenericApplicationContext} and generate the - * necessary code to restore the state of its {@link BeanFactory}, using the - * specified {@link GenerationContext}. - * @param applicationContext the application context to handle - * @param target the target class for the generated initializer (used when generating class names) - * @param name the name of the application context (used when generating class names) - * @param generationContext the generation context to use - * @param generatedInitializerClassName the class name to use for the - * generated application context initializer - */ - public void generateApplicationContext(GenericApplicationContext applicationContext, - @Nullable Class target, @Nullable String name, GenerationContext generationContext, - ClassName generatedInitializerClassName) { - + public ClassName generateApplicationContext(GenericApplicationContext applicationContext, + GenerationContext generationContext) { applicationContext.refreshForAotProcessing(); DefaultListableBeanFactory beanFactory = applicationContext .getDefaultListableBeanFactory(); - ApplicationContextInitializationCodeGenerator codeGenerator = new ApplicationContextInitializationCodeGenerator( - target, name); + ApplicationContextInitializationCodeGenerator codeGenerator = new ApplicationContextInitializationCodeGenerator(); new BeanFactoryInitializationAotContributions(beanFactory).applyTo(generationContext, codeGenerator); - JavaFile javaFile = codeGenerator.generateJavaFile(generatedInitializerClassName); - generationContext.getGeneratedFiles().addSourceFile(javaFile); + GeneratedClass applicationContextInitializer = generationContext.getGeneratedClasses() + .forFeature("ApplicationContextInitializer") + .generate(codeGenerator.generateJavaFile()); + return applicationContextInitializer.getName(); } } diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java index ab4ae35665..fe532edda2 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java @@ -18,6 +18,7 @@ package org.springframework.context.aot; import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; import javax.lang.model.element.Modifier; @@ -29,14 +30,10 @@ import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver; import org.springframework.context.support.GenericApplicationContext; -import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; import org.springframework.javapoet.TypeSpec; -import org.springframework.lang.Nullable; -import org.springframework.util.StringUtils; /** * Internal code generator to create the application context initializer. @@ -50,33 +47,11 @@ class ApplicationContextInitializationCodeGenerator private static final String APPLICATION_CONTEXT_VARIABLE = "applicationContext"; - @Nullable - private final Class target; - - private final String name; - private final GeneratedMethods generatedMethods = new GeneratedMethods(); private final List initializers = new ArrayList<>(); - ApplicationContextInitializationCodeGenerator(@Nullable Class target, @Nullable String name) { - this.target = target; - this.name = (!StringUtils.hasText(name)) ? "" : name; - } - - - @Override - @Nullable - public Class getTarget() { - return this.target; - } - - @Override - public String getName() { - return this.name; - } - @Override public MethodGenerator getMethodGenerator() { return this.generatedMethods; @@ -87,17 +62,17 @@ class ApplicationContextInitializationCodeGenerator this.initializers.add(methodReference); } - JavaFile generateJavaFile(ClassName className) { - TypeSpec.Builder builder = TypeSpec.classBuilder(className); - builder.addJavadoc( - "{@link $T} to restore an application context based on previous AOT processing.", - ApplicationContextInitializer.class); - builder.addModifiers(Modifier.PUBLIC); - builder.addSuperinterface(ParameterizedTypeName.get( - ApplicationContextInitializer.class, GenericApplicationContext.class)); - builder.addMethod(generateInitializeMethod()); - this.generatedMethods.doWithMethodSpecs(builder::addMethod); - return JavaFile.builder(className.packageName(), builder.build()).build(); + Consumer generateJavaFile() { + return builder -> { + builder.addJavadoc( + "{@link $T} to restore an application context based on previous AOT processing.", + ApplicationContextInitializer.class); + builder.addModifiers(Modifier.PUBLIC); + builder.addSuperinterface(ParameterizedTypeName.get( + ApplicationContextInitializer.class, GenericApplicationContext.class)); + builder.addMethod(generateInitializeMethod()); + this.generatedMethods.doWithMethodSpecs(builder::addMethod); + }; } private MethodSpec generateInitializeMethod() { diff --git a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java index 46d7080c16..393315fd48 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java @@ -37,6 +37,7 @@ import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryIn import org.springframework.beans.testfixture.beans.factory.generator.SimpleConfiguration; import org.springframework.context.testfixture.context.generator.annotation.ImportAwareConfiguration; import org.springframework.context.testfixture.context.generator.annotation.ImportConfiguration; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; @@ -59,7 +60,7 @@ class ConfigurationClassPostProcessorAotContributionTests { private InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); - private DefaultGenerationContext generationContext = new DefaultGenerationContext( + private DefaultGenerationContext generationContext = new TestGenerationContext( this.generatedFiles); private MockBeanFactoryInitializationCode beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); diff --git a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java index f3a5baf964..30e30f3416 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java @@ -44,7 +44,7 @@ import org.springframework.context.support.GenericApplicationContext; import org.springframework.context.testfixture.context.generator.SimpleComponent; import org.springframework.context.testfixture.context.generator.annotation.AutowiredComponent; import org.springframework.context.testfixture.context.generator.annotation.InitDestroyComponent; -import org.springframework.javapoet.ClassName; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import static org.assertj.core.api.Assertions.assertThat; @@ -56,9 +56,6 @@ import static org.assertj.core.api.Assertions.assertThat; */ class ApplicationContextAotGeneratorTests { - private static final ClassName MAIN_GENERATED_TYPE = ClassName.get("__", - "TestInitializer"); - @Test void generateApplicationContextWhenHasSimpleBean() { GenericApplicationContext applicationContext = new GenericApplicationContext(); @@ -191,10 +188,9 @@ class ApplicationContextAotGeneratorTests { BiConsumer, Compiled> result) { ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); - DefaultGenerationContext generationContext = new DefaultGenerationContext( + DefaultGenerationContext generationContext = new TestGenerationContext( generatedFiles); - generator.generateApplicationContext(applicationContext, generationContext, - MAIN_GENERATED_TYPE); + generator.generateApplicationContext(applicationContext, generationContext); generationContext.writeGeneratedContent(); TestCompiler.forSystem().withFiles(generatedFiles) .compile(compiled -> result.accept( diff --git a/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java index 6cb85f517f..06a5829850 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java @@ -24,9 +24,7 @@ import java.lang.annotation.Target; import org.junit.jupiter.api.Test; -import org.springframework.aot.generate.DefaultGenerationContext; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsPredicates; @@ -39,6 +37,7 @@ import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.core.annotation.AliasFor; import org.springframework.core.annotation.SynthesizedAnnotation; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.lang.Nullable; import static org.assertj.core.api.Assertions.assertThat; @@ -54,8 +53,7 @@ class ReflectiveProcessorBeanRegistrationAotProcessorTests { private final ReflectiveProcessorBeanRegistrationAotProcessor processor = new ReflectiveProcessorBeanRegistrationAotProcessor(); - private final GenerationContext generationContext = new DefaultGenerationContext( - new InMemoryGeneratedFiles()); + private final GenerationContext generationContext = new TestGenerationContext(); @Test void shouldIgnoreNonAnnotatedType() { diff --git a/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java index 8ff69e4bd5..16dd7185fb 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java @@ -25,9 +25,7 @@ import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.springframework.aot.generate.DefaultGenerationContext; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.hint.ResourceBundleHint; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; @@ -38,7 +36,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.ImportRuntimeHints; import org.springframework.context.support.GenericApplicationContext; -import org.springframework.javapoet.ClassName; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.lang.Nullable; import static org.assertj.core.api.Assertions.assertThat; @@ -51,17 +49,13 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; */ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { - private static final ClassName MAIN_GENERATED_TYPE = ClassName.get("__", - "TestInitializer"); - private GenerationContext generationContext; private ApplicationContextAotGenerator generator; @BeforeEach void setup() { - this.generationContext = new DefaultGenerationContext( - new InMemoryGeneratedFiles()); + this.generationContext = new TestGenerationContext(); this.generator = new ApplicationContextAotGenerator(); } @@ -70,7 +64,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { GenericApplicationContext applicationContext = createApplicationContext( ConfigurationWithHints.class); this.generator.generateApplicationContext(applicationContext, - this.generationContext, MAIN_GENERATED_TYPE); + this.generationContext); assertThatSampleRegistrarContributed(); } @@ -79,7 +73,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { GenericApplicationContext applicationContext = createApplicationContext( ConfigurationWithBeanDeclaringHints.class); this.generator.generateApplicationContext(applicationContext, - this.generationContext, MAIN_GENERATED_TYPE); + this.generationContext); assertThatSampleRegistrarContributed(); } @@ -89,7 +83,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { applicationContext.setClassLoader( new TestSpringFactoriesClassLoader("test-runtime-hints-aot.factories")); this.generator.generateApplicationContext(applicationContext, - this.generationContext, MAIN_GENERATED_TYPE); + this.generationContext); assertThatSampleRegistrarContributed(); } @@ -104,7 +98,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { new TestSpringFactoriesClassLoader("test-duplicated-runtime-hints-aot.factories")); IncrementalRuntimeHintsRegistrar.counter.set(0); this.generator.generateApplicationContext(applicationContext, - this.generationContext, MAIN_GENERATED_TYPE); + this.generationContext); RuntimeHints runtimeHints = this.generationContext.getRuntimeHints(); assertThat(runtimeHints.resources().resourceBundles().map(ResourceBundleHint::getBaseName)) .containsOnly("com.example.example0", "sample"); @@ -116,7 +110,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { GenericApplicationContext applicationContext = createApplicationContext( ConfigurationWithIllegalRegistrar.class); assertThatThrownBy(() -> this.generator.generateApplicationContext( - applicationContext, this.generationContext, MAIN_GENERATED_TYPE)) + applicationContext, this.generationContext)) .isInstanceOf(BeanInstantiationException.class); } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/ClassGenerator.java b/spring-core/src/main/java/org/springframework/aot/generate/ClassGenerator.java deleted file mode 100644 index 2accace2a0..0000000000 --- a/spring-core/src/main/java/org/springframework/aot/generate/ClassGenerator.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -import java.util.Collection; -import java.util.Collections; - -import org.springframework.javapoet.ClassName; -import org.springframework.javapoet.JavaFile; - -/** - * Generates new {@link GeneratedClass} instances. - * - * @author Phillip Webb - * @since 6.0 - * @see GeneratedMethods - */ -public interface ClassGenerator { - - /** - * Get or generate a new {@link GeneratedClass} for a given java file - * generator, target and feature name. - * @param javaFileGenerator the java file generator - * @param target the target of the newly generated class - * @param featureName the name of the feature that the generated class - * supports - * @return a {@link GeneratedClass} instance - */ - GeneratedClass getOrGenerateClass(JavaFileGenerator javaFileGenerator, - Class target, String featureName); - - - /** - * Strategy used to generate the java file for the generated class. - * Implementations of this interface are included as part of the key used to - * identify classes that have already been created and as such should be - * static final instances or implement a valid - * {@code equals}/{@code hashCode}. - */ - @FunctionalInterface - interface JavaFileGenerator { - - /** - * Generate the file {@link JavaFile} to be written. - * @param className the class name of the file - * @param methods the generated methods that must be included - * @return the generated files - */ - JavaFile generateJavaFile(ClassName className, GeneratedMethods methods); - - /** - * Return method names that must not be generated. - * @return the reserved method names - */ - default Collection getReservedMethodNames() { - return Collections.emptySet(); - } - - } - -} diff --git a/spring-core/src/main/java/org/springframework/aot/generate/ClassNameGenerator.java b/spring-core/src/main/java/org/springframework/aot/generate/ClassNameGenerator.java index ec6d7cfd99..02d4d6bccb 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/ClassNameGenerator.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/ClassNameGenerator.java @@ -27,10 +27,9 @@ import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; /** - * Generate unique class names based on an optional target {@link Class} and - * a feature name. This class is stateful so the same instance should be used - * for all name generation. Most commonly the class name generator is obtained - * via a {@link GenerationContext}. + * Generate unique class names based on target {@link Class} and a feature + * name. This class is stateful so the same instance should be used for all + * name generation. * * @author Phillip Webb * @author Stephane Nicoll @@ -40,38 +39,92 @@ public final class ClassNameGenerator { private static final String SEPARATOR = "__"; - private static final String AOT_PACKAGE = "__."; - private static final String AOT_FEATURE = "Aot"; - private final Map sequenceGenerator = new ConcurrentHashMap<>(); + private final Class defaultTarget; + + private final String featureNamePrefix; + + private final Map sequenceGenerator; + + /** + * Create a new instance using the specified {@code defaultTarget} and no + * feature name prefix. + * @param defaultTarget the default target class to use + */ + public ClassNameGenerator(Class defaultTarget) { + this(defaultTarget, ""); + } + + /** + * Create a new instance using the specified {@code defaultTarget} and + * feature name prefix. + * @param defaultTarget the default target class to use + * @param featureNamePrefix the prefix to use to qualify feature names + */ + public ClassNameGenerator(Class defaultTarget, String featureNamePrefix) { + this(defaultTarget, featureNamePrefix, new ConcurrentHashMap<>()); + } + + private ClassNameGenerator(Class defaultTarget, String featureNamePrefix, + Map sequenceGenerator) { + this.defaultTarget = defaultTarget; + this.featureNamePrefix = (!StringUtils.hasText(featureNamePrefix) ? "" : featureNamePrefix); + this.sequenceGenerator = sequenceGenerator; + } /** - * Generate a unique {@link ClassName} based on the specified {@code target} - * class and {@code featureName}. If a {@code target} is specified, the - * generated class name is a suffixed version of it. - *

For instance, a {@code com.example.Demo} target with an - * {@code Initializer} feature name leads to a - * {@code com.example.Demo__Initializer} generated class name. If such a - * feature was already requested for this target, a counter is used to - * ensure uniqueness. - *

If there is no target, the {@code featureName} is used to generate the - * class name in the {@value #AOT_PACKAGE} package. + * Generate a unique {@link ClassName} based on the specified + * {@code featureName} and {@code target}. If the {@code target} is + * {@code null}, the configured main target of this instance is used. + *

The class name is a suffixed version of the target. For instance, a + * {@code com.example.Demo} target with an {@code Initializer} feature name + * leads to a {@code com.example.Demo__Initializer} generated class name. + * The feature name is qualified by the configured feature name prefix, + * if any. + *

Generated class names are unique. If such a feature was already + * requested for this target, a counter is used to ensure uniqueness. * @param target the class the newly generated class relates to, or - * {@code null} if there is not target + * {@code null} to use the main target * @param featureName the name of the feature that the generated class * supports * @return a unique generated class name */ public ClassName generateClassName(@Nullable Class target, String featureName) { + return generateSequencedClassName(getClassName(target, featureName)); + } + + /** + * Return a class name based on the specified {@code target} and + * {@code featureName}. This uses the same algorithm as + * {@link #generateClassName(Class, String)} but does not register + * the class name, nor add a unique suffix to it if necessary. + * @param target the class the newly generated class relates to, or + * {@code null} to use the main target + * @param featureName the name of the feature that the generated class + * supports + * @return the class name + */ + String getClassName(@Nullable Class target, String featureName) { Assert.hasLength(featureName, "'featureName' must not be empty"); featureName = clean(featureName); - if (target != null) { - return generateSequencedClassName(target.getName().replace("$", "_") - + SEPARATOR + StringUtils.capitalize(featureName)); - } - return generateSequencedClassName(AOT_PACKAGE + featureName); + Class targetToUse = (target != null ? target : this.defaultTarget); + String featureNameToUse = this.featureNamePrefix + featureName; + return targetToUse.getName().replace("$", "_") + + SEPARATOR + StringUtils.capitalize(featureNameToUse); + } + + /** + * Return a new {@link ClassNameGenerator} instance for the specified + * feature name prefix, keeping track of all the class names generated + * by this instance. + * @param featureNamePrefix the feature name prefix to use + * @return a new instance for the specified feature name prefix + */ + ClassNameGenerator usingFeatureNamePrefix(String featureNamePrefix) { + return new ClassNameGenerator(this.defaultTarget, featureNamePrefix, + this.sequenceGenerator); } private String clean(String name) { diff --git a/spring-core/src/main/java/org/springframework/aot/generate/DefaultGenerationContext.java b/spring-core/src/main/java/org/springframework/aot/generate/DefaultGenerationContext.java index 4698d28dbb..dd500d0b84 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/DefaultGenerationContext.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/DefaultGenerationContext.java @@ -17,12 +17,15 @@ package org.springframework.aot.generate; import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.springframework.aot.hint.RuntimeHints; import org.springframework.util.Assert; /** - * Default implementation of {@link GenerationContext}. + * Default {@link GenerationContext} implementation. * * @author Phillip Webb * @author Stephane Nicoll @@ -30,7 +33,7 @@ import org.springframework.util.Assert; */ public class DefaultGenerationContext implements GenerationContext { - private final ClassNameGenerator classNameGenerator; + private final Map sequenceGenerator; private final GeneratedClasses generatedClasses; @@ -41,39 +44,45 @@ public class DefaultGenerationContext implements GenerationContext { /** * Create a new {@link DefaultGenerationContext} instance backed by the - * specified {@code generatedFiles}. + * specified {@link ClassNameGenerator} and {@link GeneratedFiles}. + * @param classNameGenerator the naming convention to use for generated + * class names * @param generatedFiles the generated files */ - public DefaultGenerationContext(GeneratedFiles generatedFiles) { - this(new ClassNameGenerator(), generatedFiles, new RuntimeHints()); + public DefaultGenerationContext(ClassNameGenerator classNameGenerator, GeneratedFiles generatedFiles) { + this(new GeneratedClasses(classNameGenerator), generatedFiles, new RuntimeHints()); } /** * Create a new {@link DefaultGenerationContext} instance backed by the * specified items. - * @param classNameGenerator the class name generator + * @param generatedClasses the generated classes * @param generatedFiles the generated files * @param runtimeHints the runtime hints */ - public DefaultGenerationContext(ClassNameGenerator classNameGenerator, + public DefaultGenerationContext(GeneratedClasses generatedClasses, GeneratedFiles generatedFiles, RuntimeHints runtimeHints) { - Assert.notNull(classNameGenerator, "'classNameGenerator' must not be null"); + Assert.notNull(generatedClasses, "'generatedClasses' must not be null"); Assert.notNull(generatedFiles, "'generatedFiles' must not be null"); Assert.notNull(runtimeHints, "'runtimeHints' must not be null"); - this.classNameGenerator = classNameGenerator; - this.generatedClasses = new GeneratedClasses(classNameGenerator); + this.sequenceGenerator = new ConcurrentHashMap<>(); + this.generatedClasses = generatedClasses; this.generatedFiles = generatedFiles; this.runtimeHints = runtimeHints; } - - @Override - public ClassNameGenerator getClassNameGenerator() { - return this.classNameGenerator; + private DefaultGenerationContext(DefaultGenerationContext existing, String name) { + int sequence = existing.sequenceGenerator + .computeIfAbsent(name, key -> new AtomicInteger()).getAndIncrement(); + String nameToUse = (sequence > 0 ? name + sequence : name); + this.sequenceGenerator = existing.sequenceGenerator; + this.generatedClasses = existing.generatedClasses.withName(nameToUse); + this.generatedFiles = existing.generatedFiles; + this.runtimeHints = existing.runtimeHints; } @Override - public GeneratedClasses getClassGenerator() { + public GeneratedClasses getGeneratedClasses() { return this.generatedClasses; } @@ -87,6 +96,11 @@ public class DefaultGenerationContext implements GenerationContext { return this.runtimeHints; } + @Override + public GenerationContext withName(String name) { + return new DefaultGenerationContext(this, name); + } + /** * Write any generated content out to the generated files. */ diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java index caef4730e0..12c34c76c9 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java @@ -16,22 +16,24 @@ package org.springframework.aot.generate; -import org.springframework.aot.generate.ClassGenerator.JavaFileGenerator; +import java.util.function.Consumer; + import org.springframework.javapoet.ClassName; import org.springframework.javapoet.JavaFile; -import org.springframework.util.Assert; +import org.springframework.javapoet.TypeSpec; +import org.springframework.javapoet.TypeSpec.Builder; /** - * A generated class. + * A generated class is a container for generated methods. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 * @see GeneratedClasses - * @see ClassGenerator */ public final class GeneratedClass { - private final JavaFileGenerator JavaFileGenerator; + private final Consumer typeSpecCustomizer; private final ClassName name; @@ -44,12 +46,10 @@ public final class GeneratedClass { * {@link GeneratedClasses}. * @param name the generated name */ - GeneratedClass(JavaFileGenerator javaFileGenerator, ClassName name) { - MethodNameGenerator methodNameGenerator = new MethodNameGenerator( - javaFileGenerator.getReservedMethodNames()); - this.JavaFileGenerator = javaFileGenerator; + GeneratedClass(Consumer typeSpecCustomizer, ClassName name) { + this.typeSpecCustomizer = typeSpecCustomizer; this.name = name; - this.methods = new GeneratedMethods(methodNameGenerator); + this.methods = new GeneratedMethods(new MethodNameGenerator()); } @@ -70,15 +70,11 @@ public final class GeneratedClass { } JavaFile generateJavaFile() { - JavaFile javaFile = this.JavaFileGenerator.generateJavaFile(this.name, - this.methods); - Assert.state(this.name.packageName().equals(javaFile.packageName), - () -> "Generated JavaFile should be in package '" - + this.name.packageName() + "'"); - Assert.state(this.name.simpleName().equals(javaFile.typeSpec.name), - () -> "Generated JavaFile should be named '" + this.name.simpleName() - + "'"); - return javaFile; + TypeSpec.Builder typeSpecBuilder = TypeSpec.classBuilder(this.name); + this.typeSpecCustomizer.accept(typeSpecBuilder); + this.methods.doWithMethodSpecs(typeSpecBuilder::addMethod); + return JavaFile.builder(this.name.packageName(), typeSpecBuilder.build()) + .build(); } } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClasses.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClasses.java index 09e654b3a1..7d86b7faf0 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClasses.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClasses.java @@ -22,59 +22,143 @@ import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.TypeSpec; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** - * A managed collection of generated classes. + * A managed collection of generated classes. This class is stateful so the + * same instance should be used for all class generation. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 * @see GeneratedClass */ -public class GeneratedClasses implements ClassGenerator { +public class GeneratedClasses { private final ClassNameGenerator classNameGenerator; - private final Map classes = new ConcurrentHashMap<>(); + private final List classes; + private final Map classesByOwner; + /** + * Create a new instance using the specified naming conventions. + * @param classNameGenerator the class name generator to use + */ public GeneratedClasses(ClassNameGenerator classNameGenerator) { - Assert.notNull(classNameGenerator, "'classNameGenerator' must not be null"); - this.classNameGenerator = classNameGenerator; + this(classNameGenerator, new ArrayList<>(), new ConcurrentHashMap<>()); } - - @Override - public GeneratedClass getOrGenerateClass(JavaFileGenerator javaFileGenerator, - Class target, String featureName) { - - Assert.notNull(javaFileGenerator, "'javaFileGenerator' must not be null"); - Assert.notNull(target, "'target' must not be null"); - Assert.hasLength(featureName, "'featureName' must not be empty"); - Owner owner = new Owner(javaFileGenerator, target.getName(), featureName); - return this.classes.computeIfAbsent(owner, - key -> new GeneratedClass(javaFileGenerator, - this.classNameGenerator.generateClassName(target, featureName))); + private GeneratedClasses(ClassNameGenerator classNameGenerator, + List classes, Map classesByOwner) { + Assert.notNull(classNameGenerator, "'classNameGenerator' must not be null"); + this.classNameGenerator = classNameGenerator; + this.classes = classes; + this.classesByOwner = classesByOwner; } /** - * Write generated Spring {@code .factories} files to the given + * Prepare a {@link GeneratedClass} for the specified {@code featureName} + * targeting the specified {@code component}. + * @param featureName the name of the feature to associate with the generated class + * @param component the target component + * @return a {@link Builder} for further configuration + */ + public Builder forFeatureComponent(String featureName, Class component) { + Assert.hasLength(featureName, "'featureName' must not be empty"); + Assert.notNull(component, "'component' must not be null"); + return new Builder(featureName, component); + } + + /** + * Prepare a {@link GeneratedClass} for the specified {@code featureName} + * and no particular component. This should be used for high-level code + * generation that are widely applicable and for entry points. + * @param featureName the name of the feature to associate with the generated class + * @return a {@link Builder} for further configuration + */ + public Builder forFeature(String featureName) { + Assert.hasLength(featureName, "'featureName' must not be empty"); + return new Builder(featureName, null); + } + + /** + * Write the {@link GeneratedClass generated classes} using the given * {@link GeneratedFiles} instance. - * @param generatedFiles where to write the generated files + * @param generatedFiles where to write the generated classes * @throws IOException on IO error */ public void writeTo(GeneratedFiles generatedFiles) throws IOException { Assert.notNull(generatedFiles, "'generatedFiles' must not be null"); - List generatedClasses = new ArrayList<>(this.classes.values()); + List generatedClasses = new ArrayList<>(this.classes); generatedClasses.sort(Comparator.comparing(GeneratedClass::getName)); for (GeneratedClass generatedClass : generatedClasses) { generatedFiles.addSourceFile(generatedClass.generateJavaFile()); } } - private record Owner(JavaFileGenerator javaFileGenerator, String target, - String featureName) { + GeneratedClasses withName(String name) { + return new GeneratedClasses(this.classNameGenerator.usingFeatureNamePrefix(name), + this.classes, this.classesByOwner); + } + + private record Owner(String id, String className) { + + } + + public class Builder { + + private final String featureName; + + @Nullable + private final Class target; + + + Builder(String featureName, @Nullable Class target) { + this.target = target; + this.featureName = featureName; + } + + /** + * Generate a new {@link GeneratedClass} using the specified type + * customizer. + * @param typeSpecCustomizer a customizer for the {@link TypeSpec.Builder} + * @return a new {@link GeneratedClass} + */ + public GeneratedClass generate(Consumer typeSpecCustomizer) { + Assert.notNull(typeSpecCustomizer, "'typeSpecCustomizer' must not be null"); + return createGeneratedClass(typeSpecCustomizer); + } + + + /** + * Get or generate a new {@link GeneratedClass} for the specified {@code id}. + * @param id a unique identifier + * @param typeSpecCustomizer a customizer for the {@link TypeSpec.Builder} + * @return a {@link GeneratedClass} instance + */ + public GeneratedClass getOrGenerate(String id, + Consumer typeSpecCustomizer) { + Assert.hasLength(id, "'id' must not be empty"); + Assert.notNull(typeSpecCustomizer, "'typeSpecCustomizer' must not be null"); + Owner owner = new Owner(id, GeneratedClasses.this.classNameGenerator + .getClassName(this.target, this.featureName)); + return GeneratedClasses.this.classesByOwner.computeIfAbsent(owner, + key -> createGeneratedClass(typeSpecCustomizer)); + } + + private GeneratedClass createGeneratedClass(Consumer typeSpecCustomizer) { + ClassName className = GeneratedClasses.this.classNameGenerator + .generateClassName(this.target, this.featureName); + GeneratedClass generatedClass = new GeneratedClass(typeSpecCustomizer, className); + GeneratedClasses.this.classes.add(generatedClass); + return generatedClass; + } } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GenerationContext.java b/spring-core/src/main/java/org/springframework/aot/generate/GenerationContext.java index 2cd1c770bf..d74d97e98f 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GenerationContext.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GenerationContext.java @@ -24,38 +24,31 @@ import org.springframework.aot.hint.SerializationHints; /** * Central interface used for code generation. - *

- * A generation context provides: + * + *

A generation context provides: *

    - *
  • Support for {@link #getClassNameGenerator() class name generation}.
  • - *
  • Central management of all {@link #getGeneratedFiles() generated - * files}.
  • - *
  • Support for the recording of {@link #getRuntimeHints() runtime - * hints}.
  • + *
  • Management of all {@link #getGeneratedClasses()} generated classes}, + * including naming convention support.
  • + *
  • Central management of all {@link #getGeneratedFiles() generated files}.
  • + *
  • Support for the recording of {@link #getRuntimeHints() runtime hints}.
  • *
* + *

If a dedicated round of code generation is required while processing, it + * is possible to create a specialized context using {@link #withName(String)}. + * * @author Phillip Webb * @author Stephane Nicoll * @since 6.0 */ public interface GenerationContext { - /** - * Return the {@link ClassNameGenerator} being used by the context. Allows - * new class names to be generated before they are added to the - * {@link #getGeneratedFiles() generated files}. - * @return the class name generator - * @see #getGeneratedFiles() - */ - ClassNameGenerator getClassNameGenerator(); - /** * Return the {@link GeneratedClasses} being used by the context. Allows a * single generated class to be shared across multiple AOT processors. All * generated classes are written at the end of AOT processing. * @return the generated classes */ - ClassGenerator getClassGenerator(); + GeneratedClasses getGeneratedClasses(); /** * Return the {@link GeneratedFiles} being used by the context. Used to @@ -73,4 +66,14 @@ public interface GenerationContext { */ RuntimeHints getRuntimeHints(); + /** + * Return a new {@link GenerationContext} instance using the specified + * name to qualify generated assets for a dedicated round of code + * generation. If this name is already in use, a unique sequence is added + * to ensure the name is unique. + * @param name the name to use + * @return a specialized {@link GenerationContext} for the specified name + */ + GenerationContext withName(String name); + } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/ClassNameGeneratorTests.java b/spring-core/src/test/java/org/springframework/aot/generate/ClassNameGeneratorTests.java index f45c3e1618..ae8354341d 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/ClassNameGeneratorTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/ClassNameGeneratorTests.java @@ -32,12 +32,26 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException */ class ClassNameGeneratorTests { - private final ClassNameGenerator generator = new ClassNameGenerator(); + private final ClassNameGenerator generator = new ClassNameGenerator(Object.class); @Test - void generateClassNameWhenTargetClassIsNullUsesAotPackage() { - ClassName generated = this.generator.generateClassName((Class) null, "test"); - assertThat(generated).hasToString("__.Test"); + void generateClassNameWhenTargetClassIsNullUsesMainTarget() { + ClassName generated = this.generator.generateClassName(null, "test"); + assertThat(generated).hasToString("java.lang.Object__Test"); + } + + @Test + void generateClassNameUseFeatureNamePrefix() { + ClassName generated = new ClassNameGenerator(Object.class, "One") + .generateClassName(InputStream.class, "test"); + assertThat(generated).hasToString("java.io.InputStream__OneTest"); + } + + @Test + void generateClassNameWithNoTextFeatureNamePrefix() { + ClassName generated = new ClassNameGenerator(Object.class, " ") + .generateClassName(InputStream.class, "test"); + assertThat(generated).hasToString("java.io.InputStream__Test"); } @Test @@ -59,8 +73,7 @@ class ClassNameGeneratorTests { @Test void generateClassNameWithClassWhenLowercaseFeatureNameGeneratesName() { - ClassName generated = this.generator.generateClassName(InputStream.class, - "bytes"); + ClassName generated = this.generator.generateClassName(InputStream.class, "bytes"); assertThat(generated).hasToString("java.io.InputStream__Bytes"); } @@ -68,7 +81,7 @@ class ClassNameGeneratorTests { void generateClassNameWithClassWhenInnerClassGeneratesName() { ClassName generated = this.generator.generateClassName(TestBean.class, "EventListener"); assertThat(generated) - .hasToString("org.springframework.aot.generate.ClassNameGeneratorTests_TestBean__EventListener"); + .hasToString("org.springframework.aot.generate.ClassNameGeneratorTests_TestBean__EventListener"); } @Test @@ -81,6 +94,15 @@ class ClassNameGeneratorTests { assertThat(generated3).hasToString("java.io.InputStream__Bytes2"); } + @Test + void getClassNameWhenMultipleCallsReturnsSameName() { + String name1 = this.generator.getClassName(InputStream.class, "bytes"); + String name2 = this.generator.getClassName(InputStream.class, "bytes"); + String name3 = this.generator.getClassName(InputStream.class, "bytes"); + assertThat(name1).hasToString("java.io.InputStream__Bytes") + .isEqualTo(name2).isEqualTo(name3); + } + static class TestBean { } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/DefaultGenerationContextTests.java b/spring-core/src/test/java/org/springframework/aot/generate/DefaultGenerationContextTests.java index a6d84e87eb..306a5ae047 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/DefaultGenerationContextTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/DefaultGenerationContextTests.java @@ -16,9 +16,14 @@ package org.springframework.aot.generate; +import java.util.function.Consumer; + import org.junit.jupiter.api.Test; +import org.springframework.aot.generate.GeneratedFiles.Kind; import org.springframework.aot.hint.RuntimeHints; +import org.springframework.core.testfixture.aot.generate.TestTarget; +import org.springframework.javapoet.TypeSpec.Builder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -31,9 +36,12 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException */ class DefaultGenerationContextTests { - private final ClassNameGenerator classNameGenerator = new ClassNameGenerator(); + private static final Consumer typeSpecCustomizer = type -> {}; - private final GeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); + private final GeneratedClasses generatedClasses = new GeneratedClasses( + new ClassNameGenerator(TestTarget.class)); + + private final InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); private final RuntimeHints runtimeHints = new RuntimeHints(); @@ -41,9 +49,7 @@ class DefaultGenerationContextTests { @Test void createWithOnlyGeneratedFilesCreatesContext() { DefaultGenerationContext context = new DefaultGenerationContext( - this.generatedFiles); - assertThat(context.getClassNameGenerator()) - .isInstanceOf(ClassNameGenerator.class); + new ClassNameGenerator(TestTarget.class), this.generatedFiles); assertThat(context.getGeneratedFiles()).isSameAs(this.generatedFiles); assertThat(context.getRuntimeHints()).isInstanceOf(RuntimeHints.class); } @@ -51,24 +57,23 @@ class DefaultGenerationContextTests { @Test void createCreatesContext() { DefaultGenerationContext context = new DefaultGenerationContext( - this.classNameGenerator, this.generatedFiles, this.runtimeHints); - assertThat(context.getClassNameGenerator()).isNotNull(); + this.generatedClasses, this.generatedFiles, this.runtimeHints); assertThat(context.getGeneratedFiles()).isNotNull(); assertThat(context.getRuntimeHints()).isNotNull(); } @Test - void createWhenClassNameGeneratorIsNullThrowsException() { + void createWhenGeneratedClassesIsNullThrowsException() { assertThatIllegalArgumentException() .isThrownBy(() -> new DefaultGenerationContext(null, this.generatedFiles, this.runtimeHints)) - .withMessage("'classNameGenerator' must not be null"); + .withMessage("'generatedClasses' must not be null"); } @Test void createWhenGeneratedFilesIsNullThrowsException() { assertThatIllegalArgumentException() - .isThrownBy(() -> new DefaultGenerationContext(this.classNameGenerator, + .isThrownBy(() -> new DefaultGenerationContext(this.generatedClasses, null, this.runtimeHints)) .withMessage("'generatedFiles' must not be null"); } @@ -76,30 +81,71 @@ class DefaultGenerationContextTests { @Test void createWhenRuntimeHintsIsNullThrowsException() { assertThatIllegalArgumentException() - .isThrownBy(() -> new DefaultGenerationContext(this.classNameGenerator, + .isThrownBy(() -> new DefaultGenerationContext(this.generatedClasses, this.generatedFiles, null)) .withMessage("'runtimeHints' must not be null"); } @Test - void getClassNameGeneratorReturnsClassNameGenerator() { + void getGeneratedClassesReturnsClassNameGenerator() { DefaultGenerationContext context = new DefaultGenerationContext( - this.classNameGenerator, this.generatedFiles, this.runtimeHints); - assertThat(context.getClassNameGenerator()).isSameAs(this.classNameGenerator); + this.generatedClasses, this.generatedFiles, this.runtimeHints); + assertThat(context.getGeneratedClasses()).isSameAs(this.generatedClasses); } @Test void getGeneratedFilesReturnsGeneratedFiles() { DefaultGenerationContext context = new DefaultGenerationContext( - this.classNameGenerator, this.generatedFiles, this.runtimeHints); + this.generatedClasses, this.generatedFiles, this.runtimeHints); assertThat(context.getGeneratedFiles()).isSameAs(this.generatedFiles); } @Test void getRuntimeHintsReturnsRuntimeHints() { DefaultGenerationContext context = new DefaultGenerationContext( - this.classNameGenerator, this.generatedFiles, this.runtimeHints); + this.generatedClasses, this.generatedFiles, this.runtimeHints); assertThat(context.getRuntimeHints()).isSameAs(this.runtimeHints); } + @Test + void withNameUpdateNamingConvention() { + DefaultGenerationContext context = new DefaultGenerationContext( + new ClassNameGenerator(TestTarget.class), this.generatedFiles); + GenerationContext anotherContext = context.withName("Another"); + GeneratedClass generatedClass = anotherContext.getGeneratedClasses() + .forFeature("Test").generate(typeSpecCustomizer); + assertThat(generatedClass.getName().simpleName()).endsWith("__AnotherTest"); + } + + @Test + void withNameKeepTrackOfAllGeneratedFiles() { + DefaultGenerationContext context = new DefaultGenerationContext( + new ClassNameGenerator(TestTarget.class), this.generatedFiles); + context.getGeneratedClasses().forFeature("Test").generate(typeSpecCustomizer); + GenerationContext anotherContext = context.withName("Another"); + assertThat(anotherContext.getGeneratedClasses()).isNotSameAs(context.getGeneratedClasses()); + assertThat(anotherContext.getGeneratedFiles()).isSameAs(context.getGeneratedFiles()); + assertThat(anotherContext.getRuntimeHints()).isSameAs(context.getRuntimeHints()); + anotherContext.getGeneratedClasses().forFeature("Test").generate(typeSpecCustomizer); + context.writeGeneratedContent(); + assertThat(this.generatedFiles.getGeneratedFiles(Kind.SOURCE)).hasSize(2); + } + + @Test + void withNameGenerateUniqueName() { + DefaultGenerationContext context = new DefaultGenerationContext( + new ClassNameGenerator(Object.class), this.generatedFiles); + context.withName("Test").getGeneratedClasses() + .forFeature("Feature").generate(typeSpecCustomizer); + context.withName("Test").getGeneratedClasses() + .forFeature("Feature").generate(typeSpecCustomizer); + context.withName("Test").getGeneratedClasses() + .forFeature("Feature").generate(typeSpecCustomizer); + context.writeGeneratedContent(); + assertThat(this.generatedFiles.getGeneratedFiles(Kind.SOURCE)).containsOnlyKeys( + "java/lang/Object__TestFeature.java", + "java/lang/Object__Test1Feature.java", + "java/lang/Object__Test2Feature.java"); + } + } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java index 35ef736c8c..df0245c652 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java @@ -16,77 +16,43 @@ package org.springframework.aot.generate; +import java.util.function.Consumer; + import org.junit.jupiter.api.Test; import org.springframework.javapoet.ClassName; -import org.springframework.javapoet.JavaFile; -import org.springframework.javapoet.TypeSpec; +import org.springframework.javapoet.TypeSpec.Builder; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** * Tests for {@link GeneratedClass}. * * @author Phillip Webb + * @author Stephane Nicoll */ class GeneratedClassTests { @Test void getNameReturnsName() { ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(this::generateJavaFile, name); + GeneratedClass generatedClass = new GeneratedClass(emptyTypeSpec(), name); assertThat(generatedClass.getName()).isSameAs(name); } @Test - void generateJavaFileSuppliesGeneratedMethods() { + void generateJavaFileIncludesGeneratedMethods() { ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(this::generateJavaFile, name); + GeneratedClass generatedClass = new GeneratedClass(emptyTypeSpec(), name); MethodGenerator methodGenerator = generatedClass.getMethodGenerator(); methodGenerator.generateMethod("test") .using(builder -> builder.addJavadoc("Test Method")); assertThat(generatedClass.generateJavaFile().toString()).contains("Test Method"); } - @Test - void generateJavaFileWhenHasBadPackageThrowsException() { - ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass( - this::generateBadPackageJavaFile, name); - assertThatIllegalStateException() - .isThrownBy( - () -> assertThat(generatedClass.generateJavaFile().toString())) - .withMessageContaining("should be in package"); - } - @Test - void generateJavaFileWhenHasBadNameThrowsException() { - ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(this::generateBadNameJavaFile, - name); - assertThatIllegalStateException() - .isThrownBy( - () -> assertThat(generatedClass.generateJavaFile().toString())) - .withMessageContaining("should be named"); - } - - private JavaFile generateJavaFile(ClassName className, GeneratedMethods methods) { - TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); - methods.doWithMethodSpecs(classBuilder::addMethod); - return JavaFile.builder(className.packageName(), classBuilder.build()).build(); - } - - private JavaFile generateBadPackageJavaFile(ClassName className, - GeneratedMethods methods) { - TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); - return JavaFile.builder("naughty", classBuilder.build()).build(); - } - - private JavaFile generateBadNameJavaFile(ClassName className, - GeneratedMethods methods) { - TypeSpec.Builder classBuilder = TypeSpec.classBuilder("Naughty"); - return JavaFile.builder(className.packageName(), classBuilder.build()).build(); + private Consumer emptyTypeSpec() { + return type -> {}; } } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassesTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassesTests.java index 420e610f19..7aca293319 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassesTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassesTests.java @@ -16,27 +16,34 @@ package org.springframework.aot.generate; +import java.io.IOException; +import java.util.function.Consumer; + import org.junit.jupiter.api.Test; -import org.springframework.aot.generate.ClassGenerator.JavaFileGenerator; -import org.springframework.javapoet.ClassName; -import org.springframework.javapoet.JavaFile; +import org.springframework.aot.generate.GeneratedFiles.Kind; import org.springframework.javapoet.TypeSpec; +import org.springframework.javapoet.TypeSpec.Builder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; /** * Tests for {@link GeneratedClasses}. * * @author Phillip Webb + * @author Stephane Nicoll */ class GeneratedClassesTests { - private GeneratedClasses generatedClasses = new GeneratedClasses( - new ClassNameGenerator()); + private static final Consumer emptyTypeCustomizer = type -> {}; - private static final JavaFileGenerator JAVA_FILE_GENERATOR = GeneratedClassesTests::generateJavaFile; + private final GeneratedClasses generatedClasses = new GeneratedClasses( + new ClassNameGenerator(Object.class)); @Test void createWhenClassNameGeneratorIsNullThrowsException() { @@ -45,61 +52,118 @@ class GeneratedClassesTests { } @Test - void getOrGenerateWithClassTargetWhenJavaFileGeneratorIsNullThrowsException() { + void forFeatureComponentWhenTargetIsNullThrowsException() { assertThatIllegalArgumentException() - .isThrownBy(() -> this.generatedClasses.getOrGenerateClass(null, - TestTarget.class, "test")) - .withMessage("'javaFileGenerator' must not be null"); + .isThrownBy(() -> this.generatedClasses.forFeatureComponent("test", null)) + .withMessage("'component' must not be null"); } @Test - void getOrGenerateWithClassTargetWhenTargetIsNullThrowsException() { + void forFeatureComponentWhenFeatureNameIsEmptyThrowsException() { assertThatIllegalArgumentException() - .isThrownBy(() -> this.generatedClasses - .getOrGenerateClass(JAVA_FILE_GENERATOR, (Class) null, "test")) - .withMessage("'target' must not be null"); - } - - @Test - void getOrGenerateWithClassTargetWhenFeatureIsNullThrowsException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.generatedClasses - .getOrGenerateClass(JAVA_FILE_GENERATOR, TestTarget.class, null)) + .isThrownBy(() -> this.generatedClasses.forFeatureComponent("", TestComponent.class)) .withMessage("'featureName' must not be empty"); } @Test - void getOrGenerateWhenNewReturnsGeneratedMethod() { + void forFeatureWhenFeatureNameIsEmptyThrowsException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.generatedClasses.forFeature("")) + .withMessage("'featureName' must not be empty"); + } + + @Test + void generateWhenTypeSpecCustomizerIsNullThrowsException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.generatedClasses + .forFeatureComponent("test", TestComponent.class).generate(null)) + .withMessage("'typeSpecCustomizer' must not be null"); + } + + @Test + void forFeatureUsesDefaultTarget() { + GeneratedClass generatedClass = this.generatedClasses + .forFeature("Test").generate(emptyTypeCustomizer); + assertThat(generatedClass.getName()).hasToString("java.lang.Object__Test"); + } + + @Test + void forFeatureComponentUsesComponent() { + GeneratedClass generatedClass = this.generatedClasses + .forFeatureComponent("Test", TestComponent.class).generate(emptyTypeCustomizer); + assertThat(generatedClass.getName().toString()).endsWith("TestComponent__Test"); + } + + @Test + void generateReturnsDifferentInstances() { + Consumer typeCustomizer = mockTypeCustomizer(); GeneratedClass generatedClass1 = this.generatedClasses - .getOrGenerateClass(JAVA_FILE_GENERATOR, TestTarget.class, "one"); + .forFeatureComponent("one", TestComponent.class).generate(typeCustomizer); GeneratedClass generatedClass2 = this.generatedClasses - .getOrGenerateClass(JAVA_FILE_GENERATOR, TestTarget.class, "two"); + .forFeatureComponent("one", TestComponent.class).generate(typeCustomizer); + assertThat(generatedClass1).isNotSameAs(generatedClass2); + assertThat(generatedClass1.getName().simpleName()).endsWith("__One"); + assertThat(generatedClass2.getName().simpleName()).endsWith("__One1"); + } + + @Test + void getOrGenerateWhenNewReturnsGeneratedMethod() { + Consumer typeCustomizer = mockTypeCustomizer(); + GeneratedClass generatedClass1 = this.generatedClasses + .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer); + GeneratedClass generatedClass2 = this.generatedClasses + .forFeatureComponent("two", TestComponent.class).getOrGenerate("facet", typeCustomizer); assertThat(generatedClass1).isNotNull().isNotEqualTo(generatedClass2); assertThat(generatedClass2).isNotNull(); } @Test void getOrGenerateWhenRepeatReturnsSameGeneratedMethod() { - GeneratedClasses generated = this.generatedClasses; - GeneratedClass generatedClass1 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR, - TestTarget.class, "one"); - GeneratedClass generatedClass2 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR, - TestTarget.class, "one"); - GeneratedClass generatedClass3 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR, - TestTarget.class, "one"); - GeneratedClass generatedClass4 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR, - TestTarget.class, "two"); + Consumer typeCustomizer = mockTypeCustomizer(); + GeneratedClass generatedClass1 = this.generatedClasses + .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer); + GeneratedClass generatedClass2 = this.generatedClasses + .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer); + GeneratedClass generatedClass3 = this.generatedClasses + .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer); assertThat(generatedClass1).isNotNull().isSameAs(generatedClass2) - .isSameAs(generatedClass3).isNotSameAs(generatedClass4); + .isSameAs(generatedClass3); + verifyNoInteractions(typeCustomizer); + generatedClass1.generateJavaFile(); + verify(typeCustomizer).accept(any()); } - static JavaFile generateJavaFile(ClassName className, - GeneratedMethods generatedMethods) { - TypeSpec typeSpec = TypeSpec.classBuilder(className).addJavadoc("Test").build(); - return JavaFile.builder(className.packageName(), typeSpec).build(); + @Test + @SuppressWarnings("unchecked") + void writeToInvokeTypeSpecCustomizer() throws IOException { + Consumer typeSpecCustomizer = mock(Consumer.class); + this.generatedClasses.forFeatureComponent("one", TestComponent.class) + .generate(typeSpecCustomizer); + verifyNoInteractions(typeSpecCustomizer); + InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); + this.generatedClasses.writeTo(generatedFiles); + verify(typeSpecCustomizer).accept(any()); + assertThat(generatedFiles.getGeneratedFiles(Kind.SOURCE)).hasSize(1); } - private static class TestTarget { + @Test + void withNameUpdatesNamingConventions() { + GeneratedClass generatedClass1 = this.generatedClasses + .forFeatureComponent("one", TestComponent.class).generate(emptyTypeCustomizer); + GeneratedClass generatedClass2 = this.generatedClasses.withName("Another") + .forFeatureComponent("one", TestComponent.class).generate(emptyTypeCustomizer); + assertThat(generatedClass1.getName().toString()).endsWith("TestComponent__One"); + assertThat(generatedClass2.getName().toString()).endsWith("TestComponent__AnotherOne"); + } + + + @SuppressWarnings("unchecked") + private Consumer mockTypeCustomizer() { + return mock(Consumer.class); + } + + + private static class TestComponent { } diff --git a/spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestGenerationContext.java b/spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestGenerationContext.java new file mode 100644 index 0000000000..ef50d4d4ca --- /dev/null +++ b/spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestGenerationContext.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.core.testfixture.aot.generate; + +import org.springframework.aot.generate.ClassNameGenerator; +import org.springframework.aot.generate.DefaultGenerationContext; +import org.springframework.aot.generate.GeneratedFiles; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.InMemoryGeneratedFiles; + +/** + * Test {@link GenerationContext} implementation that uses + * {@link TestTarget} as the main target. + * + * @author Stephane Nicoll + */ +public class TestGenerationContext extends DefaultGenerationContext { + + public TestGenerationContext(GeneratedFiles generatedFiles) { + super(new ClassNameGenerator(TestTarget.class), generatedFiles); + } + + public TestGenerationContext() { + this(new InMemoryGeneratedFiles()); + } +} diff --git a/spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestTarget.java b/spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestTarget.java new file mode 100644 index 0000000000..d1b5568c28 --- /dev/null +++ b/spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestTarget.java @@ -0,0 +1,25 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.core.testfixture.aot.generate; + +/** + * A target used by tests of code generation. + * + * @author Stephane Nicoll + */ +public class TestTarget { +} diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java index 68b4a178ed..3363a6ec01 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java @@ -30,6 +30,7 @@ import java.util.Map; import java.util.Properties; import java.util.TreeSet; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; import jakarta.persistence.EntityManager; import jakarta.persistence.EntityManagerFactory; @@ -39,11 +40,10 @@ import jakarta.persistence.PersistenceProperty; import jakarta.persistence.PersistenceUnit; import jakarta.persistence.SynchronizationType; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; -import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodGenerator; -import org.springframework.aot.generate.MethodNameGenerator; import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.BeanUtils; @@ -70,11 +70,9 @@ import org.springframework.core.BridgeMethodResolver; import org.springframework.core.Ordered; import org.springframework.core.PriorityOrdered; import org.springframework.core.annotation.AnnotationUtils; -import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; -import org.springframework.javapoet.TypeSpec; +import org.springframework.javapoet.MethodSpec.Builder; import org.springframework.jndi.JndiLocatorDelegate; import org.springframework.jndi.JndiTemplate; import org.springframework.lang.Nullable; @@ -789,34 +787,27 @@ public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwar @Override public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) { - ClassName className = generationContext.getClassNameGenerator() - .generateClassName(this.target, "PersistenceInjection"); - TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); - classBuilder.addJavadoc("Persistence injection for {@link $T}.", this.target); - classBuilder.addModifiers(javax.lang.model.element.Modifier.PUBLIC); - GeneratedMethods methods = new GeneratedMethods( - new MethodNameGenerator(APPLY_METHOD)); - classBuilder.addMethod(generateMethod(generationContext.getRuntimeHints(), - className, methods)); - methods.doWithMethodSpecs(classBuilder::addMethod); - JavaFile javaFile = JavaFile - .builder(className.packageName(), classBuilder.build()).build(); - generationContext.getGeneratedFiles().addSourceFile(javaFile); + GeneratedClass generatedClass = generationContext.getGeneratedClasses() + .forFeatureComponent("PersistenceInjection", this.target).generate(type -> { + type.addJavadoc("Persistence injection for {@link $T}.", this.target); + type.addModifiers(javax.lang.model.element.Modifier.PUBLIC); + }); + generatedClass.getMethodGenerator().generateMethod(APPLY_METHOD) + .using(generateMethod(generationContext.getRuntimeHints(), generatedClass.getMethodGenerator())); beanRegistrationCode.addInstancePostProcessor( - MethodReference.ofStatic(className, APPLY_METHOD)); + MethodReference.ofStatic(generatedClass.getName(), APPLY_METHOD)); } - private MethodSpec generateMethod(RuntimeHints hints, ClassName className, - MethodGenerator methodGenerator) { - MethodSpec.Builder builder = MethodSpec.methodBuilder(APPLY_METHOD); - builder.addJavadoc("Apply the persistence injection."); - builder.addModifiers(javax.lang.model.element.Modifier.PUBLIC, - javax.lang.model.element.Modifier.STATIC); - builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); - builder.addParameter(this.target, INSTANCE_PARAMETER); - builder.returns(this.target); - builder.addCode(generateMethodCode(hints, methodGenerator)); - return builder.build(); + private Consumer generateMethod(RuntimeHints hints, MethodGenerator methodGenerator) { + return method -> { + method.addJavadoc("Apply the persistence injection."); + method.addModifiers(javax.lang.model.element.Modifier.PUBLIC, + javax.lang.model.element.Modifier.STATIC); + method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); + method.addParameter(this.target, INSTANCE_PARAMETER); + method.returns(this.target); + method.addCode(generateMethodCode(hints, methodGenerator)); + }; } private CodeBlock generateMethodCode(RuntimeHints hints, diff --git a/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorAotContributionTests.java b/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorAotContributionTests.java index f38e1da8eb..8430b95e09 100644 --- a/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorAotContributionTests.java +++ b/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorAotContributionTests.java @@ -43,6 +43,7 @@ import org.springframework.beans.factory.aot.BeanRegistrationCode; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -67,7 +68,7 @@ class PersistenceAnnotationBeanPostProcessorAotContributionTests { void setup() { this.beanFactory = new DefaultListableBeanFactory(); this.generatedFiles = new InMemoryGeneratedFiles(); - this.generationContext = new DefaultGenerationContext(generatedFiles); + this.generationContext = new TestGenerationContext(generatedFiles); } @Test @@ -183,6 +184,7 @@ class PersistenceAnnotationBeanPostProcessorAotContributionTests { .processAheadOfTime(registeredBean); BeanRegistrationCode beanRegistrationCode = mock(BeanRegistrationCode.class); contribution.applyTo(generationContext, beanRegistrationCode); + generationContext.writeGeneratedContent(); TestCompiler.forSystem().withFiles(generatedFiles) .compile(compiled -> result.accept(new Invoker(compiled), compiled)); }