Allow ApplicationContextAotGenerator to generated better class names

Update `ApplicationContextAotGenerator` so that it can generate class
names based on a `target` class and using the ID of the application
context. Prior to this commit, the generated class name was always
`__.BeanFactoryRegistrations`.

Closes gh-28565
This commit is contained in:
Phillip Webb 2022-06-03 16:43:13 -07:00
parent 243350054b
commit 4bd33cb6e0
9 changed files with 99 additions and 132 deletions

View File

@ -18,6 +18,7 @@ package org.springframework.beans.factory.aot;
import org.springframework.aot.generate.MethodGenerator;
import org.springframework.aot.generate.MethodReference;
import org.springframework.lang.Nullable;
/**
* Interface that can be used to configure the code that will be generated to
@ -34,6 +35,24 @@ public interface BeanFactoryInitializationCode {
*/
String BEAN_FACTORY_VARIABLE = "beanFactory";
/**
* Return the target class for this bean factory or {@code null} if there is
* no target.
* @return the target
*/
@Nullable
default Class<?> getTarget() {
return null;
}
/**
* Return the ID of the bean factory or and empty string if no ID is avaialble.
* @return the bean factory ID
*/
default String getId() {
return "";
}
/**
* Return a {@link MethodGenerator} that can be used to add more methods to
* the Initializing code.

View File

@ -61,8 +61,9 @@ class BeanRegistrationsAotContribution
public void applyTo(GenerationContext generationContext,
BeanFactoryInitializationCode beanFactoryInitializationCode) {
ClassName className = generationContext.getClassNameGenerator()
.generateClassName("BeanFactory", "Registrations");
ClassName className = generationContext.getClassNameGenerator().generateClassName(
beanFactoryInitializationCode.getTarget(),
beanFactoryInitializationCode.getId() + "BeanFactoryRegistrations");
BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator(
className);
GeneratedMethod registerMethod = codeGenerator.getMethodGenerator()

View File

@ -23,6 +23,7 @@ import org.springframework.context.ApplicationContext;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.JavaFile;
import org.springframework.lang.Nullable;
/**
* Process an {@link ApplicationContext} and its {@link BeanFactory} to generate
@ -48,10 +49,29 @@ public class ApplicationContextAotGenerator {
GenerationContext generationContext,
ClassName generatedInitializerClassName) {
generateApplicationContext(applicationContext, null, generationContext,
generatedInitializerClassName);
}
/**
* Refresh the specified {@link GenericApplicationContext} and generate the
* necessary code to restore the state of its {@link BeanFactory}, using the
* specified {@link GenerationContext}.
* @param applicationContext the application context to handle
* @param target the target class for the generated initializer
* @param generationContext the generation context to use
* @param generatedInitializerClassName the class name to use for the
* generated application context initializer
*/
public void generateApplicationContext(GenericApplicationContext applicationContext,
@Nullable Class<?> target, GenerationContext generationContext,
ClassName generatedInitializerClassName) {
applicationContext.refreshForAotProcessing();
DefaultListableBeanFactory beanFactory = applicationContext
.getDefaultListableBeanFactory();
ApplicationContextInitializationCodeGenerator codeGenerator = new ApplicationContextInitializationCodeGenerator();
ApplicationContextInitializationCodeGenerator codeGenerator = new ApplicationContextInitializationCodeGenerator(
target, applicationContext.getId());
new BeanFactoryInitializationAotContributions(beanFactory).applyTo(generationContext,
codeGenerator);
JavaFile javaFile = codeGenerator.generateJavaFile(generatedInitializerClassName);

View File

@ -35,6 +35,7 @@ import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.ParameterizedTypeName;
import org.springframework.javapoet.TypeSpec;
import org.springframework.util.StringUtils;
/**
* Internal code generator to create the application context initializer.
@ -48,11 +49,31 @@ class ApplicationContextInitializationCodeGenerator
private static final String APPLICATION_CONTEXT_VARIABLE = "applicationContext";
private final Class<?> target;
private final String id;
private final GeneratedMethods generatedMethods = new GeneratedMethods();
private final List<MethodReference> initializers = new ArrayList<>();
ApplicationContextInitializationCodeGenerator(Class<?> target, String id) {
this.target=target;
this.id = (!StringUtils.hasText(id)) ? "" : id;
}
@Override
public Class<?> getTarget() {
return this.target;
}
@Override
public String getId() {
return this.id;
}
@Override
public MethodGenerator getMethodGenerator() {
return this.generatedMethods;

View File

@ -43,18 +43,6 @@ public interface ClassGenerator {
GeneratedClass getOrGenerateClass(JavaFileGenerator javaFileGenerator,
Class<?> target, String featureName);
/**
* Get or generate a new {@link GeneratedClass} for a given java file
* generator, target and feature name.
* @param javaFileGenerator the java file generator
* @param target the target of the newly generated class
* @param featureName the name of the feature that the generated class
* supports
* @return a {@link GeneratedClass} instance
*/
GeneratedClass getOrGenerateClass(JavaFileGenerator javaFileGenerator, String target,
String featureName);
/**
* Strategy used to generate the java file for the generated class.

View File

@ -21,6 +21,7 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.springframework.javapoet.ClassName;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
@ -38,8 +39,9 @@ public final class ClassNameGenerator {
private static final String SEPARATOR = "__";
private static final String AOT_PACKAGE = "__";
private static final String AOT_PACKAGE = "__.";
private static final String AOT_FEATURE = "Aot";
private final Map<String, AtomicInteger> sequenceGenerator = new ConcurrentHashMap<>();
@ -47,53 +49,37 @@ public final class ClassNameGenerator {
/**
* Generate a new class name for the given {@code target} /
* {@code featureName} combination.
* @param target the target of the newly generated class
* @param target the target of the newly generated class or {@code null} if
* there is not target.
* @param featureName the name of the feature that the generated class
* supports
* @return a unique generated class name
*/
public ClassName generateClassName(Class<?> target, String featureName) {
Assert.notNull(target, "'target' must not be null");
String rootName = target.getName().replace("$", "_");
return generateSequencedClassName(rootName, featureName);
}
/**
* Generate a new class name for the given {@code name} /
* {@code featureName} combination.
* @param target the target of the newly generated class. When possible,
* this should be a class name
* @param featureName the name of the feature that the generated class
* supports
* @return a unique generated class name
*/
public ClassName generateClassName(String target, String featureName) {
Assert.hasLength(target, "'target' must not be empty");
target = clean(target);
String rootName = AOT_PACKAGE + "." + ((!target.isEmpty()) ? target : "Aot");
return generateSequencedClassName(rootName, featureName);
public ClassName generateClassName(@Nullable Class<?> target, String featureName) {
Assert.hasLength(featureName, "'featureName' must not be empty");
featureName = clean(featureName);
if(target != null) {
return generateSequencedClassName(target.getName().replace("$", "_") + SEPARATOR + StringUtils.capitalize(featureName));
}
return generateSequencedClassName(AOT_PACKAGE+ featureName);
}
private String clean(String name) {
StringBuilder rootName = new StringBuilder();
StringBuilder clean = new StringBuilder();
boolean lastNotLetter = true;
for (char ch : name.toCharArray()) {
if (!Character.isLetter(ch)) {
lastNotLetter = true;
continue;
}
rootName.append(lastNotLetter ? Character.toUpperCase(ch) : ch);
clean.append(lastNotLetter ? Character.toUpperCase(ch) : ch);
lastNotLetter = false;
}
return rootName.toString();
return (!clean.isEmpty()) ? clean.toString() : AOT_FEATURE;
}
private ClassName generateSequencedClassName(String rootName, String featureName) {
Assert.hasLength(featureName, "'featureName' must not be empty");
Assert.isTrue(featureName.chars().allMatch(Character::isLetter),
"'featureName' must contain only letters");
String name = addSequence(
rootName + SEPARATOR + StringUtils.capitalize(featureName));
private ClassName generateSequencedClassName(String name) {
name = addSequence(name);
return ClassName.get(ClassUtils.getPackageName(name),
ClassUtils.getShortName(name));
}

View File

@ -58,19 +58,6 @@ public class GeneratedClasses implements ClassGenerator {
this.classNameGenerator.generateClassName(target, featureName)));
}
@Override
public GeneratedClass getOrGenerateClass(JavaFileGenerator javaFileGenerator,
String target, String featureName) {
Assert.notNull(javaFileGenerator, "'javaFileGenerator' must not be null");
Assert.hasLength(target, "'target' must not be empty");
Assert.hasLength(featureName, "'featureName' must not be empty");
Owner owner = new Owner(javaFileGenerator, target, featureName);
return this.classes.computeIfAbsent(owner,
key -> new GeneratedClass(javaFileGenerator,
this.classNameGenerator.generateClassName(target, featureName)));
}
/**
* Write generated Spring {@code .factories} files to the given
* {@link GeneratedFiles} instance.

View File

@ -35,18 +35,9 @@ class ClassNameGeneratorTests {
private final ClassNameGenerator generator = new ClassNameGenerator();
@Test
void generateClassNameWhenTargetClassIsNullThrowsException() {
assertThatIllegalArgumentException()
.isThrownBy(
() -> this.generator.generateClassName((Class<?>) null, "Test"))
.withMessage("'target' must not be null");
}
@Test
void generateClassNameWhenTargetStringIsEmptyThrowsException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.generator.generateClassName("", "Test"))
.withMessage("'target' must not be empty");
void generateClassNameWhenTargetClassIsNullUsesAotPackage() {
ClassName generated = this.generator.generateClassName((Class<?>) null, "test");
assertThat(generated).hasToString("__.Test");
}
@Test
@ -58,17 +49,12 @@ class ClassNameGeneratorTests {
@Test
void generatedClassNameWhenFeatureIsNotAllLettersThrowsException() {
String expectedMessage = "'featureName' must contain only letters";
assertThatIllegalArgumentException().isThrownBy(
() -> this.generator.generateClassName(InputStream.class, "noway!"))
.withMessage(expectedMessage);
assertThatIllegalArgumentException().isThrownBy(
() -> this.generator.generateClassName(InputStream.class, "1WontWork"))
.withMessage(expectedMessage);
assertThatIllegalArgumentException()
.isThrownBy(
() -> this.generator.generateClassName(InputStream.class, "N0pe"))
.withMessage(expectedMessage);
assertThat(this.generator.generateClassName(InputStream.class, "name!"))
.hasToString("java.io.InputStream__Name");
assertThat(this.generator.generateClassName(InputStream.class, "1NameHere"))
.hasToString("java.io.InputStream__NameHere");
assertThat(this.generator.generateClassName(InputStream.class, "Y0pe"))
.hasToString("java.io.InputStream__YPe");
}
@Test
@ -80,38 +66,21 @@ class ClassNameGeneratorTests {
@Test
void generateClassNameWithClassWhenInnerClassGeneratesName() {
ClassName generated = this.generator.generateClassName(TestBean.class,
"EventListener");
assertThat(generated).hasToString(
"org.springframework.aot.generate.ClassNameGeneratorTests_TestBean__EventListener");
ClassName generated = this.generator.generateClassName(TestBean.class, "EventListener");
assertThat(generated)
.hasToString("org.springframework.aot.generate.ClassNameGeneratorTests_TestBean__EventListener");
}
@Test
void generateClassWithClassWhenMultipleCallsGeneratesSequencedName() {
ClassName generated1 = this.generator.generateClassName(InputStream.class,
"bytes");
ClassName generated2 = this.generator.generateClassName(InputStream.class,
"bytes");
ClassName generated3 = this.generator.generateClassName(InputStream.class,
"bytes");
ClassName generated1 = this.generator.generateClassName(InputStream.class, "bytes");
ClassName generated2 = this.generator.generateClassName(InputStream.class, "bytes");
ClassName generated3 = this.generator.generateClassName(InputStream.class, "bytes");
assertThat(generated1).hasToString("java.io.InputStream__Bytes");
assertThat(generated2).hasToString("java.io.InputStream__Bytes1");
assertThat(generated3).hasToString("java.io.InputStream__Bytes2");
}
@Test
void generateClassNameWithStringGeneratesNameUsingOnlyLetters() {
ClassName generated = this.generator.generateClassName("my-bean--factoryStuff",
"beans");
assertThat(generated).hasToString("__.MyBeanFactoryStuff__Beans");
}
@Test
void generateClassNameWithStringWhenNoLettersGeneratesAotName() {
ClassName generated = this.generator.generateClassName("1234!@#", "beans");
assertThat(generated).hasToString("__.Aot__Beans");
}
static class TestBean {
}

View File

@ -68,36 +68,12 @@ class GeneratedClassesTests {
.withMessage("'featureName' must not be empty");
}
@Test
void getOrGenerateWithStringTargetWhenJavaFileGeneratorIsNullThrowsException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.generatedClasses.getOrGenerateClass(null,
TestTarget.class.getName(), "test"))
.withMessage("'javaFileGenerator' must not be null");
}
@Test
void getOrGenerateWithStringTargetWhenTargetIsNullThrowsException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.generatedClasses
.getOrGenerateClass(JAVA_FILE_GENERATOR, (String) null, "test"))
.withMessage("'target' must not be empty");
}
@Test
void getOrGenerateWithStringTargetWhenFeatureIsNullThrowsException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.generatedClasses.getOrGenerateClass(
JAVA_FILE_GENERATOR, TestTarget.class.getName(), null))
.withMessage("'featureName' must not be empty");
}
@Test
void getOrGenerateWhenNewReturnsGeneratedMethod() {
GeneratedClass generatedClass1 = this.generatedClasses
.getOrGenerateClass(JAVA_FILE_GENERATOR, TestTarget.class, "one");
GeneratedClass generatedClass2 = this.generatedClasses.getOrGenerateClass(
JAVA_FILE_GENERATOR, TestTarget.class.getName(), "two");
GeneratedClass generatedClass2 = this.generatedClasses
.getOrGenerateClass(JAVA_FILE_GENERATOR, TestTarget.class, "two");
assertThat(generatedClass1).isNotNull().isNotEqualTo(generatedClass2);
assertThat(generatedClass2).isNotNull();
}
@ -110,7 +86,7 @@ class GeneratedClassesTests {
GeneratedClass generatedClass2 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR,
TestTarget.class, "one");
GeneratedClass generatedClass3 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR,
TestTarget.class.getName(), "one");
TestTarget.class, "one");
GeneratedClass generatedClass4 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR,
TestTarget.class, "two");
assertThat(generatedClass1).isNotNull().isSameAs(generatedClass2)