Avoid code too large with AOT processing

This commit adapts code generation to "slice" the registration of bean
definitions in separate bean methods rather than a unique method for
all of them.

If the bean factory has more than a thousand bean, a method is created
for each slice of 1000 bean definitions.

Closes gh-33126
This commit is contained in:
Stéphane Nicoll 2024-07-17 16:02:41 +02:00
parent 48dead4017
commit 30a64d6a0b
2 changed files with 145 additions and 29 deletions

View File

@ -33,6 +33,7 @@ import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.javapoet.MethodSpec;
/**
@ -51,6 +52,9 @@ class BeanRegistrationsAotContribution
private static final String BEAN_FACTORY_PARAMETER_NAME = "beanFactory";
private static final ArgumentCodeGenerator argumentCodeGenerator = ArgumentCodeGenerator
.of(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME);
private final List<Registration> registrations;
@ -69,8 +73,8 @@ class BeanRegistrationsAotContribution
type.addModifiers(Modifier.PUBLIC);
});
BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator(generatedClass);
GeneratedMethod generatedBeanDefinitionsMethod = codeGenerator.getMethods().add("registerBeanDefinitions", method ->
generateRegisterBeanDefinitionsMethod(method, generationContext, codeGenerator));
GeneratedMethod generatedBeanDefinitionsMethod = new BeanDefinitionsRegistrationGenerator(
generationContext, codeGenerator, this.registrations).generateRegisterBeanDefinitionsMethod();
beanFactoryInitializationCode.addInitializer(generatedBeanDefinitionsMethod.toMethodReference());
GeneratedMethod generatedAliasesMethod = codeGenerator.getMethods().add("registerAliases",
this::generateRegisterAliasesMethod);
@ -78,33 +82,6 @@ class BeanRegistrationsAotContribution
generateRegisterHints(generationContext.getRuntimeHints(), this.registrations);
}
private void generateRegisterBeanDefinitionsMethod(MethodSpec.Builder method,
GenerationContext generationContext, BeanRegistrationsCode beanRegistrationsCode) {
method.addJavadoc("Register the bean definitions.");
method.addModifiers(Modifier.PUBLIC);
method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME);
CodeBlock.Builder code = CodeBlock.builder();
this.registrations.forEach(registration -> {
try {
MethodReference beanDefinitionMethod = registration.methodGenerator
.generateBeanDefinitionMethod(generationContext, beanRegistrationsCode);
CodeBlock methodInvocation = beanDefinitionMethod.toInvokeCodeBlock(
ArgumentCodeGenerator.none(), beanRegistrationsCode.getClassName());
code.addStatement("$L.registerBeanDefinition($S, $L)",
BEAN_FACTORY_PARAMETER_NAME, registration.beanName(), methodInvocation);
}
catch (AotException ex) {
throw ex;
}
catch (Exception ex) {
throw new AotBeanProcessingException(registration.registeredBean,
"failed to generate code for bean definition", ex);
}
});
method.addCode(code.build());
}
private void generateRegisterAliasesMethod(MethodSpec.Builder method) {
method.addJavadoc("Register the aliases.");
method.addModifiers(Modifier.PUBLIC);
@ -167,4 +144,89 @@ class BeanRegistrationsAotContribution
}
static final class BeanDefinitionsRegistrationGenerator {
private final GenerationContext generationContext;
private final BeanRegistrationsCodeGenerator codeGenerator;
private final List<Registration> registrations;
BeanDefinitionsRegistrationGenerator(GenerationContext generationContext,
BeanRegistrationsCodeGenerator codeGenerator, List<Registration> registrations) {
this.generationContext = generationContext;
this.codeGenerator = codeGenerator;
this.registrations = registrations;
}
GeneratedMethod generateRegisterBeanDefinitionsMethod() {
return this.codeGenerator.getMethods().add("registerBeanDefinitions", method -> {
method.addJavadoc("Register the bean definitions.");
method.addModifiers(Modifier.PUBLIC);
method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME);
if (this.registrations.size() <= 1000) {
generateRegisterBeanDefinitionMethods(method, this.registrations);
}
else {
Builder code = CodeBlock.builder();
code.add("// Registration is sliced to avoid exceeding size limit\n");
int index = 0;
int end = 0;
while (end < this.registrations.size()) {
int start = index * 1000;
end = Math.min(start + 1000, this.registrations.size());
GeneratedMethod sliceMethod = generateSliceMethod(start, end);
code.addStatement(sliceMethod.toMethodReference().toInvokeCodeBlock(
argumentCodeGenerator, this.codeGenerator.getClassName()));
index++;
}
method.addCode(code.build());
}
});
}
private GeneratedMethod generateSliceMethod(int start, int end) {
String description = "Register the bean definitions from %s to %s.".formatted(start, end - 1);
List<Registration> slice = this.registrations.subList(start, end);
return this.codeGenerator.getMethods().add("registerBeanDefinitions", method -> {
method.addJavadoc(description);
method.addModifiers(Modifier.PRIVATE);
method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME);
generateRegisterBeanDefinitionMethods(method, slice);
});
}
private void generateRegisterBeanDefinitionMethods(MethodSpec.Builder method,
Iterable<Registration> registrations) {
CodeBlock.Builder code = CodeBlock.builder();
registrations.forEach(registration -> {
try {
CodeBlock methodInvocation = generateBeanRegistration(registration);
code.addStatement("$L.registerBeanDefinition($S, $L)",
BEAN_FACTORY_PARAMETER_NAME, registration.beanName(), methodInvocation);
}
catch (AotException ex) {
throw ex;
}
catch (Exception ex) {
throw new AotBeanProcessingException(registration.registeredBean,
"failed to generate code for bean definition", ex);
}
});
method.addCode(code.build());
}
private CodeBlock generateBeanRegistration(Registration registration) {
MethodReference beanDefinitionMethod = registration.methodGenerator
.generateBeanDefinitionMethod(this.generationContext, this.codeGenerator);
return beanDefinitionMethod.toInvokeCodeBlock(
ArgumentCodeGenerator.none(), this.codeGenerator.getClassName());
}
}
}

View File

@ -20,6 +20,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.lang.model.element.Modifier;
@ -210,6 +211,59 @@ class BeanRegistrationsAotContributionTests {
.havingCause().isInstanceOf(IllegalStateException.class).withMessage("Test exception");
}
@Test
void applyToWithLessThanAThousandBeanDefinitionsDoesNotCreateSlices() {
BeanRegistrationsAotContribution contribution = createContribution(999, i -> "testBean" + i);
contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode);
compile((consumer, compiled) -> {
assertThat(compiled.getSourceFile(".*BeanFactoryRegistrations"))
.doesNotContain("Register the bean definitions from 0 to 999.",
"// Registration is sliced to avoid exceeding size limit");
DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory();
consumer.accept(freshBeanFactory);
for (int i = 0; i < 999; i++) {
String beanName = "testBean" + i;
assertThat(freshBeanFactory.containsBeanDefinition(beanName)).isTrue();
assertThat(freshBeanFactory.getBean(beanName)).isInstanceOf(TestBean.class);
}
assertThat(freshBeanFactory.getBeansOfType(TestBean.class)).hasSize(999);
});
}
@Test
void applyToWithLargeBeanDefinitionsCreatesSlices() {
BeanRegistrationsAotContribution contribution = createContribution(1001, i -> "testBean" + i);
contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode);
compile((consumer, compiled) -> {
assertThat(compiled.getSourceFile(".*BeanFactoryRegistrations"))
.contains("Register the bean definitions from 0 to 999.",
"Register the bean definitions from 1000 to 1000.",
"// Registration is sliced to avoid exceeding size limit");
DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory();
consumer.accept(freshBeanFactory);
for (int i = 0; i < 1001; i++) {
String beanName = "testBean" + i;
assertThat(freshBeanFactory.containsBeanDefinition(beanName)).isTrue();
assertThat(freshBeanFactory.getBean(beanName)).isInstanceOf(TestBean.class);
}
assertThat(freshBeanFactory.getBeansOfType(TestBean.class)).hasSize(1001);
});
}
private BeanRegistrationsAotContribution createContribution(int size, Function<Integer, String> beanNameFactory) {
List<Registration> registrations = new ArrayList<>();
for (int i = 0; i < size; i++) {
String beanName = beanNameFactory.apply(i);
RootBeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class);
this.beanFactory.registerBeanDefinition(beanName, beanDefinition);
RegisteredBean registeredBean = RegisteredBean.of(this.beanFactory, beanName);
BeanDefinitionMethodGenerator methodGenerator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null, List.of());
registrations.add(new Registration(registeredBean, methodGenerator, new String[0]));
}
return new BeanRegistrationsAotContribution(registrations);
}
private RegisteredBean registerBean(RootBeanDefinition rootBeanDefinition) {
String beanName = "testBean";
this.beanFactory.registerBeanDefinition(beanName, rootBeanDefinition);