From e88b70e6ad09ab7565bedded8d22eddc026f147a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Nicoll?= Date: Mon, 30 Jun 2025 12:38:42 +0200 Subject: [PATCH] Slice bean registrations in separate file if necessary The compiler has a constants pool limit of 65536 entries per source file which can be hit with a very large amount of beans to register in the bean factory. This commit makes sure to create separate source files if the number of beans to register is very large. The main generated source file delegate to those. Closes gh-35044 --- .../aot/BeanRegistrationsAotContribution.java | 131 +++++++++++++----- ...BeanRegistrationsAotContributionTests.java | 69 +++++++++ 2 files changed, 166 insertions(+), 34 deletions(-) diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java index 66a0394264..0f09a91f68 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java @@ -17,6 +17,7 @@ package org.springframework.beans.factory.aot; import java.util.List; +import java.util.function.BiConsumer; import javax.lang.model.element.Modifier; @@ -52,6 +53,10 @@ class BeanRegistrationsAotContribution private static final String BEAN_FACTORY_PARAMETER_NAME = "beanFactory"; + private static final int MAX_REGISTRATIONS_PER_FILE = 5000; + + private static final int MAX_REGISTRATIONS_PER_METHOD = 1000; + private static final ArgumentCodeGenerator argumentCodeGenerator = ArgumentCodeGenerator .of(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME); @@ -67,14 +72,10 @@ class BeanRegistrationsAotContribution public void applyTo(GenerationContext generationContext, BeanFactoryInitializationCode beanFactoryInitializationCode) { - GeneratedClass generatedClass = generationContext.getGeneratedClasses() - .addForFeature("BeanFactoryRegistrations", type -> { - type.addJavadoc("Register bean definitions for the bean factory."); - type.addModifiers(Modifier.PUBLIC); - }); + GeneratedClass generatedClass = createBeanFactoryRegistrationClass(generationContext); BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator(generatedClass); - GeneratedMethod generatedBeanDefinitionsMethod = new BeanDefinitionsRegistrationGenerator( - generationContext, codeGenerator, this.registrations).generateRegisterBeanDefinitionsMethod(); + GeneratedMethod generatedBeanDefinitionsMethod = generateBeanRegistrationCode(generationContext, + generatedClass, codeGenerator); beanFactoryInitializationCode.addInitializer(generatedBeanDefinitionsMethod.toMethodReference()); GeneratedMethod generatedAliasesMethod = codeGenerator.getMethods().add("registerAliases", this::generateRegisterAliasesMethod); @@ -82,6 +83,48 @@ class BeanRegistrationsAotContribution generateRegisterHints(generationContext.getRuntimeHints(), this.registrations); } + private GeneratedMethod generateBeanRegistrationCode(GenerationContext generationContext, GeneratedClass mainGeneratedClass, BeanRegistrationsCodeGenerator mainCodeGenerator) { + if (this.registrations.size() < MAX_REGISTRATIONS_PER_FILE) { + return generateBeanRegistrationClass(generationContext, mainCodeGenerator, 0, this.registrations.size()); + } + else { + return mainGeneratedClass.getMethods().add("registerBeanDefinitions", method -> { + method.addJavadoc("Register the bean definitions."); + method.addModifiers(Modifier.PUBLIC); + method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME); + CodeBlock.Builder body = CodeBlock.builder(); + Registration.doWithSlice(this.registrations, MAX_REGISTRATIONS_PER_FILE, (start, end) -> { + GeneratedClass sliceGeneratedClass = createBeanFactoryRegistrationClass(generationContext); + BeanRegistrationsCodeGenerator sliceCodeGenerator = new BeanRegistrationsCodeGenerator(sliceGeneratedClass); + GeneratedMethod generatedMethod = generateBeanRegistrationClass(generationContext, sliceCodeGenerator, start, end); + body.addStatement(generatedMethod.toMethodReference().toInvokeCodeBlock(argumentCodeGenerator)); + }); + method.addCode(body.build()); + }); + } + } + + private GeneratedMethod generateBeanRegistrationClass(GenerationContext generationContext, + BeanRegistrationsCodeGenerator codeGenerator, int start, int end) { + + return codeGenerator.getMethods().add("registerBeanDefinitions", method -> { + method.addJavadoc("Register the bean definitions."); + method.addModifiers(Modifier.PUBLIC); + method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME); + List sliceRegistrations = this.registrations.subList(start, end); + new BeanDefinitionsRegistrationGenerator( + generationContext, codeGenerator, sliceRegistrations, start).generateBeanRegistrationsCode(method); + }); + } + + private static GeneratedClass createBeanFactoryRegistrationClass(GenerationContext generationContext) { + return generationContext.getGeneratedClasses() + .addForFeature("BeanFactoryRegistrations", type -> { + type.addJavadoc("Register bean definitions for the bean factory."); + type.addModifiers(Modifier.PUBLIC); + }); + } + private void generateRegisterAliasesMethod(MethodSpec.Builder method) { method.addJavadoc("Register the aliases."); method.addModifiers(Modifier.PUBLIC); @@ -117,6 +160,28 @@ class BeanRegistrationsAotContribution return this.registeredBean.getBeanName(); } + /** + * Invoke an action for each slice of the given {@code registrations}. The + * {@code action} is invoked for each slice with the start and end index of the + * given list of registrations. Elements to process can be retrieved using + * {@link List#subList(int, int)}. + * @param registrations the registrations to process + * @param sliceSize the size of a slice + * @param action the action to invoke for each slice + */ + static void doWithSlice(List registrations, int sliceSize, + BiConsumer action) { + + int index = 0; + int end = 0; + while (end < registrations.size()) { + int start = index * sliceSize; + end = Math.min(start + sliceSize, registrations.size()); + action.accept(start, end); + index++; + } + } + } @@ -144,6 +209,10 @@ class BeanRegistrationsAotContribution } + /** + * Generate code for bean registrations. Limited to {@value #MAX_REGISTRATIONS_PER_METHOD} + * beans per method to avoid hitting a limit. + */ static final class BeanDefinitionsRegistrationGenerator { private final GenerationContext generationContext; @@ -152,44 +221,38 @@ class BeanRegistrationsAotContribution private final List registrations; + private final int globalStart; + BeanDefinitionsRegistrationGenerator(GenerationContext generationContext, - BeanRegistrationsCodeGenerator codeGenerator, List registrations) { + BeanRegistrationsCodeGenerator codeGenerator, List registrations, int globalStart) { this.generationContext = generationContext; this.codeGenerator = codeGenerator; this.registrations = registrations; + this.globalStart = globalStart; } - - GeneratedMethod generateRegisterBeanDefinitionsMethod() { - return this.codeGenerator.getMethods().add("registerBeanDefinitions", method -> { - method.addJavadoc("Register the bean definitions."); - method.addModifiers(Modifier.PUBLIC); - method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME); - if (this.registrations.size() <= 1000) { - generateRegisterBeanDefinitionMethods(method, this.registrations); - } - else { - Builder code = CodeBlock.builder(); - code.add("// Registration is sliced to avoid exceeding size limit\n"); - int index = 0; - int end = 0; - while (end < this.registrations.size()) { - int start = index * 1000; - end = Math.min(start + 1000, this.registrations.size()); - GeneratedMethod sliceMethod = generateSliceMethod(start, end); - code.addStatement(sliceMethod.toMethodReference().toInvokeCodeBlock( - argumentCodeGenerator, this.codeGenerator.getClassName())); - index++; - } - method.addCode(code.build()); - } - }); + void generateBeanRegistrationsCode(MethodSpec.Builder method) { + if (this.registrations.size() <= 1000) { + generateRegisterBeanDefinitionMethods(method, this.registrations); + } + else { + Builder code = CodeBlock.builder(); + code.add("// Registration is sliced to avoid exceeding size limit\n"); + Registration.doWithSlice(this.registrations, MAX_REGISTRATIONS_PER_METHOD, + (start, end) -> { + GeneratedMethod sliceMethod = generateSliceMethod(start, end); + code.addStatement(sliceMethod.toMethodReference().toInvokeCodeBlock( + argumentCodeGenerator, this.codeGenerator.getClassName())); + }); + method.addCode(code.build()); + } } private GeneratedMethod generateSliceMethod(int start, int end) { - String description = "Register the bean definitions from %s to %s.".formatted(start, end - 1); + String description = "Register the bean definitions from %s to %s." + .formatted(this.globalStart + start, this.globalStart + end - 1); List slice = this.registrations.subList(start, end); return this.codeGenerator.getMethods().add("registerBeanDefinitions", method -> { method.addJavadoc(description); diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java index fb20052d61..2647f3d9c6 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; +import java.util.stream.Stream; import javax.lang.model.element.Modifier; @@ -54,6 +55,8 @@ import org.springframework.javapoet.ParameterizedTypeName; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.BDDMockito.then; +import static org.mockito.Mockito.mock; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; /** @@ -230,6 +233,40 @@ class BeanRegistrationsAotContributionTests { }); } + @Test + void doWithSliceWithOnlyLessThanOneSlice() { + List registration = Stream.generate(() -> mock(Registration.class)).limit(10).toList(); + BiConsumer sliceAction = mockSliceAction(); + Registration.doWithSlice(registration, 20, sliceAction); + then(sliceAction).should().accept(0, 10); + then(sliceAction).shouldHaveNoMoreInteractions(); + } + + @Test + void doWithSliceWithOnlyOneExactSlice() { + List registration = Stream.generate(() -> mock(Registration.class)).limit(20).toList(); + BiConsumer sliceAction = mockSliceAction(); + Registration.doWithSlice(registration, 20, sliceAction); + then(sliceAction).should().accept(0, 20); + then(sliceAction).shouldHaveNoMoreInteractions(); + } + + @Test + void doWithSeveralSlices() { + List registration = Stream.generate(() -> mock(Registration.class)).limit(20).toList(); + BiConsumer sliceAction = mockSliceAction(); + Registration.doWithSlice(registration, 7, sliceAction); + then(sliceAction).should().accept(0, 7); + then(sliceAction).should().accept(7, 14); + then(sliceAction).should().accept(14, 20); + then(sliceAction).shouldHaveNoMoreInteractions(); + } + + @SuppressWarnings("unchecked") + BiConsumer mockSliceAction() { + return mock(BiConsumer.class); + } + @Test void applyToWithLargeBeanDefinitionsCreatesSlices() { BeanRegistrationsAotContribution contribution = createContribution(1001, i -> "testBean" + i); @@ -250,6 +287,38 @@ class BeanRegistrationsAotContributionTests { }); } + @Test + void applyToWithVeryLargeBeanDefinitionsCreatesSeparateSourceFiles() { + BeanRegistrationsAotContribution contribution = createContribution(10001, i -> "testBean" + i); + contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); + compile((consumer, compiled) -> { + assertThat(compiled.getSourceFile(".*BeanFactoryRegistrations1")) + .contains("Register the bean definitions from 0 to 999.", + "Register the bean definitions from 1000 to 1999.", + "Register the bean definitions from 2000 to 2999.", + "Register the bean definitions from 3000 to 3999.", + "Register the bean definitions from 4000 to 4999.", + "// Registration is sliced to avoid exceeding size limit"); + assertThat(compiled.getSourceFile(".*BeanFactoryRegistrations2")) + .contains("Register the bean definitions from 5000 to 5999.", + "Register the bean definitions from 6000 to 6999.", + "Register the bean definitions from 7000 to 7999.", + "Register the bean definitions from 8000 to 8999.", + "Register the bean definitions from 9000 to 9999.", + "// Registration is sliced to avoid exceeding size limit"); + assertThat(compiled.getSourceFile(".*BeanFactoryRegistrations3")) + .doesNotContain("// Registration is sliced to avoid exceeding size limit"); + DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); + consumer.accept(freshBeanFactory); + for (int i = 0; i < 10001; i++) { + String beanName = "testBean" + i; + assertThat(freshBeanFactory.containsBeanDefinition(beanName)).isTrue(); + assertThat(freshBeanFactory.getBean(beanName)).isInstanceOf(TestBean.class); + } + assertThat(freshBeanFactory.getBeansOfType(TestBean.class)).hasSize(10001); + }); + } + private BeanRegistrationsAotContribution createContribution(int size, Function beanNameFactory) { List registrations = new ArrayList<>(); for (int i = 0; i < size; i++) {