diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java index 66a0394264..0f09a91f68 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java @@ -17,6 +17,7 @@ package org.springframework.beans.factory.aot; import java.util.List; +import java.util.function.BiConsumer; import javax.lang.model.element.Modifier; @@ -52,6 +53,10 @@ class BeanRegistrationsAotContribution private static final String BEAN_FACTORY_PARAMETER_NAME = "beanFactory"; + private static final int MAX_REGISTRATIONS_PER_FILE = 5000; + + private static final int MAX_REGISTRATIONS_PER_METHOD = 1000; + private static final ArgumentCodeGenerator argumentCodeGenerator = ArgumentCodeGenerator .of(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME); @@ -67,14 +72,10 @@ class BeanRegistrationsAotContribution public void applyTo(GenerationContext generationContext, BeanFactoryInitializationCode beanFactoryInitializationCode) { - GeneratedClass generatedClass = generationContext.getGeneratedClasses() - .addForFeature("BeanFactoryRegistrations", type -> { - type.addJavadoc("Register bean definitions for the bean factory."); - type.addModifiers(Modifier.PUBLIC); - }); + GeneratedClass generatedClass = createBeanFactoryRegistrationClass(generationContext); BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator(generatedClass); - GeneratedMethod generatedBeanDefinitionsMethod = new BeanDefinitionsRegistrationGenerator( - generationContext, codeGenerator, this.registrations).generateRegisterBeanDefinitionsMethod(); + GeneratedMethod generatedBeanDefinitionsMethod = generateBeanRegistrationCode(generationContext, + generatedClass, codeGenerator); beanFactoryInitializationCode.addInitializer(generatedBeanDefinitionsMethod.toMethodReference()); GeneratedMethod generatedAliasesMethod = codeGenerator.getMethods().add("registerAliases", this::generateRegisterAliasesMethod); @@ -82,6 +83,48 @@ class BeanRegistrationsAotContribution generateRegisterHints(generationContext.getRuntimeHints(), this.registrations); } + private GeneratedMethod generateBeanRegistrationCode(GenerationContext generationContext, GeneratedClass mainGeneratedClass, BeanRegistrationsCodeGenerator mainCodeGenerator) { + if (this.registrations.size() < MAX_REGISTRATIONS_PER_FILE) { + return generateBeanRegistrationClass(generationContext, mainCodeGenerator, 0, this.registrations.size()); + } + else { + return mainGeneratedClass.getMethods().add("registerBeanDefinitions", method -> { + method.addJavadoc("Register the bean definitions."); + method.addModifiers(Modifier.PUBLIC); + method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME); + CodeBlock.Builder body = CodeBlock.builder(); + Registration.doWithSlice(this.registrations, MAX_REGISTRATIONS_PER_FILE, (start, end) -> { + GeneratedClass sliceGeneratedClass = createBeanFactoryRegistrationClass(generationContext); + BeanRegistrationsCodeGenerator sliceCodeGenerator = new BeanRegistrationsCodeGenerator(sliceGeneratedClass); + GeneratedMethod generatedMethod = generateBeanRegistrationClass(generationContext, sliceCodeGenerator, start, end); + body.addStatement(generatedMethod.toMethodReference().toInvokeCodeBlock(argumentCodeGenerator)); + }); + method.addCode(body.build()); + }); + } + } + + private GeneratedMethod generateBeanRegistrationClass(GenerationContext generationContext, + BeanRegistrationsCodeGenerator codeGenerator, int start, int end) { + + return codeGenerator.getMethods().add("registerBeanDefinitions", method -> { + method.addJavadoc("Register the bean definitions."); + method.addModifiers(Modifier.PUBLIC); + method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME); + List sliceRegistrations = this.registrations.subList(start, end); + new BeanDefinitionsRegistrationGenerator( + generationContext, codeGenerator, sliceRegistrations, start).generateBeanRegistrationsCode(method); + }); + } + + private static GeneratedClass createBeanFactoryRegistrationClass(GenerationContext generationContext) { + return generationContext.getGeneratedClasses() + .addForFeature("BeanFactoryRegistrations", type -> { + type.addJavadoc("Register bean definitions for the bean factory."); + type.addModifiers(Modifier.PUBLIC); + }); + } + private void generateRegisterAliasesMethod(MethodSpec.Builder method) { method.addJavadoc("Register the aliases."); method.addModifiers(Modifier.PUBLIC); @@ -117,6 +160,28 @@ class BeanRegistrationsAotContribution return this.registeredBean.getBeanName(); } + /** + * Invoke an action for each slice of the given {@code registrations}. The + * {@code action} is invoked for each slice with the start and end index of the + * given list of registrations. Elements to process can be retrieved using + * {@link List#subList(int, int)}. + * @param registrations the registrations to process + * @param sliceSize the size of a slice + * @param action the action to invoke for each slice + */ + static void doWithSlice(List registrations, int sliceSize, + BiConsumer action) { + + int index = 0; + int end = 0; + while (end < registrations.size()) { + int start = index * sliceSize; + end = Math.min(start + sliceSize, registrations.size()); + action.accept(start, end); + index++; + } + } + } @@ -144,6 +209,10 @@ class BeanRegistrationsAotContribution } + /** + * Generate code for bean registrations. Limited to {@value #MAX_REGISTRATIONS_PER_METHOD} + * beans per method to avoid hitting a limit. + */ static final class BeanDefinitionsRegistrationGenerator { private final GenerationContext generationContext; @@ -152,44 +221,38 @@ class BeanRegistrationsAotContribution private final List registrations; + private final int globalStart; + BeanDefinitionsRegistrationGenerator(GenerationContext generationContext, - BeanRegistrationsCodeGenerator codeGenerator, List registrations) { + BeanRegistrationsCodeGenerator codeGenerator, List registrations, int globalStart) { this.generationContext = generationContext; this.codeGenerator = codeGenerator; this.registrations = registrations; + this.globalStart = globalStart; } - - GeneratedMethod generateRegisterBeanDefinitionsMethod() { - return this.codeGenerator.getMethods().add("registerBeanDefinitions", method -> { - method.addJavadoc("Register the bean definitions."); - method.addModifiers(Modifier.PUBLIC); - method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME); - if (this.registrations.size() <= 1000) { - generateRegisterBeanDefinitionMethods(method, this.registrations); - } - else { - Builder code = CodeBlock.builder(); - code.add("// Registration is sliced to avoid exceeding size limit\n"); - int index = 0; - int end = 0; - while (end < this.registrations.size()) { - int start = index * 1000; - end = Math.min(start + 1000, this.registrations.size()); - GeneratedMethod sliceMethod = generateSliceMethod(start, end); - code.addStatement(sliceMethod.toMethodReference().toInvokeCodeBlock( - argumentCodeGenerator, this.codeGenerator.getClassName())); - index++; - } - method.addCode(code.build()); - } - }); + void generateBeanRegistrationsCode(MethodSpec.Builder method) { + if (this.registrations.size() <= 1000) { + generateRegisterBeanDefinitionMethods(method, this.registrations); + } + else { + Builder code = CodeBlock.builder(); + code.add("// Registration is sliced to avoid exceeding size limit\n"); + Registration.doWithSlice(this.registrations, MAX_REGISTRATIONS_PER_METHOD, + (start, end) -> { + GeneratedMethod sliceMethod = generateSliceMethod(start, end); + code.addStatement(sliceMethod.toMethodReference().toInvokeCodeBlock( + argumentCodeGenerator, this.codeGenerator.getClassName())); + }); + method.addCode(code.build()); + } } private GeneratedMethod generateSliceMethod(int start, int end) { - String description = "Register the bean definitions from %s to %s.".formatted(start, end - 1); + String description = "Register the bean definitions from %s to %s." + .formatted(this.globalStart + start, this.globalStart + end - 1); List slice = this.registrations.subList(start, end); return this.codeGenerator.getMethods().add("registerBeanDefinitions", method -> { method.addJavadoc(description); diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java index fb20052d61..2647f3d9c6 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; +import java.util.stream.Stream; import javax.lang.model.element.Modifier; @@ -54,6 +55,8 @@ import org.springframework.javapoet.ParameterizedTypeName; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.BDDMockito.then; +import static org.mockito.Mockito.mock; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; /** @@ -230,6 +233,40 @@ class BeanRegistrationsAotContributionTests { }); } + @Test + void doWithSliceWithOnlyLessThanOneSlice() { + List registration = Stream.generate(() -> mock(Registration.class)).limit(10).toList(); + BiConsumer sliceAction = mockSliceAction(); + Registration.doWithSlice(registration, 20, sliceAction); + then(sliceAction).should().accept(0, 10); + then(sliceAction).shouldHaveNoMoreInteractions(); + } + + @Test + void doWithSliceWithOnlyOneExactSlice() { + List registration = Stream.generate(() -> mock(Registration.class)).limit(20).toList(); + BiConsumer sliceAction = mockSliceAction(); + Registration.doWithSlice(registration, 20, sliceAction); + then(sliceAction).should().accept(0, 20); + then(sliceAction).shouldHaveNoMoreInteractions(); + } + + @Test + void doWithSeveralSlices() { + List registration = Stream.generate(() -> mock(Registration.class)).limit(20).toList(); + BiConsumer sliceAction = mockSliceAction(); + Registration.doWithSlice(registration, 7, sliceAction); + then(sliceAction).should().accept(0, 7); + then(sliceAction).should().accept(7, 14); + then(sliceAction).should().accept(14, 20); + then(sliceAction).shouldHaveNoMoreInteractions(); + } + + @SuppressWarnings("unchecked") + BiConsumer mockSliceAction() { + return mock(BiConsumer.class); + } + @Test void applyToWithLargeBeanDefinitionsCreatesSlices() { BeanRegistrationsAotContribution contribution = createContribution(1001, i -> "testBean" + i); @@ -250,6 +287,38 @@ class BeanRegistrationsAotContributionTests { }); } + @Test + void applyToWithVeryLargeBeanDefinitionsCreatesSeparateSourceFiles() { + BeanRegistrationsAotContribution contribution = createContribution(10001, i -> "testBean" + i); + contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); + compile((consumer, compiled) -> { + assertThat(compiled.getSourceFile(".*BeanFactoryRegistrations1")) + .contains("Register the bean definitions from 0 to 999.", + "Register the bean definitions from 1000 to 1999.", + "Register the bean definitions from 2000 to 2999.", + "Register the bean definitions from 3000 to 3999.", + "Register the bean definitions from 4000 to 4999.", + "// Registration is sliced to avoid exceeding size limit"); + assertThat(compiled.getSourceFile(".*BeanFactoryRegistrations2")) + .contains("Register the bean definitions from 5000 to 5999.", + "Register the bean definitions from 6000 to 6999.", + "Register the bean definitions from 7000 to 7999.", + "Register the bean definitions from 8000 to 8999.", + "Register the bean definitions from 9000 to 9999.", + "// Registration is sliced to avoid exceeding size limit"); + assertThat(compiled.getSourceFile(".*BeanFactoryRegistrations3")) + .doesNotContain("// Registration is sliced to avoid exceeding size limit"); + DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); + consumer.accept(freshBeanFactory); + for (int i = 0; i < 10001; i++) { + String beanName = "testBean" + i; + assertThat(freshBeanFactory.containsBeanDefinition(beanName)).isTrue(); + assertThat(freshBeanFactory.getBean(beanName)).isInstanceOf(TestBean.class); + } + assertThat(freshBeanFactory.getBeansOfType(TestBean.class)).hasSize(10001); + }); + } + private BeanRegistrationsAotContribution createContribution(int size, Function beanNameFactory) { List registrations = new ArrayList<>(); for (int i = 0; i < size; i++) {