diff --git a/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java b/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java index a86805dd3e7..77873b7b36d 100644 --- a/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java +++ b/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java @@ -152,30 +152,30 @@ class ScopedProxyBeanRegistrationAotProcessor Executable constructorOrFactoryMethod, boolean allowDirectSupplierShortcut) { - GeneratedMethod method = beanRegistrationCode.getMethodGenerator() - .generateMethod("get", "scopedProxyInstance").using(builder -> { + GeneratedMethod generatedMethod = beanRegistrationCode.getMethods() + .add("getScopedProxyInstance", method -> { Class beanClass = this.targetBeanDefinition.getResolvableType() .toClass(); - builder.addJavadoc( + method.addJavadoc( "Create the scoped proxy bean instance for '$L'.", this.registeredBean.getBeanName()); - builder.addModifiers(Modifier.PRIVATE, Modifier.STATIC); - builder.returns(beanClass); - builder.addParameter(RegisteredBean.class, + method.addModifiers(Modifier.PRIVATE, Modifier.STATIC); + method.returns(beanClass); + method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); - builder.addStatement("$T factory = new $T()", + method.addStatement("$T factory = new $T()", ScopedProxyFactoryBean.class, ScopedProxyFactoryBean.class); - builder.addStatement("factory.setTargetBeanName($S)", + method.addStatement("factory.setTargetBeanName($S)", this.targetBeanName); - builder.addStatement( + method.addStatement( "factory.setBeanFactory($L.getBeanFactory())", REGISTERED_BEAN_PARAMETER_NAME); - builder.addStatement("return ($T) factory.getObject()", + method.addStatement("return ($T) factory.getObject()", beanClass); }); return CodeBlock.of("$T.of($T::$L)", InstanceSupplier.class, - beanRegistrationCode.getClassName(), method.getName()); + beanRegistrationCode.getClassName(), generatedMethod.getName()); } } diff --git a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java index d9897fa5881..81ea4fa4f14 100644 --- a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java +++ b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java @@ -16,11 +16,12 @@ package org.springframework.aop.scope; -import java.lang.reflect.Method; import java.util.Properties; import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import javax.lang.model.element.Modifier; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.aop.framework.AopInfrastructureBean; @@ -43,7 +44,9 @@ import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryIn import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolder; import org.springframework.core.ResolvableType; import org.springframework.core.testfixture.aot.generate.TestGenerationContext; -import org.springframework.util.ReflectionUtils; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.ParameterizedTypeName; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -57,25 +60,26 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; */ class ScopedProxyBeanRegistrationAotProcessorTests { - private DefaultListableBeanFactory beanFactory; + private final DefaultListableBeanFactory beanFactory; - private TestBeanRegistrationsAotProcessor processor; + private final TestBeanRegistrationsAotProcessor processor; - private InMemoryGeneratedFiles generatedFiles; + private final InMemoryGeneratedFiles generatedFiles; - private DefaultGenerationContext generationContext; + private final DefaultGenerationContext generationContext; - private MockBeanFactoryInitializationCode beanFactoryInitializationCode; + private final MockBeanFactoryInitializationCode beanFactoryInitializationCode; - @BeforeEach - void setup() { + + ScopedProxyBeanRegistrationAotProcessorTests() { this.beanFactory = new DefaultListableBeanFactory(); this.processor = new TestBeanRegistrationsAotProcessor(); this.generatedFiles = new InMemoryGeneratedFiles(); this.generationContext = new TestGenerationContext(this.generatedFiles); - this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); + this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(this.generationContext); } + @Test void scopedProxyBeanRegistrationAotProcessorIsRegistered() { assertThat(new AotFactoriesLoader(this.beanFactory).load(BeanRegistrationAotProcessor.class)) @@ -87,7 +91,7 @@ class ScopedProxyBeanRegistrationAotProcessorTests { BeanDefinition beanDefinition = BeanDefinitionBuilder .rootBeanDefinition(PropertiesFactoryBean.class).getBeanDefinition(); this.beanFactory.registerBeanDefinition("test", beanDefinition); - testCompile((freshBeanFactory, compiled) -> { + compile((freshBeanFactory, compiled) -> { Object bean = freshBeanFactory.getBean("test"); assertThat(bean).isInstanceOf(Properties.class); }); @@ -98,10 +102,9 @@ class ScopedProxyBeanRegistrationAotProcessorTests { BeanDefinition beanDefinition = BeanDefinitionBuilder .rootBeanDefinition(ScopedProxyFactoryBean.class).getBeanDefinition(); this.beanFactory.registerBeanDefinition("test", beanDefinition); - testCompile((freshBeanFactory, - compiled) -> assertThatExceptionOfType(BeanCreationException.class) - .isThrownBy(() -> freshBeanFactory.getBean("test")) - .withMessageContaining("'targetBeanName' is required")); + compile((freshBeanFactory, compiled) -> + assertThatExceptionOfType(BeanCreationException.class).isThrownBy(() -> + freshBeanFactory.getBean("test")).withMessageContaining("'targetBeanName' is required")); } @Test @@ -111,10 +114,9 @@ class ScopedProxyBeanRegistrationAotProcessorTests { .addPropertyValue("targetBeanName", "testDoesNotExist") .getBeanDefinition(); this.beanFactory.registerBeanDefinition("test", beanDefinition); - testCompile((freshBeanFactory, - compiled) -> assertThatExceptionOfType(BeanCreationException.class) - .isThrownBy(() -> freshBeanFactory.getBean("test")) - .withMessageContaining("No bean named 'testDoesNotExist'")); + compile((freshBeanFactory, compiled) -> + assertThatExceptionOfType(BeanCreationException.class).isThrownBy(() -> + freshBeanFactory.getBean("test")).withMessageContaining("No bean named 'testDoesNotExist'")); } @Test @@ -128,29 +130,32 @@ class ScopedProxyBeanRegistrationAotProcessorTests { .rootBeanDefinition(ScopedProxyFactoryBean.class) .addPropertyValue("targetBeanName", "numberHolder").getBeanDefinition(); this.beanFactory.registerBeanDefinition("test", scopedBean); - testCompile((freshBeanFactory, compiled) -> { + compile((freshBeanFactory, compiled) -> { Object bean = freshBeanFactory.getBean("test"); - assertThat(bean).isNotNull().isInstanceOf(NumberHolder.class) - .isInstanceOf(AopInfrastructureBean.class); + assertThat(bean).isNotNull().isInstanceOf(NumberHolder.class).isInstanceOf(AopInfrastructureBean.class); }); } - private void testCompile(BiConsumer result) { - BeanFactoryInitializationAotContribution contribution = this.processor - .processAheadOfTime(this.beanFactory); + @SuppressWarnings("unchecked") + private void compile(BiConsumer result) { + BeanFactoryInitializationAotContribution contribution = this.processor.processAheadOfTime(this.beanFactory); assertThat(contribution).isNotNull(); contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); + MethodReference methodReference = this.beanFactoryInitializationCode + .getInitializers().get(0); + this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + type.addModifiers(Modifier.PUBLIC); + type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); + type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) + .addParameter(DefaultListableBeanFactory.class, "beanFactory") + .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .build()); + }); this.generationContext.writeGeneratedContent(); TestCompiler.forSystem().withFiles(this.generatedFiles).compile(compiled -> { - MethodReference reference = this.beanFactoryInitializationCode - .getInitializers().get(0); - Object instance = compiled.getInstance(Object.class, - reference.getDeclaringClass().toString()); - Method method = ReflectionUtils.findMethod(instance.getClass(), - reference.getMethodName(), DefaultListableBeanFactory.class); DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); freshBeanFactory.setBeanClassLoader(compiled.getClassLoader()); - ReflectionUtils.invokeMethod(method, instance, freshBeanFactory); + compiled.getInstance(Consumer.class).accept(freshBeanFactory); result.accept(freshBeanFactory, compiled); }); } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java index c75d94aa326..3a0ce718989 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java @@ -42,6 +42,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.aot.generate.AccessVisibility; import org.springframework.aot.generate.GeneratedClass; +import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.ExecutableHint; @@ -81,7 +82,6 @@ import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.core.annotation.MergedAnnotations; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.MethodSpec; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -881,8 +881,6 @@ public class AutowiredAnnotationBeanPostProcessor implements SmartInstantiationA */ private static class AotContribution implements BeanRegistrationAotContribution { - private static final String APPLY_METHOD = "apply"; - private static final String REGISTERED_BEAN_PARAMETER = "registeredBean"; private static final String INSTANCE_PARAMETER = "instance"; @@ -909,27 +907,21 @@ public class AutowiredAnnotationBeanPostProcessor implements SmartInstantiationA public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) { GeneratedClass generatedClass = generationContext.getGeneratedClasses() - .forFeatureComponent("Autowiring", this.target) - .generate(type -> { + .addForFeatureComponent("Autowiring", this.target, type -> { type.addJavadoc("Autowiring for {@link $T}.", this.target); type.addModifiers(javax.lang.model.element.Modifier.PUBLIC); }); - generatedClass.getMethodGenerator().generateMethod(APPLY_METHOD) - .using(generateMethod(generationContext.getRuntimeHints())); - beanRegistrationCode.addInstancePostProcessor( - MethodReference.ofStatic(generatedClass.getName(), APPLY_METHOD)); - } - - private Consumer generateMethod(RuntimeHints hints) { - return method -> { + GeneratedMethod generateMethod = generatedClass.getMethods().add("apply", method -> { method.addJavadoc("Apply the autowiring."); method.addModifiers(javax.lang.model.element.Modifier.PUBLIC, javax.lang.model.element.Modifier.STATIC); method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); method.addParameter(this.target, INSTANCE_PARAMETER); method.returns(this.target); - method.addCode(generateMethodCode(hints)); - }; + method.addCode(generateMethodCode(generationContext.getRuntimeHints())); + }); + beanRegistrationCode.addInstancePostProcessor( + MethodReference.ofStatic(generatedClass.getName(), generateMethod.getName())); } private CodeBlock generateMethodCode(RuntimeHints hints) { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java index ed7682a882d..9ad1b274cb6 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java @@ -23,9 +23,9 @@ import javax.lang.model.element.Modifier; import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.MethodGenerator; -import org.springframework.aot.generate.MethodNameGenerator; +import org.springframework.aot.generate.MethodName; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.RegisteredBean; @@ -41,8 +41,6 @@ import org.springframework.lang.Nullable; */ class BeanDefinitionMethodGenerator { - private static final String FEATURE_NAME = "BeanDefinitions"; - private final BeanDefinitionMethodGeneratorFactory methodGeneratorFactory; private final RegisteredBean registeredBean; @@ -91,23 +89,22 @@ class BeanDefinitionMethodGenerator { this.constructorOrFactoryMethod); if (!target.getName().startsWith("java.")) { GeneratedClass generatedClass = generationContext.getGeneratedClasses() - .forFeatureComponent(FEATURE_NAME, target) - .getOrGenerate(FEATURE_NAME, type -> { + .getOrAddForFeatureComponent("BeanDefinitions", target, type -> { type.addJavadoc("Bean definitions for {@link $T}", target); type.addModifiers(Modifier.PUBLIC); }); - MethodGenerator methodGenerator = generatedClass.getMethodGenerator() - .withName(getName()); + GeneratedMethods generatedMethods = generatedClass.getMethods() + .withPrefix(getName()); GeneratedMethod generatedMethod = generateBeanDefinitionMethod( - generationContext, generatedClass.getName(), methodGenerator, + generationContext, generatedClass.getName(), generatedMethods, codeFragments, Modifier.PUBLIC); return MethodReference.ofStatic(generatedClass.getName(), generatedMethod.getName()); } - MethodGenerator methodGenerator = beanRegistrationsCode.getMethodGenerator() - .withName(getName()); + GeneratedMethods generatedMethods = beanRegistrationsCode.getMethods() + .withPrefix(getName()); GeneratedMethod generatedMethod = generateBeanDefinitionMethod(generationContext, - beanRegistrationsCode.getClassName(), methodGenerator, codeFragments, + beanRegistrationsCode.getClassName(), generatedMethods, codeFragments, Modifier.PRIVATE); return MethodReference.ofStatic(beanRegistrationsCode.getClassName(), generatedMethod.getName()); @@ -126,22 +123,21 @@ class BeanDefinitionMethodGenerator { private GeneratedMethod generateBeanDefinitionMethod( GenerationContext generationContext, ClassName className, - MethodGenerator methodGenerator, BeanRegistrationCodeFragments codeFragments, + GeneratedMethods generatedMethods, BeanRegistrationCodeFragments codeFragments, Modifier modifier) { BeanRegistrationCodeGenerator codeGenerator = new BeanRegistrationCodeGenerator( - className, methodGenerator, this.registeredBean, + className, generatedMethods, this.registeredBean, this.constructorOrFactoryMethod, codeFragments); - GeneratedMethod method = methodGenerator.generateMethod("get", "bean", "definition"); this.aotContributions.forEach(aotContribution -> aotContribution .applyTo(generationContext, codeGenerator)); - return method.using(builder -> { - builder.addJavadoc("Get the $L definition for '$L'", + return generatedMethods.add("getBeanDefinition", method -> { + method.addJavadoc("Get the $L definition for '$L'", (!this.registeredBean.isInnerBean()) ? "bean" : "inner-bean", getName()); - builder.addModifiers(modifier, Modifier.STATIC); - builder.returns(BeanDefinition.class); - builder.addCode(codeGenerator.generateCode(generationContext)); + method.addModifiers(modifier, Modifier.STATIC); + method.returns(BeanDefinition.class); + method.addCode(codeGenerator.generateCode(generationContext)); }); } @@ -156,10 +152,10 @@ class BeanDefinitionMethodGenerator { while (nonGeneratedParent != null && nonGeneratedParent.isGeneratedBeanName()) { nonGeneratedParent = nonGeneratedParent.getParent(); } - return (nonGeneratedParent != null) - ? MethodNameGenerator.join( - getSimpleBeanName(nonGeneratedParent.getBeanName()), "innerBean") - : "innerBean"; + if (nonGeneratedParent != null) { + return MethodName.of(getSimpleBeanName(nonGeneratedParent.getBeanName()), "innerBean").toString(); + } + return "innerBean"; } private String getSimpleBeanName(String beanName) { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java index d186bafe1d9..cc30e470c56 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java @@ -31,7 +31,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; -import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.hint.ExecutableHint; import org.springframework.aot.hint.ExecutableMode; import org.springframework.aot.hint.RuntimeHints; @@ -81,7 +81,7 @@ class BeanDefinitionPropertiesCodeGenerator { private static final String BEAN_DEFINITION_VARIABLE = BeanRegistrationCodeFragments.BEAN_DEFINITION_VARIABLE; - private static final Consumer INVOKE_HINT = hint -> hint.withMode(ExecutableMode.INVOKE); + private static final Consumer INVOKE_HINT = hint -> hint.withMode(ExecutableMode.INVOKE); private static final BeanInfoFactory beanInfoFactory = new ExtendedBeanInfoFactory(); @@ -95,14 +95,14 @@ class BeanDefinitionPropertiesCodeGenerator { BeanDefinitionPropertiesCodeGenerator(RuntimeHints hints, - Predicate attributeFilter, MethodGenerator methodGenerator, + Predicate attributeFilter, GeneratedMethods generatedMethods, BiFunction customValueCodeGenerator) { this.hints = hints; this.attributeFilter = attributeFilter; this.customValueCodeGenerator = customValueCodeGenerator; this.valueCodeGenerator = new BeanDefinitionPropertyValueCodeGenerator( - methodGenerator); + generatedMethods); } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGenerator.java index 58f67d0b890..c2b9e7a4f50 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGenerator.java @@ -30,8 +30,7 @@ import java.util.TreeMap; import java.util.TreeSet; import org.springframework.aot.generate.GeneratedMethod; -import org.springframework.aot.generate.MethodGenerator; -import org.springframework.aot.generate.MethodNameGenerator; +import org.springframework.aot.generate.GeneratedMethods; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanReference; import org.springframework.beans.factory.config.RuntimeBeanReference; @@ -57,7 +56,7 @@ class BeanDefinitionPropertyValueCodeGenerator { static final CodeBlock NULL_VALUE_CODE_BLOCK = CodeBlock.of("null"); - private final MethodGenerator methodGenerator; + private final GeneratedMethods generatedMethods; private final List delegates = List.of( new PrimitiveDelegate(), @@ -76,8 +75,8 @@ class BeanDefinitionPropertyValueCodeGenerator { ); - BeanDefinitionPropertyValueCodeGenerator(MethodGenerator methodGenerator) { - this.methodGenerator = methodGenerator; + BeanDefinitionPropertyValueCodeGenerator(GeneratedMethods generatedMethods) { + this.generatedMethods = generatedMethods; } @@ -485,24 +484,22 @@ class BeanDefinitionPropertyValueCodeGenerator { private CodeBlock generateLinkedHashMapCode(Map map, ResolvableType keyType, ResolvableType valueType) { - GeneratedMethod method = BeanDefinitionPropertyValueCodeGenerator.this.methodGenerator - .generateMethod(MethodNameGenerator.join("get", "map")) - .using(builder -> { - builder.addAnnotation(AnnotationSpec + GeneratedMethod generatedMethod = generatedMethods.add("getMap", method -> { + method.addAnnotation(AnnotationSpec .builder(SuppressWarnings.class) .addMember("value", "{\"rawtypes\", \"unchecked\"}") .build()); - builder.returns(Map.class); - builder.addStatement("$T map = new $T($L)", Map.class, + method.returns(Map.class); + method.addStatement("$T map = new $T($L)", Map.class, LinkedHashMap.class, map.size()); - map.forEach((key, value) -> builder.addStatement("map.put($L, $L)", + map.forEach((key, value) -> method.addStatement("map.put($L, $L)", BeanDefinitionPropertyValueCodeGenerator.this .generateCode(key, keyType), BeanDefinitionPropertyValueCodeGenerator.this .generateCode(value, valueType))); - builder.addStatement("return map"); + method.addStatement("return map"); }); - return CodeBlock.of("$L()", method.getName()); + return CodeBlock.of("$L()", generatedMethod.getName()); } } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java index 92e250ba7e9..0522eccfee7 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java @@ -16,7 +16,7 @@ package org.springframework.beans.factory.aot; -import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.MethodReference; /** @@ -35,11 +35,10 @@ public interface BeanFactoryInitializationCode { String BEAN_FACTORY_VARIABLE = "beanFactory"; /** - * Return a {@link MethodGenerator} that can be used to add more methods to - * the Initializing code. + * Return the {@link GeneratedMethods} being used by the Initializing code. * @return the method generator */ - MethodGenerator getMethodGenerator(); + GeneratedMethods getMethods(); /** * Add an initializer method call. diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCode.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCode.java index 6bcd48646f8..ee7c9db9a56 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCode.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCode.java @@ -16,7 +16,7 @@ package org.springframework.beans.factory.aot; -import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.javapoet.ClassName; @@ -39,11 +39,10 @@ public interface BeanRegistrationCode { ClassName getClassName(); /** - * Return a {@link MethodGenerator} that can be used to add more methods to - * the registrations code. + * Return a {@link GeneratedMethods} being used by the registrations code. * @return the method generator */ - MethodGenerator getMethodGenerator(); + GeneratedMethods getMethods(); /** * Add an instance post processor method call to the registration code. diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java index bd1eb2ba9a7..f3ed9ac42b6 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java @@ -21,8 +21,8 @@ import java.util.ArrayList; import java.util.List; import java.util.function.Predicate; +import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.javapoet.ClassName; @@ -41,7 +41,7 @@ class BeanRegistrationCodeGenerator implements BeanRegistrationCode { private final ClassName className; - private final MethodGenerator methodGenerator; + private final GeneratedMethods generatedMethods; private final List instancePostProcessors = new ArrayList<>(); @@ -52,12 +52,12 @@ class BeanRegistrationCodeGenerator implements BeanRegistrationCode { private final BeanRegistrationCodeFragments codeFragments; - BeanRegistrationCodeGenerator(ClassName className, MethodGenerator methodGenerator, + BeanRegistrationCodeGenerator(ClassName className, GeneratedMethods methodGenerator, RegisteredBean registeredBean, Executable constructorOrFactoryMethod, BeanRegistrationCodeFragments codeFragments) { this.className = className; - this.methodGenerator = methodGenerator; + this.generatedMethods = methodGenerator; this.registeredBean = registeredBean; this.constructorOrFactoryMethod = constructorOrFactoryMethod; this.codeFragments = codeFragments; @@ -69,8 +69,8 @@ class BeanRegistrationCodeGenerator implements BeanRegistrationCode { } @Override - public MethodGenerator getMethodGenerator() { - return this.methodGenerator; + public GeneratedMethods getMethods() { + return this.generatedMethods; } @Override diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java index 7b29174c239..4902ce8a150 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java @@ -22,8 +22,8 @@ import javax.lang.model.element.Modifier; import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.javapoet.ClassName; @@ -60,27 +60,24 @@ class BeanRegistrationsAotContribution BeanFactoryInitializationCode beanFactoryInitializationCode) { GeneratedClass generatedClass = generationContext.getGeneratedClasses() - .forFeature("BeanFactoryRegistrations").generate(type -> { + .addForFeature("BeanFactoryRegistrations", type -> { type.addJavadoc("Register bean definitions for the bean factory."); type.addModifiers(Modifier.PUBLIC); }); - BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator( - generatedClass); - GeneratedMethod registerMethod = codeGenerator.getMethodGenerator() - .generateMethod("registerBeanDefinitions") - .using(builder -> generateRegisterMethod(builder, generationContext, - codeGenerator)); - beanFactoryInitializationCode - .addInitializer(MethodReference.of(generatedClass.getName(), registerMethod.getName())); + BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator(generatedClass); + GeneratedMethod generatedMethod = codeGenerator.getMethods().add("registerBeanDefinitions", method -> + generateRegisterMethod(method, generationContext, codeGenerator)); + beanFactoryInitializationCode.addInitializer( + MethodReference.of(generatedClass.getName(), generatedMethod.getName())); } - private void generateRegisterMethod(MethodSpec.Builder builder, + private void generateRegisterMethod(MethodSpec.Builder method, GenerationContext generationContext, BeanRegistrationsCode beanRegistrationsCode) { - builder.addJavadoc("Register the bean definitions."); - builder.addModifiers(Modifier.PUBLIC); - builder.addParameter(DefaultListableBeanFactory.class, + method.addJavadoc("Register the bean definitions."); + method.addModifiers(Modifier.PUBLIC); + method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME); CodeBlock.Builder code = CodeBlock.builder(); this.registrations.forEach((beanName, beanDefinitionMethodGenerator) -> { @@ -91,7 +88,7 @@ class BeanRegistrationsAotContribution BEAN_FACTORY_PARAMETER_NAME, beanName, beanDefinitionMethod.toInvokeCodeBlock()); }); - builder.addCode(code.build()); + method.addCode(code.build()); } @@ -113,8 +110,8 @@ class BeanRegistrationsAotContribution } @Override - public MethodGenerator getMethodGenerator() { - return this.generatedClass.getMethodGenerator(); + public GeneratedMethods getMethods() { + return this.generatedClass.getMethods(); } } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsCode.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsCode.java index b3ff475fe81..f325d7da0c6 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsCode.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsCode.java @@ -16,7 +16,7 @@ package org.springframework.beans.factory.aot; -import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.GeneratedMethods; import org.springframework.javapoet.ClassName; /** @@ -35,10 +35,9 @@ public interface BeanRegistrationsCode { ClassName getClassName(); /** - * Return a {@link MethodGenerator} that can be used to add more methods to - * the registrations code. + * Return a {@link GeneratedMethods} being used by the registrations code. * @return the method generator */ - MethodGenerator getMethodGenerator(); + GeneratedMethods getMethods(); } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index 4f00fa68a4b..cf4e4f045c4 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -105,7 +105,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments return new BeanDefinitionPropertiesCodeGenerator( generationContext.getRuntimeHints(), attributeFilter, - beanRegistrationCode.getMethodGenerator(), + beanRegistrationCode.getMethods(), (name, value) -> generateValueCode(generationContext, name, value)) .generateCode(beanDefinition); } @@ -170,7 +170,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments return new InstanceSupplierCodeGenerator(generationContext, beanRegistrationCode.getClassName(), - beanRegistrationCode.getMethodGenerator(), allowDirectSupplierShortcut) + beanRegistrationCode.getMethods(), allowDirectSupplierShortcut) .generateCode(this.registeredBean, constructorOrFactoryMethod); } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java index a0703f56683..8ec4ad93e86 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java @@ -26,8 +26,8 @@ import java.util.function.Consumer; import org.springframework.aot.generate.AccessVisibility; import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.hint.ExecutableHint; import org.springframework.aot.hint.ExecutableMode; import org.springframework.beans.factory.support.InstanceSupplier; @@ -36,6 +36,7 @@ import org.springframework.core.ResolvableType; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.MethodSpec.Builder; import org.springframework.util.ClassUtils; import org.springframework.util.function.ThrowingSupplier; @@ -60,7 +61,7 @@ class InstanceSupplierCodeGenerator { private static final CodeBlock NO_ARGS = CodeBlock.of(""); - private static final Consumer INTROSPECT = builder -> builder + private static final Consumer INTROSPECT = hint -> hint .withMode(ExecutableMode.INTROSPECT); @@ -68,18 +69,18 @@ class InstanceSupplierCodeGenerator { private final ClassName className; - private final MethodGenerator methodGenerator; + private final GeneratedMethods generatedMethods; private final boolean allowDirectSupplierShortcut; InstanceSupplierCodeGenerator(GenerationContext generationContext, - ClassName className, MethodGenerator methodGenerator, + ClassName className, GeneratedMethods generatedMethods, boolean allowDirectSupplierShortcut) { this.generationContext = generationContext; this.className = className; - this.methodGenerator = methodGenerator; + this.generatedMethods = generatedMethods; this.allowDirectSupplierShortcut = allowDirectSupplierShortcut; } @@ -131,11 +132,11 @@ class InstanceSupplierCodeGenerator { return CodeBlock.of("$T.of($T::new)", ThrowingSupplier.class, declaringClass); } - GeneratedMethod getInstanceMethod = generateGetInstanceMethod() - .using(builder -> buildGetInstanceMethodForConstructor(builder, name, - constructor, declaringClass, dependsOnBean, PRIVATE_STATIC)); + GeneratedMethod generatedMethod = generateGetInstanceMethod(method -> + buildGetInstanceMethodForConstructor(method, name, constructor, declaringClass, + dependsOnBean, PRIVATE_STATIC)); return CodeBlock.of("$T.of($T::$L)", InstanceSupplier.class, this.className, - getInstanceMethod.getName()); + generatedMethod.getName()); } private CodeBlock generateCodeForInaccessibleConstructor(String name, @@ -143,33 +144,33 @@ class InstanceSupplierCodeGenerator { this.generationContext.getRuntimeHints().reflection() .registerConstructor(constructor); - GeneratedMethod getInstanceMethod = generateGetInstanceMethod().using(builder -> { - builder.addJavadoc("Instantiate the bean instance for '$L'.", name); - builder.addModifiers(PRIVATE_STATIC); - builder.returns(declaringClass); - builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); + GeneratedMethod generatedMethod = generateGetInstanceMethod(method -> { + method.addJavadoc("Instantiate the bean instance for '$L'.", name); + method.addModifiers(PRIVATE_STATIC); + method.returns(declaringClass); + method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); int parameterOffset = (!dependsOnBean) ? 0 : 1; - builder.addStatement( + method.addStatement( generateResolverForConstructor(constructor, parameterOffset)); - builder.addStatement("return resolver.resolveAndInstantiate($L)", + method.addStatement("return resolver.resolveAndInstantiate($L)", REGISTERED_BEAN_PARAMETER_NAME); }); return CodeBlock.of("$T.of($T::$L)", InstanceSupplier.class, this.className, - getInstanceMethod.getName()); + generatedMethod.getName()); } - private void buildGetInstanceMethodForConstructor(MethodSpec.Builder builder, + private void buildGetInstanceMethodForConstructor(MethodSpec.Builder method, String name, Constructor constructor, Class declaringClass, boolean dependsOnBean, javax.lang.model.element.Modifier... modifiers) { - builder.addJavadoc("Create the bean instance for '$L'.", name); - builder.addModifiers(modifiers); - builder.returns(declaringClass); - builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); + method.addJavadoc("Create the bean instance for '$L'.", name); + method.addModifiers(modifiers); + method.returns(declaringClass); + method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); if (constructor.getParameterCount() == 0) { CodeBlock instantiationCode = generateNewInstanceCodeForConstructor( dependsOnBean, declaringClass, NO_ARGS); - builder.addCode(generateReturnStatement(instantiationCode)); + method.addCode(generateReturnStatement(instantiationCode)); } else { int parameterOffset = (!dependsOnBean) ? 0 : 1; @@ -183,7 +184,7 @@ class InstanceSupplierCodeGenerator { declaringClass, arguments); code.addStatement("return resolver.resolve($L, (args) -> $L)", REGISTERED_BEAN_PARAMETER_NAME, newInstance); - builder.addCode(code.build()); + method.addCode(code.build()); } } @@ -241,11 +242,11 @@ class InstanceSupplierCodeGenerator { return CodeBlock.of("$T.of($T::$L)", ThrowingSupplier.class, declaringClass, factoryMethod.getName()); } - GeneratedMethod getInstanceMethod = generateGetInstanceMethod() - .using(builder -> buildGetInstanceMethodForFactoryMethod(builder, name, - factoryMethod, declaringClass, dependsOnBean, PRIVATE_STATIC)); + GeneratedMethod generatedMethod = generateGetInstanceMethod(method -> + buildGetInstanceMethodForFactoryMethod(method, name, factoryMethod, declaringClass, + dependsOnBean, PRIVATE_STATIC)); return CodeBlock.of("$T.of($T::$L)", InstanceSupplier.class, this.className, - getInstanceMethod.getName()); + generatedMethod.getName()); } private CodeBlock generateCodeForInaccessibleFactoryMethod(String name, @@ -253,36 +254,36 @@ class InstanceSupplierCodeGenerator { this.generationContext.getRuntimeHints().reflection() .registerMethod(factoryMethod); - GeneratedMethod getInstanceMethod = generateGetInstanceMethod().using(builder -> { - builder.addJavadoc("Instantiate the bean instance for '$L'.", name); - builder.addModifiers(PRIVATE_STATIC); - builder.returns(factoryMethod.getReturnType()); - builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); - builder.addStatement(generateResolverForFactoryMethod(factoryMethod, + GeneratedMethod generatedMethod = generateGetInstanceMethod(method -> { + method.addJavadoc("Instantiate the bean instance for '$L'.", name); + method.addModifiers(PRIVATE_STATIC); + method.returns(factoryMethod.getReturnType()); + method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); + method.addStatement(generateResolverForFactoryMethod(factoryMethod, declaringClass, factoryMethod.getName())); - builder.addStatement("return resolver.resolveAndInstantiate($L)", + method.addStatement("return resolver.resolveAndInstantiate($L)", REGISTERED_BEAN_PARAMETER_NAME); }); return CodeBlock.of("$T.of($T::$L)", InstanceSupplier.class, this.className, - getInstanceMethod.getName()); + generatedMethod.getName()); } - private void buildGetInstanceMethodForFactoryMethod(MethodSpec.Builder builder, + private void buildGetInstanceMethodForFactoryMethod(MethodSpec.Builder method, String name, Method factoryMethod, Class declaringClass, boolean dependsOnBean, javax.lang.model.element.Modifier... modifiers) { String factoryMethodName = factoryMethod.getName(); - builder.addJavadoc("Get the bean instance for '$L'.", name); - builder.addModifiers(modifiers); - builder.returns(factoryMethod.getReturnType()); + method.addJavadoc("Get the bean instance for '$L'.", name); + method.addModifiers(modifiers); + method.returns(factoryMethod.getReturnType()); if (isThrowingCheckedException(factoryMethod)) { - builder.addException(Exception.class); + method.addException(Exception.class); } - builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); + method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER_NAME); if (factoryMethod.getParameterCount() == 0) { CodeBlock instantiationCode = generateNewInstanceCodeForMethod(dependsOnBean, declaringClass, factoryMethodName, NO_ARGS); - builder.addCode(generateReturnStatement(instantiationCode)); + method.addCode(generateReturnStatement(instantiationCode)); } else { CodeBlock.Builder code = CodeBlock.builder(); @@ -294,7 +295,7 @@ class InstanceSupplierCodeGenerator { declaringClass, factoryMethodName, arguments); code.addStatement("return resolver.resolve($L, (args) -> $L)", REGISTERED_BEAN_PARAMETER_NAME, newInstance); - builder.addCode(code.build()); + method.addCode(code.build()); } } @@ -349,8 +350,8 @@ class InstanceSupplierCodeGenerator { return builder.build(); } - private GeneratedMethod generateGetInstanceMethod() { - return this.methodGenerator.generateMethod("get", "instance"); + private GeneratedMethod generateGetInstanceMethod(Consumer method) { + return this.generatedMethods.add("getInstance", method); } private boolean isThrowingCheckedException(Executable executable) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java index a42f26d46bd..e5155c32d4d 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java @@ -21,13 +21,11 @@ import java.util.function.BiFunction; import javax.lang.model.element.Modifier; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.DefaultGenerationContext; import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.generate.MethodReference; -import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; import org.springframework.aot.test.generator.compile.CompileWithTargetClassAccess; import org.springframework.aot.test.generator.compile.Compiled; @@ -41,10 +39,8 @@ import org.springframework.core.env.Environment; import org.springframework.core.env.StandardEnvironment; import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; -import org.springframework.javapoet.TypeSpec; import static org.assertj.core.api.Assertions.assertThat; @@ -57,25 +53,23 @@ import static org.assertj.core.api.Assertions.assertThat; */ class AutowiredAnnotationBeanRegistrationAotContributionTests { - private InMemoryGeneratedFiles generatedFiles; + private final InMemoryGeneratedFiles generatedFiles; - private DefaultGenerationContext generationContext; + private final DefaultGenerationContext generationContext; - private RuntimeHints runtimeHints; + private final MockBeanRegistrationCode beanRegistrationCode; - private MockBeanRegistrationCode beanRegistrationCode; + private final DefaultListableBeanFactory beanFactory; - private DefaultListableBeanFactory beanFactory; - @BeforeEach - void setup() { + AutowiredAnnotationBeanRegistrationAotContributionTests() { this.generatedFiles = new InMemoryGeneratedFiles(); this.generationContext = new TestGenerationContext(this.generatedFiles); - this.runtimeHints = this.generationContext.getRuntimeHints(); - this.beanRegistrationCode = new MockBeanRegistrationCode(); + this.beanRegistrationCode = new MockBeanRegistrationCode(this.generationContext); this.beanFactory = new DefaultListableBeanFactory(); } + @Test void contributeWhenPrivateFieldInjectionInjectsUsingReflection() { Environment environment = new StandardEnvironment(); @@ -84,8 +78,8 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { PrivateFieldInjectionSample.class); assertThat(RuntimeHintsPredicates.reflection() .onField(PrivateFieldInjectionSample.class, "environment").allowWrite()) - .accepts(this.runtimeHints); - testCompiledResult(registeredBean, (postProcessor, compiled) -> { + .accepts(this.generationContext.getRuntimeHints()); + compile(registeredBean, (postProcessor, compiled) -> { PrivateFieldInjectionSample instance = new PrivateFieldInjectionSample(); postProcessor.apply(registeredBean, instance); assertThat(instance).extracting("environment").isSameAs(environment); @@ -103,8 +97,8 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { PackagePrivateFieldInjectionSample.class); assertThat(RuntimeHintsPredicates.reflection() .onField(PackagePrivateFieldInjectionSample.class, "environment").allowWrite()) - .accepts(this.runtimeHints); - testCompiledResult(registeredBean, (postProcessor, compiled) -> { + .accepts(this.generationContext.getRuntimeHints()); + compile(registeredBean, (postProcessor, compiled) -> { PackagePrivateFieldInjectionSample instance = new PackagePrivateFieldInjectionSample(); postProcessor.apply(registeredBean, instance); assertThat(instance).extracting("environment").isSameAs(environment); @@ -121,8 +115,8 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { PrivateMethodInjectionSample.class); assertThat(RuntimeHintsPredicates.reflection() .onMethod(PrivateMethodInjectionSample.class, "setTestBean").invoke()) - .accepts(this.runtimeHints); - testCompiledResult(registeredBean, (postProcessor, compiled) -> { + .accepts(this.generationContext.getRuntimeHints()); + compile(registeredBean, (postProcessor, compiled) -> { PrivateMethodInjectionSample instance = new PrivateMethodInjectionSample(); postProcessor.apply(registeredBean, instance); assertThat(instance).extracting("environment").isSameAs(environment); @@ -140,8 +134,8 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { PackagePrivateMethodInjectionSample.class); assertThat(RuntimeHintsPredicates.reflection() .onMethod(PackagePrivateMethodInjectionSample.class, "setTestBean").introspect()) - .accepts(this.runtimeHints); - testCompiledResult(registeredBean, (postProcessor, compiled) -> { + .accepts(this.generationContext.getRuntimeHints()); + compile(registeredBean, (postProcessor, compiled) -> { PackagePrivateMethodInjectionSample instance = new PackagePrivateMethodInjectionSample(); postProcessor.apply(registeredBean, instance); assertThat(instance).extracting("environment").isSameAs(environment); @@ -167,29 +161,24 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { } @SuppressWarnings("unchecked") - private void testCompiledResult(RegisteredBean registeredBean, + private void compile(RegisteredBean registeredBean, BiConsumer, Compiled> result) { - this.generationContext.writeGeneratedContent(); - JavaFile javaFile = createJavaFile(registeredBean.getBeanClass()); - TestCompiler.forSystem().withFiles(this.generatedFiles).compile(javaFile::writeTo, - compiled -> result.accept(compiled.getInstance(BiFunction.class), - compiled)); - } + Class target = registeredBean.getBeanClass(); + MethodReference methodReference = this.beanRegistrationCode.getInstancePostProcessors().get(0); + this.beanRegistrationCode.getTypeBuilder().set(type -> { + type.addModifiers(Modifier.PUBLIC); + type.addSuperinterface(ParameterizedTypeName.get(BiFunction.class, RegisteredBean.class, target, target)); + type.addMethod(MethodSpec.methodBuilder("apply") + .addModifiers(Modifier.PUBLIC) + .addParameter(RegisteredBean.class, "registeredBean") + .addParameter(target, "instance").returns(target) + .addStatement("return $L", methodReference.toInvokeCodeBlock(CodeBlock.of("registeredBean"), CodeBlock.of("instance"))) + .build()); - private JavaFile createJavaFile(Class target) { - MethodReference methodReference = this.beanRegistrationCode.getInstancePostProcessors() - .get(0); - TypeSpec.Builder builder = TypeSpec.classBuilder("TestPostProcessor"); - builder.addModifiers(Modifier.PUBLIC); - builder.addSuperinterface(ParameterizedTypeName.get(BiFunction.class, - RegisteredBean.class, target, target)); - builder.addMethod(MethodSpec.methodBuilder("apply").addModifiers(Modifier.PUBLIC) - .addParameter(RegisteredBean.class, "registeredBean") - .addParameter(target, "instance").returns(target) - .addStatement("return $L", methodReference.toInvokeCodeBlock( - CodeBlock.of("registeredBean"), CodeBlock.of("instance"))) - .build()); - return JavaFile.builder("__", builder.build()).build(); + }); + this.generationContext.writeGeneratedContent(); + TestCompiler.forSystem().withFiles(this.generatedFiles).compile(compiled -> + result.accept(compiled.getInstance(BiFunction.class), compiled)); } } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index a8183ce4cb0..49718c0e575 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -25,7 +25,6 @@ import java.util.function.Supplier; import javax.lang.model.element.Modifier; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.DefaultGenerationContext; @@ -51,12 +50,9 @@ import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrat import org.springframework.core.ResolvableType; import org.springframework.core.mock.MockSpringFactoriesLoader; import org.springframework.core.testfixture.aot.generate.TestGenerationContext; -import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; -import org.springframework.javapoet.TypeSpec; import static org.assertj.core.api.Assertions.assertThat; @@ -68,27 +64,27 @@ import static org.assertj.core.api.Assertions.assertThat; */ class BeanDefinitionMethodGeneratorTests { - private InMemoryGeneratedFiles generatedFiles; + private final InMemoryGeneratedFiles generatedFiles; - private DefaultGenerationContext generationContext; + private final DefaultGenerationContext generationContext; - private DefaultListableBeanFactory beanFactory; + private final DefaultListableBeanFactory beanFactory; - private MockBeanRegistrationsCode beanRegistrationsCode; + private final MockBeanRegistrationsCode beanRegistrationsCode; - private BeanDefinitionMethodGeneratorFactory methodGeneratorFactory; + private final BeanDefinitionMethodGeneratorFactory methodGeneratorFactory; - @BeforeEach - void setup() { + + BeanDefinitionMethodGeneratorTests() { this.generatedFiles = new InMemoryGeneratedFiles(); this.generationContext = new TestGenerationContext(this.generatedFiles); this.beanFactory = new DefaultListableBeanFactory(); this.methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( new AotFactoriesLoader(this.beanFactory, new MockSpringFactoriesLoader())); - this.beanRegistrationsCode = new MockBeanRegistrationsCode( - ClassName.get("__", "Registration")); + this.beanRegistrationsCode = new MockBeanRegistrationsCode(this.generationContext); } + @Test void generateBeanDefinitionMethodGeneratesMethod() { RegisteredBean registeredBean = registerBean( @@ -98,7 +94,7 @@ class BeanDefinitionMethodGeneratorTests { Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( this.generationContext, this.beanRegistrationsCode); - testCompiledResult(method, (actual, compiled) -> { + compile(method, (actual, compiled) -> { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); assertThat(sourceFile).contains("beanType = TestBean.class"); @@ -116,7 +112,7 @@ class BeanDefinitionMethodGeneratorTests { Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( this.generationContext, this.beanRegistrationsCode); - testCompiledResult(method, (actual, compiled) -> { + compile(method, (actual, compiled) -> { assertThat(actual.getResolvableType().resolve()).isEqualTo(GenericBean.class); SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); @@ -133,15 +129,13 @@ class BeanDefinitionMethodGeneratorTests { new RootBeanDefinition(TestBean.class)); BeanRegistrationAotContribution aotContribution = (generationContext, beanRegistrationCode) -> { - GeneratedMethod method = beanRegistrationCode.getMethodGenerator() - .generateMethod("postProcess") - .using(builder -> builder.addModifiers(Modifier.STATIC) + GeneratedMethod generatedMethod = beanRegistrationCode.getMethods().add("postProcess", method -> + method.addModifiers(Modifier.STATIC) .addParameter(RegisteredBean.class, "registeredBean") .addParameter(TestBean.class, "testBean") - .returns(TestBean.class).addCode("return new $T($S);", - TestBean.class, "postprocessed")); + .returns(TestBean.class).addCode("return new $T($S);", TestBean.class, "postprocessed")); beanRegistrationCode.addInstancePostProcessor(MethodReference.ofStatic( - beanRegistrationCode.getClassName(), method.getName())); + beanRegistrationCode.getClassName(), generatedMethod.getName())); }; List aotContributions = Collections .singletonList(aotContribution); @@ -149,7 +143,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, aotContributions); MethodReference method = generator.generateBeanDefinitionMethod( this.generationContext, this.beanRegistrationsCode); - testCompiledResult(method, (actual, compiled) -> { + compile(method, (actual, compiled) -> { assertThat(actual.getBeanClass()).isEqualTo(TestBean.class); InstanceSupplier supplier = (InstanceSupplier) actual .getInstanceSupplier(); @@ -175,7 +169,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, aotContributions); MethodReference method = generator.generateBeanDefinitionMethod( this.generationContext, this.beanRegistrationsCode); - testCompiledResult(method, (actual, compiled) -> { + compile(method, (actual, compiled) -> { assertThat(actual.getBeanClass()).isEqualTo(TestBean.class); SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("I am custom"); @@ -215,7 +209,7 @@ class BeanDefinitionMethodGeneratorTests { aotContributions); MethodReference method = generator.generateBeanDefinitionMethod( this.generationContext, this.beanRegistrationsCode); - testCompiledResult(method, (actual, compiled) -> { + compile(method, (actual, compiled) -> { assertThat(actual.getAttribute("a")).isEqualTo("A"); assertThat(actual.getAttribute("b")).isNull(); }); @@ -248,7 +242,7 @@ class BeanDefinitionMethodGeneratorTests { Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( this.generationContext, this.beanRegistrationsCode); - testCompiledResult(method, (actual, compiled) -> { + compile(method, (actual, compiled) -> { assertThat(compiled.getSourceFile(".*BeanDefinitions")) .contains("Get the inner-bean definition for 'testInnerBean'"); assertThat(actual).isInstanceOf(RootBeanDefinition.class); @@ -269,7 +263,7 @@ class BeanDefinitionMethodGeneratorTests { Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( this.generationContext, this.beanRegistrationsCode); - testCompiledResult(method, (actual, compiled) -> { + compile(method, (actual, compiled) -> { RootBeanDefinition actualInnerBeanDefinition = (RootBeanDefinition) actual .getPropertyValues().get("name"); assertThat(actualInnerBeanDefinition.isPrimary()).isTrue(); @@ -303,7 +297,7 @@ class BeanDefinitionMethodGeneratorTests { Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( this.generationContext, this.beanRegistrationsCode); - testCompiledResult(method, (actual, compiled) -> { + compile(method, (actual, compiled) -> { RootBeanDefinition actualInnerBeanDefinition = (RootBeanDefinition) actual .getConstructorArgumentValues() .getIndexedArgumentValue(0, RootBeanDefinition.class).getValue(); @@ -328,15 +322,14 @@ class BeanDefinitionMethodGeneratorTests { RegisteredBean registeredBean = registerBean( new RootBeanDefinition(TestBean.class)); List aotContributions = new ArrayList<>(); - aotContributions - .add((generationContext, beanRegistrationCode) -> beanRegistrationCode - .getMethodGenerator().generateMethod("aotContributedMethod") - .using(builder -> builder.addComment("Example Contribution"))); + aotContributions.add((generationContext, beanRegistrationCode) -> + beanRegistrationCode.getMethods().add("aotContributedMethod", method -> + method.addComment("Example Contribution"))); BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( this.methodGeneratorFactory, registeredBean, null, aotContributions); MethodReference method = generator.generateBeanDefinitionMethod( this.generationContext, this.beanRegistrationsCode); - testCompiledResult(method, (actual, compiled) -> { + compile(method, (actual, compiled) -> { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("AotContributedMethod()"); assertThat(sourceFile).contains("Example Contribution"); @@ -353,7 +346,7 @@ class BeanDefinitionMethodGeneratorTests { Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( this.generationContext, this.beanRegistrationsCode); - testCompiledResult(method, (actual, compiled) -> { + compile(method, (actual, compiled) -> { DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); freshBeanFactory.registerBeanDefinition("test", actual); Object bean = freshBeanFactory.getBean("test"); @@ -369,27 +362,19 @@ class BeanDefinitionMethodGeneratorTests { return RegisteredBean.of(this.beanFactory, beanName); } - private void testCompiledResult(MethodReference method, + private void compile(MethodReference method, BiConsumer result) { + this.beanRegistrationsCode.getTypeBuilder().set(type -> { + type.addModifiers(Modifier.PUBLIC); + type.addSuperinterface(ParameterizedTypeName.get(Supplier.class, BeanDefinition.class)); + type.addMethod(MethodSpec.methodBuilder("get") + .addModifiers(Modifier.PUBLIC) + .returns(BeanDefinition.class) + .addCode("return $L;", method.toInvokeCodeBlock()).build()); + }); this.generationContext.writeGeneratedContent(); - JavaFile javaFile = generateJavaFile(method); - TestCompiler.forSystem().withFiles(this.generatedFiles).compile( - javaFile::writeTo, compiled -> result.accept( - (RootBeanDefinition) compiled.getInstance(Supplier.class).get(), - compiled)); - } - - private JavaFile generateJavaFile(MethodReference method) { - TypeSpec.Builder builder = TypeSpec.classBuilder("Registration"); - builder.addModifiers(Modifier.PUBLIC); - builder.addSuperinterface( - ParameterizedTypeName.get(Supplier.class, BeanDefinition.class)); - builder.addMethod(MethodSpec.methodBuilder("get").addModifiers(Modifier.PUBLIC) - .returns(BeanDefinition.class) - .addCode("return $L;", method.toInvokeCodeBlock()).build()); - this.beanRegistrationsCode.getMethodGenerator() - .doWithMethodSpecs(builder::addMethod); - return JavaFile.builder("__", builder.build()).build(); + TestCompiler.forSystem().withFiles(this.generatedFiles).compile(compiled -> + result.accept((RootBeanDefinition) compiled.getInstance(Supplier.class).get(), compiled)); } } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java index 5b6f9e69d6b..2b1e78d3040 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java @@ -27,8 +27,9 @@ import javax.lang.model.element.Modifier; import org.junit.jupiter.api.Test; -import org.springframework.aot.generate.GeneratedMethods; -import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.generate.DefaultGenerationContext; +import org.springframework.aot.generate.GeneratedClass; +import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; import org.springframework.aot.test.generator.compile.Compiled; import org.springframework.aot.test.generator.compile.TestCompiler; @@ -43,11 +44,11 @@ import org.springframework.beans.factory.support.ManagedList; import org.springframework.beans.factory.support.ManagedMap; import org.springframework.beans.factory.support.ManagedSet; import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.factory.aot.DeferredTypeBuilder; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; -import org.springframework.javapoet.TypeSpec; import static org.assertj.core.api.Assertions.assertThat; @@ -61,18 +62,14 @@ class BeanDefinitionPropertiesCodeGeneratorTests { private final RootBeanDefinition beanDefinition = new RootBeanDefinition(); - private final GeneratedMethods generatedMethods = new GeneratedMethods(); - - private final RuntimeHints hints = new RuntimeHints(); - - private BeanDefinitionPropertiesCodeGenerator generator = new BeanDefinitionPropertiesCodeGenerator( - this.hints, attribute -> true, this.generatedMethods, (name, value) -> null); + private final InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); + private final DefaultGenerationContext generationContext = new TestGenerationContext(this.generatedFiles); @Test void setPrimaryWhenFalse() { this.beanDefinition.setPrimary(false); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(compiled.getSourceFile()).doesNotContain("setPrimary"); assertThat(actual.isPrimary()).isFalse(); }); @@ -81,13 +78,13 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void setPrimaryWhenTrue() { this.beanDefinition.setPrimary(true); - testCompiledResult((actual, compiled) -> assertThat(actual.isPrimary()).isTrue()); + compile((actual, compiled) -> assertThat(actual.isPrimary()).isTrue()); } @Test void setScopeWhenEmptyString() { this.beanDefinition.setScope(""); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(compiled.getSourceFile()).doesNotContain("setScope"); assertThat(actual.getScope()).isEmpty(); }); @@ -96,7 +93,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void setScopeWhenSingleton() { this.beanDefinition.setScope("singleton"); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(compiled.getSourceFile()).doesNotContain("setScope"); assertThat(actual.getScope()).isEmpty(); }); @@ -105,14 +102,14 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void setScopeWhenOther() { this.beanDefinition.setScope("prototype"); - testCompiledResult((actual, compiled) -> assertThat(actual.getScope()) + compile((actual, compiled) -> assertThat(actual.getScope()) .isEqualTo("prototype")); } @Test void setDependsOnWhenEmpty() { this.beanDefinition.setDependsOn(); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(compiled.getSourceFile()).doesNotContain("setDependsOn"); assertThat(actual.getDependsOn()).isNull(); }); @@ -121,13 +118,13 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void setDependsOnWhenNotEmpty() { this.beanDefinition.setDependsOn("a", "b", "c"); - testCompiledResult((actual, compiled) -> assertThat(actual.getDependsOn()) + compile((actual, compiled) -> assertThat(actual.getDependsOn()) .containsExactly("a", "b", "c")); } @Test void setLazyInitWhenNoSet() { - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(compiled.getSourceFile()).doesNotContain("setLazyInit"); assertThat(actual.isLazyInit()).isFalse(); assertThat(actual.getLazyInit()).isNull(); @@ -137,7 +134,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void setLazyInitWhenFalse() { this.beanDefinition.setLazyInit(false); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(actual.isLazyInit()).isFalse(); assertThat(actual.getLazyInit()).isFalse(); }); @@ -146,7 +143,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void setLazyInitWhenTrue() { this.beanDefinition.setLazyInit(true); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(actual.isLazyInit()).isTrue(); assertThat(actual.getLazyInit()).isTrue(); }); @@ -155,14 +152,14 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void setAutowireCandidateWhenFalse() { this.beanDefinition.setAutowireCandidate(false); - testCompiledResult( + compile( (actual, compiled) -> assertThat(actual.isAutowireCandidate()).isFalse()); } @Test void setAutowireCandidateWhenTrue() { this.beanDefinition.setAutowireCandidate(true); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(compiled.getSourceFile()).doesNotContain("setAutowireCandidate"); assertThat(actual.isAutowireCandidate()).isTrue(); }); @@ -171,7 +168,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void setSyntheticWhenFalse() { this.beanDefinition.setSynthetic(false); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(compiled.getSourceFile()).doesNotContain("setSynthetic"); assertThat(actual.isSynthetic()).isFalse(); }); @@ -180,14 +177,14 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void setSyntheticWhenTrue() { this.beanDefinition.setSynthetic(true); - testCompiledResult( + compile( (actual, compiled) -> assertThat(actual.isSynthetic()).isTrue()); } @Test void setRoleWhenApplication() { this.beanDefinition.setRole(BeanDefinition.ROLE_APPLICATION); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(compiled.getSourceFile()).doesNotContain("setRole"); assertThat(actual.getRole()).isEqualTo(BeanDefinition.ROLE_APPLICATION); }); @@ -196,7 +193,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void setRoleWhenInfrastructure() { this.beanDefinition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(compiled.getSourceFile()) .contains("setRole(BeanDefinition.ROLE_INFRASTRUCTURE);"); assertThat(actual.getRole()).isEqualTo(BeanDefinition.ROLE_INFRASTRUCTURE); @@ -206,7 +203,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void setRoleWhenSupport() { this.beanDefinition.setRole(BeanDefinition.ROLE_SUPPORT); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(compiled.getSourceFile()) .contains("setRole(BeanDefinition.ROLE_SUPPORT);"); assertThat(actual.getRole()).isEqualTo(BeanDefinition.ROLE_SUPPORT); @@ -216,7 +213,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void setRoleWhenOther() { this.beanDefinition.setRole(999); - testCompiledResult( + compile( (actual, compiled) -> assertThat(actual.getRole()).isEqualTo(999)); } @@ -224,7 +221,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { void setInitMethodWhenSingleInitMethod() { this.beanDefinition.setTargetType(InitDestroyBean.class); this.beanDefinition.setInitMethodName("i1"); - testCompiledResult((actual, compiled) -> assertThat(actual.getInitMethodNames()) + compile((actual, compiled) -> assertThat(actual.getInitMethodNames()) .containsExactly("i1")); String[] methodNames = { "i1" }; assertHasMethodInvokeHints(InitDestroyBean.class, methodNames); @@ -234,14 +231,14 @@ class BeanDefinitionPropertiesCodeGeneratorTests { void setInitMethodWhenSingleInferredInitMethod() { this.beanDefinition.setTargetType(InitDestroyBean.class); this.beanDefinition.setInitMethodName(AbstractBeanDefinition.INFER_METHOD); - testCompiledResult((actual, compiled) -> assertThat(actual.getInitMethodNames()).isNull()); + compile((actual, compiled) -> assertThat(actual.getInitMethodNames()).isNull()); } @Test void setInitMethodWhenMultipleInitMethods() { this.beanDefinition.setTargetType(InitDestroyBean.class); this.beanDefinition.setInitMethodNames("i1", "i2"); - testCompiledResult((actual, compiled) -> assertThat(actual.getInitMethodNames()) + compile((actual, compiled) -> assertThat(actual.getInitMethodNames()) .containsExactly("i1", "i2")); String[] methodNames = { "i1", "i2" }; assertHasMethodInvokeHints(InitDestroyBean.class, methodNames); @@ -251,7 +248,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { void setDestroyMethodWhenDestroyInitMethod() { this.beanDefinition.setTargetType(InitDestroyBean.class); this.beanDefinition.setDestroyMethodName("d1"); - testCompiledResult( + compile( (actual, compiled) -> assertThat(actual.getDestroyMethodNames()) .containsExactly("d1")); String[] methodNames = { "d1" }; @@ -262,14 +259,14 @@ class BeanDefinitionPropertiesCodeGeneratorTests { void setDestroyMethodWhenSingleInferredInitMethod() { this.beanDefinition.setTargetType(InitDestroyBean.class); this.beanDefinition.setDestroyMethodName(AbstractBeanDefinition.INFER_METHOD); - testCompiledResult((actual, compiled) -> assertThat(actual.getDestroyMethodNames()).isNull()); + compile((actual, compiled) -> assertThat(actual.getDestroyMethodNames()).isNull()); } @Test void setDestroyMethodWhenMultipleDestroyMethods() { this.beanDefinition.setTargetType(InitDestroyBean.class); this.beanDefinition.setDestroyMethodNames("d1", "d2"); - testCompiledResult( + compile( (actual, compiled) -> assertThat(actual.getDestroyMethodNames()) .containsExactly("d1", "d2")); String[] methodNames = { "d1", "d2" }; @@ -277,8 +274,9 @@ class BeanDefinitionPropertiesCodeGeneratorTests { } private void assertHasMethodInvokeHints(Class beanType, String... methodNames) { - assertThat(methodNames).allMatch(methodName -> - RuntimeHintsPredicates.reflection().onMethod(beanType, methodName).invoke().test(this.hints)); + assertThat(methodNames).allMatch(methodName -> RuntimeHintsPredicates.reflection() + .onMethod(beanType, methodName).invoke() + .test(this.generationContext.getRuntimeHints())); } @Test @@ -289,7 +287,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { "test"); this.beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(2, 123); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { Map values = actual.getConstructorArgumentValues() .getIndexedArgumentValues(); assertThat(values.get(0).getValue()).isEqualTo(String.class); @@ -303,7 +301,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { this.beanDefinition.setTargetType(PropertyValuesBean.class); this.beanDefinition.getPropertyValues().add("test", String.class); this.beanDefinition.getPropertyValues().add("spring", "framework"); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(actual.getPropertyValues().get("test")).isEqualTo(String.class); assertThat(actual.getPropertyValues().get("spring")).isEqualTo("framework"); }); @@ -315,7 +313,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { void propertyValuesWhenContainsBeanReference() { this.beanDefinition.getPropertyValues().add("myService", new RuntimeBeanNameReference("test")); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(actual.getPropertyValues().contains("myService")).isTrue(); assertThat(actual.getPropertyValues().get("myService")) .isInstanceOfSatisfying(RuntimeBeanReference.class, @@ -329,7 +327,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { ManagedList managedList = new ManagedList<>(); managedList.add(new RuntimeBeanNameReference("test")); this.beanDefinition.getPropertyValues().add("value", managedList); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { Object value = actual.getPropertyValues().get("value"); assertThat(value).isInstanceOf(ManagedList.class); assertThat(((List) value).get(0)).isInstanceOf(BeanReference.class); @@ -341,7 +339,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { ManagedSet managedSet = new ManagedSet<>(); managedSet.add(new RuntimeBeanNameReference("test")); this.beanDefinition.getPropertyValues().add("value", managedSet); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { Object value = actual.getPropertyValues().get("value"); assertThat(value).isInstanceOf(ManagedSet.class); assertThat(((Set) value).iterator().next()) @@ -354,7 +352,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { ManagedMap managedMap = new ManagedMap<>(); managedMap.put("test", new RuntimeBeanNameReference("test")); this.beanDefinition.getPropertyValues().add("value", managedMap); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { Object value = actual.getPropertyValues().get("value"); assertThat(value).isInstanceOf(ManagedMap.class); assertThat(((Map) value).get("test")).isInstanceOf(BeanReference.class); @@ -366,9 +364,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { this.beanDefinition.setAttribute("a", "A"); this.beanDefinition.setAttribute("b", "B"); Predicate attributeFilter = attribute -> false; - this.generator = new BeanDefinitionPropertiesCodeGenerator(this.hints, - attributeFilter, this.generatedMethods, (name, value) -> null); - testCompiledResult((actual, compiled) -> { + compile(attributeFilter, (actual, compiled) -> { assertThat(compiled.getSourceFile()).doesNotContain("setAttribute"); assertThat(actual.getAttribute("a")).isNull(); assertThat(actual.getAttribute("b")).isNull(); @@ -380,9 +376,7 @@ class BeanDefinitionPropertiesCodeGeneratorTests { this.beanDefinition.setAttribute("a", "A"); this.beanDefinition.setAttribute("b", "B"); Predicate attributeFilter = "a"::equals; - this.generator = new BeanDefinitionPropertiesCodeGenerator(this.hints, - attributeFilter, this.generatedMethods, (name, value) -> null); - testCompiledResult(this.beanDefinition, (actual, compiled) -> { + compile(attributeFilter, (actual, compiled) -> { assertThat(actual.getAttribute("a")).isEqualTo("A"); assertThat(actual.getAttribute("b")).isNull(); }); @@ -393,47 +387,43 @@ class BeanDefinitionPropertiesCodeGeneratorTests { this.beanDefinition.setPrimary(true); this.beanDefinition.setScope("test"); this.beanDefinition.setRole(BeanDefinition.ROLE_SUPPORT); - testCompiledResult((actual, compiled) -> { + compile((actual, compiled) -> { assertThat(actual.isPrimary()).isTrue(); assertThat(actual.getScope()).isEqualTo("test"); assertThat(actual.getRole()).isEqualTo(BeanDefinition.ROLE_SUPPORT); }); } - private void testCompiledResult(BiConsumer result) { - testCompiledResult(this.beanDefinition, result); + private void compile(BiConsumer result) { + compile(attribute -> true, result); } - private void testCompiledResult(RootBeanDefinition beanDefinition, + private void compile( + Predicate attributeFilter, BiConsumer result) { - testCompiledResult(() -> this.generator.generateCode(beanDefinition), result); - } - - private void testCompiledResult(Supplier codeBlock, - BiConsumer result) { - JavaFile javaFile = createJavaFile(codeBlock); - TestCompiler.forSystem().compile(javaFile::writeTo, compiled -> { - RootBeanDefinition beanDefinition = (RootBeanDefinition) compiled - .getInstance(Supplier.class).get(); - result.accept(beanDefinition, compiled); + DeferredTypeBuilder typeBuilder = new DeferredTypeBuilder(); + GeneratedClass generatedClass = this.generationContext.getGeneratedClasses().addForFeature("TestCode", typeBuilder); + BeanDefinitionPropertiesCodeGenerator codeGenerator = new BeanDefinitionPropertiesCodeGenerator( + this.generationContext.getRuntimeHints(), attributeFilter, + generatedClass.getMethods(), (name, value) -> null); + CodeBlock generatedCode = codeGenerator.generateCode(this.beanDefinition); + typeBuilder.set(type -> { + type.addModifiers(Modifier.PUBLIC); + type.addSuperinterface(ParameterizedTypeName.get(Supplier.class, RootBeanDefinition.class)); + type.addMethod(MethodSpec.methodBuilder("get") + .addModifiers(Modifier.PUBLIC) + .returns(RootBeanDefinition.class) + .addStatement("$T beanDefinition = new $T()", RootBeanDefinition.class, RootBeanDefinition.class) + .addStatement("$T beanFactory = new $T()", DefaultListableBeanFactory.class, DefaultListableBeanFactory.class) + .addCode(generatedCode) + .addStatement("return beanDefinition").build()); + }); + this.generationContext.writeGeneratedContent(); + TestCompiler.forSystem().withFiles(this.generatedFiles).compile(compiled -> { + RootBeanDefinition suppliedBeanDefinition = (RootBeanDefinition) compiled + .getInstance(Supplier.class).get(); + result.accept(suppliedBeanDefinition, compiled); }); - } - - private JavaFile createJavaFile(Supplier codeBlock) { - TypeSpec.Builder builder = TypeSpec.classBuilder("BeanSupplier"); - builder.addModifiers(Modifier.PUBLIC); - builder.addSuperinterface( - ParameterizedTypeName.get(Supplier.class, RootBeanDefinition.class)); - builder.addMethod(MethodSpec.methodBuilder("get").addModifiers(Modifier.PUBLIC) - .returns(RootBeanDefinition.class) - .addStatement("$T beanDefinition = new $T()", RootBeanDefinition.class, - RootBeanDefinition.class) - .addStatement("$T beanFactory = new $T()", - DefaultListableBeanFactory.class, - DefaultListableBeanFactory.class) - .addCode(codeBlock.get()).addStatement("return beanDefinition").build()); - this.generatedMethods.doWithMethodSpecs(builder::addMethod); - return JavaFile.builder("com.example", builder.build()).build(); } static class InitDestroyBean { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGeneratorTests.java index bbdeb883e29..cacecfbc176 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertyValueCodeGeneratorTests.java @@ -33,21 +33,22 @@ import javax.lang.model.element.Modifier; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; -import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.generate.DefaultGenerationContext; +import org.springframework.aot.generate.GeneratedClass; +import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.test.generator.compile.Compiled; import org.springframework.aot.test.generator.compile.TestCompiler; -import org.springframework.aot.test.generator.file.SourceFile; import org.springframework.beans.factory.config.BeanReference; import org.springframework.beans.factory.config.RuntimeBeanNameReference; import org.springframework.beans.factory.support.ManagedList; import org.springframework.beans.factory.support.ManagedMap; import org.springframework.beans.factory.support.ManagedSet; +import org.springframework.beans.testfixture.beans.factory.aot.DeferredTypeBuilder; import org.springframework.core.ResolvableType; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; -import org.springframework.javapoet.TypeSpec; import static org.assertj.core.api.Assertions.assertThat; @@ -61,28 +62,23 @@ import static org.assertj.core.api.Assertions.assertThat; */ class BeanDefinitionPropertyValueCodeGeneratorTests { - private GeneratedMethods generatedMethods = new GeneratedMethods(); - - private BeanDefinitionPropertyValueCodeGenerator instance = new BeanDefinitionPropertyValueCodeGenerator( - generatedMethods); - private void compile(Object value, BiConsumer result) { - CodeBlock code = instance.generateCode(value); - JavaFile javaFile = createJavaFile(code); - TestCompiler.forSystem().compile(SourceFile.of(javaFile::writeTo), - compiled -> result.accept(compiled.getInstance(Supplier.class).get(), - compiled)); - } - - private JavaFile createJavaFile(CodeBlock code) { - TypeSpec.Builder builder = TypeSpec.classBuilder("InstanceSupplier"); - builder.addModifiers(Modifier.PUBLIC); - builder.addSuperinterface( - ParameterizedTypeName.get(Supplier.class, Object.class)); - builder.addMethod(MethodSpec.methodBuilder("get").addModifiers(Modifier.PUBLIC) - .returns(Object.class).addStatement("return $L", code).build()); - generatedMethods.doWithMethodSpecs(builder::addMethod); - return JavaFile.builder("com.example", builder.build()).build(); + InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); + DefaultGenerationContext generationContext = new TestGenerationContext(generatedFiles); + DeferredTypeBuilder typeBuilder = new DeferredTypeBuilder(); + GeneratedClass generatedClass = generationContext.getGeneratedClasses().addForFeature("TestCode", typeBuilder); + CodeBlock generatedCode = new BeanDefinitionPropertyValueCodeGenerator( + generatedClass.getMethods()).generateCode(value); + typeBuilder.set(type -> { + type.addModifiers(Modifier.PUBLIC); + type.addSuperinterface( + ParameterizedTypeName.get(Supplier.class, Object.class)); + type.addMethod(MethodSpec.methodBuilder("get").addModifiers(Modifier.PUBLIC) + .returns(Object.class).addStatement("return $L", generatedCode).build()); + }); + generationContext.writeGeneratedContent(); + TestCompiler.forSystem().withFiles(generatedFiles).compile(compiled -> + result.accept(compiled.getInstance(Supplier.class).get(), compiled)); } @Nested diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java index 580855d9462..1b6afb8cc7c 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java @@ -26,7 +26,6 @@ import java.util.function.Consumer; import javax.lang.model.element.Modifier; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.ClassNameGenerator; @@ -46,10 +45,8 @@ import org.springframework.core.mock.MockSpringFactoriesLoader; import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.core.testfixture.aot.generate.TestTarget; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; -import org.springframework.javapoet.TypeSpec; import static org.assertj.core.api.Assertions.assertThat; @@ -60,28 +57,30 @@ import static org.assertj.core.api.Assertions.assertThat; */ class BeanRegistrationsAotContributionTests { - private InMemoryGeneratedFiles generatedFiles; - - private DefaultGenerationContext generationContext; + private final MockSpringFactoriesLoader springFactoriesLoader; private DefaultListableBeanFactory beanFactory; - private MockSpringFactoriesLoader springFactoriesLoader; + private final InMemoryGeneratedFiles generatedFiles; - private BeanDefinitionMethodGeneratorFactory methodGeneratorFactory; + private DefaultGenerationContext generationContext; - private MockBeanFactoryInitializationCode beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); + private final BeanDefinitionMethodGeneratorFactory methodGeneratorFactory; - @BeforeEach - void setup() { + private MockBeanFactoryInitializationCode beanFactoryInitializationCode; + + + BeanRegistrationsAotContributionTests() { + this.springFactoriesLoader = new MockSpringFactoriesLoader(); + this.beanFactory = new DefaultListableBeanFactory(); this.generatedFiles = new InMemoryGeneratedFiles(); this.generationContext = new TestGenerationContext(this.generatedFiles); - this.beanFactory = new DefaultListableBeanFactory(); - this.springFactoriesLoader = new MockSpringFactoriesLoader(); this.methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( new AotFactoriesLoader(this.beanFactory, this.springFactoriesLoader)); + this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(this.generationContext); } + @Test void applyToAppliesContribution() { Map registrations = new LinkedHashMap<>(); @@ -94,7 +93,7 @@ class BeanRegistrationsAotContributionTests { BeanRegistrationsAotContribution contribution = new BeanRegistrationsAotContribution( registrations); contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); - testCompiledResult((consumer, compiled) -> { + compile((consumer, compiled) -> { DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); consumer.accept(freshBeanFactory); assertThat(freshBeanFactory.getBean(TestBean.class)).isNotNull(); @@ -105,7 +104,7 @@ class BeanRegistrationsAotContributionTests { void applyToWhenHasNameGeneratesPrefixedFeatureName() { this.generationContext = new DefaultGenerationContext( new ClassNameGenerator(TestTarget.class, "Management"), this.generatedFiles); - this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); + this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(this.generationContext); Map registrations = new LinkedHashMap<>(); RegisteredBean registeredBean = registerBean( new RootBeanDefinition(TestBean.class)); @@ -116,7 +115,7 @@ class BeanRegistrationsAotContributionTests { BeanRegistrationsAotContribution contribution = new BeanRegistrationsAotContribution( registrations); contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); - testCompiledResult((consumer, compiled) -> { + compile((consumer, compiled) -> { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile.getClassName()).endsWith("__ManagementBeanDefinitions"); }); @@ -148,7 +147,7 @@ class BeanRegistrationsAotContributionTests { contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); assertThat(beanRegistrationsCodes).hasSize(1); BeanRegistrationsCode actual = beanRegistrationsCodes.get(0); - assertThat(actual.getMethodGenerator()).isNotNull(); + assertThat(actual.getMethods()).isNotNull(); } private RegisteredBean registerBean(RootBeanDefinition rootBeanDefinition) { @@ -158,27 +157,21 @@ class BeanRegistrationsAotContributionTests { } @SuppressWarnings({ "unchecked", "cast" }) - private void testCompiledResult( + private void compile( BiConsumer, Compiled> result) { + MethodReference methodReference = this.beanFactoryInitializationCode + .getInitializers().get(0); + this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + type.addModifiers(Modifier.PUBLIC); + type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); + type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) + .addParameter(DefaultListableBeanFactory.class, "beanFactory") + .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .build()); + }); this.generationContext.writeGeneratedContent(); - JavaFile javaFile = createJavaFile(); - TestCompiler.forSystem().withFiles(this.generatedFiles).compile(javaFile::writeTo, - compiled -> result.accept(compiled.getInstance(Consumer.class), - compiled)); - } - - private JavaFile createJavaFile() { - MethodReference initializer = this.beanFactoryInitializationCode.getInitializers() - .get(0); - TypeSpec.Builder builder = TypeSpec.classBuilder("BeanFactoryConsumer"); - builder.addModifiers(Modifier.PUBLIC); - builder.addSuperinterface(ParameterizedTypeName.get(Consumer.class, - DefaultListableBeanFactory.class)); - builder.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) - .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(initializer.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) - .build()); - return JavaFile.builder("__", builder.build()).build(); + TestCompiler.forSystem().withFiles(this.generatedFiles).printFiles(System.out).compile(compiled -> + result.accept(compiled.getInstance(Consumer.class), compiled)); } } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java index beb867620c1..de04e7d9ab2 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java @@ -23,11 +23,10 @@ import java.util.function.Supplier; import javax.lang.model.element.Modifier; import org.assertj.core.api.ThrowingConsumer; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.DefaultGenerationContext; -import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.hint.ExecutableHint; import org.springframework.aot.hint.ExecutableMode; @@ -43,6 +42,7 @@ import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.TestBeanWithPrivateConstructor; +import org.springframework.beans.testfixture.beans.factory.aot.DeferredTypeBuilder; import org.springframework.beans.testfixture.beans.factory.generator.InnerComponentConfiguration; import org.springframework.beans.testfixture.beans.factory.generator.InnerComponentConfiguration.EnvironmentAwareComponent; import org.springframework.beans.testfixture.beans.factory.generator.InnerComponentConfiguration.NoDependencyComponent; @@ -53,12 +53,9 @@ import org.springframework.beans.testfixture.beans.factory.generator.factory.Sam import org.springframework.beans.testfixture.beans.factory.generator.injection.InjectionComponent; import org.springframework.core.env.StandardEnvironment; import org.springframework.core.testfixture.aot.generate.TestGenerationContext; -import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; -import org.springframework.javapoet.TypeSpec; import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -71,17 +68,14 @@ import static org.assertj.core.api.Assertions.assertThat; */ class InstanceSupplierCodeGeneratorTests { - private InMemoryGeneratedFiles generatedFiles; + private final InMemoryGeneratedFiles generatedFiles; - private DefaultGenerationContext generationContext; + private final DefaultGenerationContext generationContext; private boolean allowDirectSupplierShortcut = false; - private ClassName className = ClassName.get("__", "InstanceSupplierSupplier"); - - @BeforeEach - void setup() { + InstanceSupplierCodeGeneratorTests() { this.generatedFiles = new InMemoryGeneratedFiles(); this.generationContext = new TestGenerationContext(this.generatedFiles); } @@ -91,7 +85,7 @@ class InstanceSupplierCodeGeneratorTests { void generateWhenHasDefaultConstructor() { BeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class); DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); - testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + compile(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { TestBean bean = getBean(beanFactory, beanDefinition, instanceSupplier); assertThat(bean).isInstanceOf(TestBean.class); assertThat(compiled.getSourceFile()) @@ -106,7 +100,7 @@ class InstanceSupplierCodeGeneratorTests { BeanDefinition beanDefinition = new RootBeanDefinition(InjectionComponent.class); DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); beanFactory.registerSingleton("injected", "injected"); - testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + compile(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { InjectionComponent bean = getBean(beanFactory, beanDefinition, instanceSupplier); assertThat(bean).isInstanceOf(InjectionComponent.class).extracting("bean") @@ -122,7 +116,7 @@ class InstanceSupplierCodeGeneratorTests { NoDependencyComponent.class); DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); beanFactory.registerSingleton("configuration", new InnerComponentConfiguration()); - testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + compile(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { NoDependencyComponent bean = getBean(beanFactory, beanDefinition, instanceSupplier); assertThat(bean).isInstanceOf(NoDependencyComponent.class); @@ -140,7 +134,7 @@ class InstanceSupplierCodeGeneratorTests { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); beanFactory.registerSingleton("configuration", new InnerComponentConfiguration()); beanFactory.registerSingleton("environment", new StandardEnvironment()); - testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + compile(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { EnvironmentAwareComponent bean = getBean(beanFactory, beanDefinition, instanceSupplier); assertThat(bean).isInstanceOf(EnvironmentAwareComponent.class); @@ -157,7 +151,7 @@ class InstanceSupplierCodeGeneratorTests { NumberHolderFactoryBean.class); DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); beanFactory.registerSingleton("number", 123); - testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + compile(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { NumberHolder bean = getBean(beanFactory, beanDefinition, instanceSupplier); assertThat(bean).isInstanceOf(NumberHolder.class); assertThat(bean).extracting("number").isNull(); // No property @@ -173,7 +167,7 @@ class InstanceSupplierCodeGeneratorTests { BeanDefinition beanDefinition = new RootBeanDefinition( TestBeanWithPrivateConstructor.class); DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); - testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + compile(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { TestBeanWithPrivateConstructor bean = getBean(beanFactory, beanDefinition, instanceSupplier); assertThat(bean).isInstanceOf(TestBeanWithPrivateConstructor.class); @@ -192,7 +186,7 @@ class InstanceSupplierCodeGeneratorTests { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); beanFactory.registerBeanDefinition("config", BeanDefinitionBuilder .genericBeanDefinition(SimpleConfiguration.class).getBeanDefinition()); - testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + compile(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { String bean = getBean(beanFactory, beanDefinition, instanceSupplier); assertThat(bean).isInstanceOf(String.class); assertThat(bean).isEqualTo("Hello"); @@ -212,7 +206,7 @@ class InstanceSupplierCodeGeneratorTests { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); beanFactory.registerBeanDefinition("config", BeanDefinitionBuilder .genericBeanDefinition(SimpleConfiguration.class).getBeanDefinition()); - testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + compile(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { String bean = getBean(beanFactory, beanDefinition, instanceSupplier); assertThat(bean).isInstanceOf(String.class); assertThat(bean).isEqualTo("Hello"); @@ -231,7 +225,7 @@ class InstanceSupplierCodeGeneratorTests { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); beanFactory.registerBeanDefinition("config", BeanDefinitionBuilder .genericBeanDefinition(SimpleConfiguration.class).getBeanDefinition()); - testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + compile(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { Integer bean = getBean(beanFactory, beanDefinition, instanceSupplier); assertThat(bean).isInstanceOf(Integer.class); assertThat(bean).isEqualTo(42); @@ -254,7 +248,7 @@ class InstanceSupplierCodeGeneratorTests { .genericBeanDefinition(SampleFactory.class).getBeanDefinition()); beanFactory.registerSingleton("number", 42); beanFactory.registerSingleton("string", "test"); - testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + compile(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { String bean = getBean(beanFactory, beanDefinition, instanceSupplier); assertThat(bean).isInstanceOf(String.class); assertThat(bean).isEqualTo("42test"); @@ -273,7 +267,7 @@ class InstanceSupplierCodeGeneratorTests { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); beanFactory.registerBeanDefinition("config", BeanDefinitionBuilder .genericBeanDefinition(SimpleConfiguration.class).getBeanDefinition()); - testCompiledResult(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { + compile(beanFactory, beanDefinition, (instanceSupplier, compiled) -> { Integer bean = getBean(beanFactory, beanDefinition, instanceSupplier); assertThat(bean).isInstanceOf(Integer.class); assertThat(bean).isEqualTo(42); @@ -308,41 +302,30 @@ class InstanceSupplierCodeGeneratorTests { } @SuppressWarnings("unchecked") - private void testCompiledResult(DefaultListableBeanFactory beanFactory, + private void compile(DefaultListableBeanFactory beanFactory, BeanDefinition beanDefinition, BiConsumer, Compiled> result) { - this.generationContext.writeGeneratedContent(); - DefaultListableBeanFactory registrationBeanFactory = new DefaultListableBeanFactory( - beanFactory); - registrationBeanFactory.registerBeanDefinition("testBean", beanDefinition); - RegisteredBean registeredBean = RegisteredBean.of(registrationBeanFactory, - "testBean"); - GeneratedMethods generatedMethods = new GeneratedMethods(); + DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(beanFactory); + freshBeanFactory.registerBeanDefinition("testBean", beanDefinition); + RegisteredBean registeredBean = RegisteredBean.of(freshBeanFactory, "testBean"); + DeferredTypeBuilder typeBuilder = new DeferredTypeBuilder(); + GeneratedClass generateClass = this.generationContext.getGeneratedClasses().addForFeature("TestCode", typeBuilder); InstanceSupplierCodeGenerator generator = new InstanceSupplierCodeGenerator( - this.generationContext, this.className, generatedMethods, - this.allowDirectSupplierShortcut); - Executable constructorOrFactoryMethod = ConstructorOrFactoryMethodResolver - .resolve(registeredBean); - CodeBlock generatedCode = generator.generateCode(registeredBean, - constructorOrFactoryMethod); - JavaFile javaFile = createJavaFile(generatedCode, generatedMethods); - TestCompiler.forSystem().withFiles(this.generatedFiles).compile(javaFile::writeTo, - compiled -> result.accept( - (InstanceSupplier) compiled.getInstance(Supplier.class).get(), - compiled)); - } - - private JavaFile createJavaFile(CodeBlock generatedCode, - GeneratedMethods generatedMethods) { - TypeSpec.Builder builder = TypeSpec.classBuilder("InstanceSupplierSupplier"); - builder.addModifiers(Modifier.PUBLIC); - builder.addSuperinterface( - ParameterizedTypeName.get(Supplier.class, InstanceSupplier.class)); - builder.addMethod(MethodSpec.methodBuilder("get").addModifiers(Modifier.PUBLIC) - .returns(InstanceSupplier.class).addStatement("return $L", generatedCode) - .build()); - generatedMethods.doWithMethodSpecs(builder::addMethod); - return JavaFile.builder("__", builder.build()).build(); + this.generationContext, generateClass.getName(), + generateClass.getMethods(), this.allowDirectSupplierShortcut); + Executable constructorOrFactoryMethod = ConstructorOrFactoryMethodResolver.resolve(registeredBean); + CodeBlock generatedCode = generator.generateCode(registeredBean, constructorOrFactoryMethod); + typeBuilder.set(type -> { + type.addModifiers(Modifier.PUBLIC); + type.addSuperinterface(ParameterizedTypeName.get(Supplier.class, InstanceSupplier.class)); + type.addMethod(MethodSpec.methodBuilder("get") + .addModifiers(Modifier.PUBLIC) + .returns(InstanceSupplier.class) + .addStatement("return $L", generatedCode).build()); + }); + this.generationContext.writeGeneratedContent(); + TestCompiler.forSystem().withFiles(this.generatedFiles).compile(compiled -> + result.accept((InstanceSupplier) compiled.getInstance(Supplier.class).get(), compiled)); } } diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/DeferredTypeBuilder.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/DeferredTypeBuilder.java new file mode 100644 index 00000000000..18459ad256e --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/DeferredTypeBuilder.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans.factory.aot; + +import java.util.function.Consumer; + +import org.springframework.javapoet.TypeSpec; +import org.springframework.util.Assert; + +/** + * {@link TypeSpec.Builder} {@link Consumer} that can be used to defer the to + * another consumer that is set at a later point. + * + * @author Phillip Webb + * @since 6.0 + */ +public class DeferredTypeBuilder implements Consumer { + + private Consumer type; + + @Override + public void accept(TypeSpec.Builder type) { + Assert.notNull(this.type, "No type builder set"); + this.type.accept(type); + } + + public void set(Consumer type) { + this.type = type; + } + +} diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java index b931cccdeb9..7d1eb4a237b 100644 --- a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java @@ -20,7 +20,9 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; @@ -28,16 +30,33 @@ import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; * Mock {@link BeanFactoryInitializationCode} implementation. * * @author Stephane Nicoll + * @author Phillip Webb */ public class MockBeanFactoryInitializationCode implements BeanFactoryInitializationCode { - private final GeneratedMethods generatedMethods = new GeneratedMethods(); + private final GeneratedClass generatedClass; private final List initializers = new ArrayList<>(); + private final DeferredTypeBuilder typeBuilder = new DeferredTypeBuilder(); + + + public MockBeanFactoryInitializationCode(GenerationContext generationContext) { + this.generatedClass = generationContext.getGeneratedClasses().addForFeature("TestCode", typeBuilder); + } + + + public DeferredTypeBuilder getTypeBuilder() { + return typeBuilder; + } + + public GeneratedClass getGeneratedClass() { + return generatedClass; + } + @Override - public GeneratedMethods getMethodGenerator() { - return this.generatedMethods; + public GeneratedMethods getMethods() { + return this.generatedClass.getMethods(); } @Override diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanRegistrationCode.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanRegistrationCode.java index 253547b3d17..9e7197afc38 100644 --- a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanRegistrationCode.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanRegistrationCode.java @@ -20,7 +20,9 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.aot.BeanRegistrationCode; import org.springframework.javapoet.ClassName; @@ -33,28 +35,34 @@ import org.springframework.javapoet.ClassName; */ public class MockBeanRegistrationCode implements BeanRegistrationCode { - private final ClassName className; - - private final GeneratedMethods generatedMethods = new GeneratedMethods(); + private final GeneratedClass generatedClass; private final List instancePostProcessors = new ArrayList<>(); - public MockBeanRegistrationCode(ClassName className) { - this.className = className; + private final DeferredTypeBuilder typeBuilder = new DeferredTypeBuilder(); + + + public MockBeanRegistrationCode(GenerationContext generationContext) { + this.generatedClass = generationContext.getGeneratedClasses().addForFeature("TestCode", this.typeBuilder); } - public MockBeanRegistrationCode() { - this(ClassName.get("com.example", "Test")); + + public DeferredTypeBuilder getTypeBuilder() { + return typeBuilder; + } + + public GeneratedClass getGeneratedClass() { + return this.generatedClass; } @Override public ClassName getClassName() { - return this.className; + return this.generatedClass.getName(); } @Override - public GeneratedMethods getMethodGenerator() { - return this.generatedMethods; + public GeneratedMethods getMethods() { + return this.generatedClass.getMethods(); } @Override diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanRegistrationsCode.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanRegistrationsCode.java index 1c123176166..8849d243619 100644 --- a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanRegistrationsCode.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanRegistrationsCode.java @@ -16,7 +16,9 @@ package org.springframework.beans.testfixture.beans.factory.aot; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.generate.GenerationContext; import org.springframework.beans.factory.aot.BeanRegistrationsCode; import org.springframework.javapoet.ClassName; @@ -28,26 +30,32 @@ import org.springframework.javapoet.ClassName; */ public class MockBeanRegistrationsCode implements BeanRegistrationsCode { - private final ClassName className; + private final GeneratedClass generatedClass; - private final GeneratedMethods generatedMethods = new GeneratedMethods(); + private final DeferredTypeBuilder typeBuilder = new DeferredTypeBuilder(); - public MockBeanRegistrationsCode(ClassName className) { - this.className = className; + + public MockBeanRegistrationsCode(GenerationContext generationContext) { + this.generatedClass = generationContext.getGeneratedClasses().addForFeature("TestCode", this.typeBuilder); } - public MockBeanRegistrationsCode() { - this(ClassName.get("com.example", "Test")); + + public DeferredTypeBuilder getTypeBuilder() { + return typeBuilder; + } + + public GeneratedClass getGeneratedClass() { + return this.generatedClass; } @Override public ClassName getClassName() { - return this.className; + return this.generatedClass.getName(); } @Override - public GeneratedMethods getMethodGenerator() { - return this.generatedMethods; + public GeneratedMethods getMethods() { + return this.generatedClass.getMethods(); } } diff --git a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java index aa8b9e51c80..c106cb81623 100644 --- a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java +++ b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java @@ -534,10 +534,9 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo Map mappings = buildImportAwareMappings(); if (!mappings.isEmpty()) { GeneratedMethod generatedMethod = beanFactoryInitializationCode - .getMethodGenerator() - .generateMethod("addImportAwareBeanPostProcessors") - .using(builder -> generateAddPostProcessorMethod(builder, - mappings)); + .getMethods() + .add("addImportAwareBeanPostProcessors", method -> + generateAddPostProcessorMethod(method, mappings)); beanFactoryInitializationCode .addInitializer(MethodReference.of(generatedMethod.getName())); ResourceHints hints = generationContext.getRuntimeHints().resources(); @@ -546,14 +545,14 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo } } - private void generateAddPostProcessorMethod(MethodSpec.Builder builder, + private void generateAddPostProcessorMethod(MethodSpec.Builder method, Map mappings) { - builder.addJavadoc( + method.addJavadoc( "Add ImportAwareBeanPostProcessor to support ImportAware beans"); - builder.addModifiers(Modifier.PRIVATE); - builder.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_VARIABLE); - builder.addCode(generateAddPostProcessorCode(mappings)); + method.addModifiers(Modifier.PRIVATE); + method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_VARIABLE); + method.addCode(generateAddPostProcessorCode(mappings)); } private CodeBlock generateAddPostProcessorCode(Map mappings) { diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java index f5a450d1e6d..c958d1bc53b 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java @@ -16,7 +16,6 @@ package org.springframework.context.aot; -import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GenerationContext; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.support.DefaultListableBeanFactory; @@ -47,15 +46,11 @@ public class ApplicationContextAotGenerator { public ClassName generateApplicationContext(GenericApplicationContext applicationContext, GenerationContext generationContext) { applicationContext.refreshForAotProcessing(); - DefaultListableBeanFactory beanFactory = applicationContext - .getDefaultListableBeanFactory(); - ApplicationContextInitializationCodeGenerator codeGenerator = new ApplicationContextInitializationCodeGenerator(); - new BeanFactoryInitializationAotContributions(beanFactory).applyTo(generationContext, - codeGenerator); - GeneratedClass applicationContextInitializer = generationContext.getGeneratedClasses() - .forFeature("ApplicationContextInitializer") - .generate(codeGenerator.generateJavaFile()); - return applicationContextInitializer.getName(); + DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory(); + ApplicationContextInitializationCodeGenerator codeGenerator = + new ApplicationContextInitializationCodeGenerator(generationContext); + new BeanFactoryInitializationAotContributions(beanFactory).applyTo(generationContext, codeGenerator); + return codeGenerator.getGeneratedClass().getName(); } } diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java index fe532edda20..e22160f7174 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java @@ -18,12 +18,12 @@ package org.springframework.context.aot; import java.util.ArrayList; import java.util.List; -import java.util.function.Consumer; import javax.lang.model.element.Modifier; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethods; -import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; import org.springframework.beans.factory.support.DefaultListableBeanFactory; @@ -44,45 +44,41 @@ import org.springframework.javapoet.TypeSpec; class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitializationCode { + private static final String INITIALIZE_METHOD = "initialize"; + private static final String APPLICATION_CONTEXT_VARIABLE = "applicationContext"; - - private final GeneratedMethods generatedMethods = new GeneratedMethods(); - private final List initializers = new ArrayList<>(); + private final GeneratedClass generatedClass; - @Override - public MethodGenerator getMethodGenerator() { - return this.generatedMethods; + + ApplicationContextInitializationCodeGenerator(GenerationContext generationContext) { + this.generatedClass = generationContext.getGeneratedClasses() + .addForFeature("ApplicationContextInitializer", this::generateType); + this.generatedClass.reserveMethodNames(INITIALIZE_METHOD); } - @Override - public void addInitializer(MethodReference methodReference) { - this.initializers.add(methodReference); + + private void generateType(TypeSpec.Builder type) { + type.addJavadoc( + "{@link $T} to restore an application context based on previous AOT processing.", + ApplicationContextInitializer.class); + type.addModifiers(Modifier.PUBLIC); + type.addSuperinterface(ParameterizedTypeName.get( + ApplicationContextInitializer.class, GenericApplicationContext.class)); + type.addMethod(generateInitializeMethod()); } - Consumer generateJavaFile() { - return builder -> { - builder.addJavadoc( - "{@link $T} to restore an application context based on previous AOT processing.", - ApplicationContextInitializer.class); - builder.addModifiers(Modifier.PUBLIC); - builder.addSuperinterface(ParameterizedTypeName.get( - ApplicationContextInitializer.class, GenericApplicationContext.class)); - builder.addMethod(generateInitializeMethod()); - this.generatedMethods.doWithMethodSpecs(builder::addMethod); - }; - } private MethodSpec generateInitializeMethod() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("initialize"); - builder.addAnnotation(Override.class); - builder.addModifiers(Modifier.PUBLIC); - builder.addParameter(GenericApplicationContext.class, + MethodSpec.Builder method = MethodSpec.methodBuilder(INITIALIZE_METHOD); + method.addAnnotation(Override.class); + method.addModifiers(Modifier.PUBLIC); + method.addParameter(GenericApplicationContext.class, APPLICATION_CONTEXT_VARIABLE); - builder.addCode(generateInitializeCode()); - return builder.build(); + method.addCode(generateInitializeCode()); + return method.build(); } private CodeBlock generateInitializeCode() { @@ -93,10 +89,23 @@ class ApplicationContextInitializationCodeGenerator builder.addStatement("$L.setAutowireCandidateResolver(new $T())", BEAN_FACTORY_VARIABLE, ContextAnnotationAutowireCandidateResolver.class); for (MethodReference initializer : this.initializers) { - builder.addStatement( - initializer.toInvokeCodeBlock(CodeBlock.of(BEAN_FACTORY_VARIABLE))); + builder.addStatement(initializer.toInvokeCodeBlock(CodeBlock.of(BEAN_FACTORY_VARIABLE))); } return builder.build(); } + GeneratedClass getGeneratedClass() { + return this.generatedClass; + } + + @Override + public GeneratedMethods getMethods() { + return this.generatedClass.getMethods(); + } + + @Override + public void addInitializer(MethodReference methodReference) { + this.initializers.add(methodReference); + } + } diff --git a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java index 393315fd482..d22f648209b 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java @@ -39,10 +39,8 @@ import org.springframework.context.testfixture.context.generator.annotation.Impo import org.springframework.context.testfixture.context.generator.annotation.ImportConfiguration; import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; -import org.springframework.javapoet.TypeSpec; import org.springframework.lang.Nullable; import static org.assertj.core.api.Assertions.assertThat; @@ -56,21 +54,26 @@ import static org.assertj.core.api.Assertions.entry; */ class ConfigurationClassPostProcessorAotContributionTests { - private DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + private final InMemoryGeneratedFiles generatedFiles; - private InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); + private final DefaultGenerationContext generationContext; - private DefaultGenerationContext generationContext = new TestGenerationContext( - this.generatedFiles); + private final MockBeanFactoryInitializationCode beanFactoryInitializationCode; + + + ConfigurationClassPostProcessorAotContributionTests() { + this.generatedFiles = new InMemoryGeneratedFiles(); + this.generationContext = new TestGenerationContext(this.generatedFiles); + this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(this.generationContext); + } - private MockBeanFactoryInitializationCode beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); @Test void applyToWhenHasImportAwareConfigurationRegistersBeanPostProcessorWithMapEntry() { BeanFactoryInitializationAotContribution contribution = getContribution( ImportConfiguration.class); contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); - testCompiledResult((initializer, compiled) -> { + compile((initializer, compiled) -> { DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); initializer.accept(freshBeanFactory); ImportAwareAotBeanPostProcessor postProcessor = (ImportAwareAotBeanPostProcessor) freshBeanFactory @@ -89,8 +92,8 @@ class ConfigurationClassPostProcessorAotContributionTests { .singleElement() .satisfies(resourceHint -> assertThat(resourceHint.getIncludes()) .map(ResourcePatternHint::getPattern) - .containsOnly( - "org/springframework/context/testfixture/context/generator/annotation/ImportConfiguration.class")); + .containsOnly("org/springframework/context/testfixture/context/generator/annotation/" + + "ImportConfiguration.class")); } @Test @@ -100,38 +103,28 @@ class ConfigurationClassPostProcessorAotContributionTests { @Nullable private BeanFactoryInitializationAotContribution getContribution(Class type) { - this.beanFactory.registerBeanDefinition("configuration", - new RootBeanDefinition(type)); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("configuration", new RootBeanDefinition(type)); ConfigurationClassPostProcessor postProcessor = new ConfigurationClassPostProcessor(); - postProcessor.postProcessBeanFactory(this.beanFactory); - return postProcessor.processAheadOfTime(this.beanFactory); + postProcessor.postProcessBeanFactory(beanFactory); + return postProcessor.processAheadOfTime(beanFactory); } @SuppressWarnings("unchecked") - private void testCompiledResult( - BiConsumer, Compiled> result) { - JavaFile javaFile = createJavaFile(); + private void compile(BiConsumer, Compiled> result) { + MethodReference methodReference = this.beanFactoryInitializationCode + .getInitializers().get(0); + this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + type.addModifiers(Modifier.PUBLIC); + type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); + type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) + .addParameter(DefaultListableBeanFactory.class, "beanFactory") + .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .build()); + }); this.generationContext.writeGeneratedContent(); - TestCompiler.forSystem().withFiles(this.generatedFiles).compile(javaFile::writeTo, - compiled -> result.accept(compiled.getInstance(Consumer.class), - compiled)); - } - - private JavaFile createJavaFile() { - MethodReference methodReference = this.beanFactoryInitializationCode.getInitializers() - .get(0); - TypeSpec.Builder builder = TypeSpec.classBuilder("TestConsumer"); - builder.addModifiers(Modifier.PUBLIC); - builder.addSuperinterface(ParameterizedTypeName.get(Consumer.class, - DefaultListableBeanFactory.class)); - builder.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) - .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement( - methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) - .build()); - this.beanFactoryInitializationCode.getMethodGenerator() - .doWithMethodSpecs(builder::addMethod); - return JavaFile.builder("__", builder.build()).build(); + TestCompiler.forSystem().withFiles(this.generatedFiles).compile(compiled -> + result.accept(compiled.getInstance(Consumer.class), compiled)); } private void assertPostProcessorEntry(ImportAwareAotBeanPostProcessor postProcessor, diff --git a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java index 30e30f34164..a40c5b68fbf 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java @@ -188,14 +188,11 @@ class ApplicationContextAotGeneratorTests { BiConsumer, Compiled> result) { ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); - DefaultGenerationContext generationContext = new TestGenerationContext( - generatedFiles); + DefaultGenerationContext generationContext = new TestGenerationContext(generatedFiles); generator.generateApplicationContext(applicationContext, generationContext); generationContext.writeGeneratedContent(); - TestCompiler.forSystem().withFiles(generatedFiles) - .compile(compiled -> result.accept( - compiled.getInstance(ApplicationContextInitializer.class), - compiled)); + TestCompiler.forSystem().withFiles(generatedFiles).compile(compiled -> + result.accept(compiled.getInstance(ApplicationContextInitializer.class), compiled)); } private GenericApplicationContext toFreshApplicationContext( diff --git a/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java index 5f2dd666c91..4c63acbcc4e 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java @@ -25,6 +25,7 @@ import java.lang.annotation.Target; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; @@ -53,7 +54,7 @@ class ReflectiveProcessorBeanRegistrationAotProcessorTests { private final ReflectiveProcessorBeanRegistrationAotProcessor processor = new ReflectiveProcessorBeanRegistrationAotProcessor(); - private final GenerationContext generationContext = new TestGenerationContext(); + private final GenerationContext generationContext = new TestGenerationContext(new InMemoryGeneratedFiles()); @Test void shouldIgnoreNonAnnotatedType() { diff --git a/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java index 16dd7185fb0..884a2e9943d 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java @@ -26,6 +26,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.hint.ResourceBundleHint; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; @@ -55,7 +56,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { @BeforeEach void setup() { - this.generationContext = new TestGenerationContext(); + this.generationContext = new TestGenerationContext(new InMemoryGeneratedFiles()); this.generator = new ApplicationContextAotGenerator(); } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/ClassNameGenerator.java b/spring-core/src/main/java/org/springframework/aot/generate/ClassNameGenerator.java index 02d4d6bccb5..1758abaf0f4 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/ClassNameGenerator.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/ClassNameGenerator.java @@ -47,6 +47,7 @@ public final class ClassNameGenerator { private final Map sequenceGenerator; + /** * Create a new instance using the specified {@code defaultTarget} and no * feature name prefix. @@ -68,11 +69,15 @@ public final class ClassNameGenerator { private ClassNameGenerator(Class defaultTarget, String featureNamePrefix, Map sequenceGenerator) { + Assert.notNull(defaultTarget, "'defaultTarget' must not be null"); this.defaultTarget = defaultTarget; this.featureNamePrefix = (!StringUtils.hasText(featureNamePrefix) ? "" : featureNamePrefix); this.sequenceGenerator = sequenceGenerator; } + String getFeatureNamePrefix() { + return this.featureNamePrefix; + } /** * Generate a unique {@link ClassName} based on the specified @@ -85,46 +90,22 @@ public final class ClassNameGenerator { * if any. *

Generated class names are unique. If such a feature was already * requested for this target, a counter is used to ensure uniqueness. - * @param target the class the newly generated class relates to, or - * {@code null} to use the main target * @param featureName the name of the feature that the generated class * supports + * @param target the class the newly generated class relates to, or + * {@code null} to use the main target * @return a unique generated class name */ - public ClassName generateClassName(@Nullable Class target, String featureName) { - return generateSequencedClassName(getClassName(target, featureName)); + public ClassName generateClassName(String featureName, @Nullable Class target) { + return generateSequencedClassName(getRootName(featureName, target)); } - /** - * Return a class name based on the specified {@code target} and - * {@code featureName}. This uses the same algorithm as - * {@link #generateClassName(Class, String)} but does not register - * the class name, nor add a unique suffix to it if necessary. - * @param target the class the newly generated class relates to, or - * {@code null} to use the main target - * @param featureName the name of the feature that the generated class - * supports - * @return the class name - */ - String getClassName(@Nullable Class target, String featureName) { + private String getRootName(String featureName, @Nullable Class target) { Assert.hasLength(featureName, "'featureName' must not be empty"); featureName = clean(featureName); Class targetToUse = (target != null ? target : this.defaultTarget); String featureNameToUse = this.featureNamePrefix + featureName; - return targetToUse.getName().replace("$", "_") - + SEPARATOR + StringUtils.capitalize(featureNameToUse); - } - - /** - * Return a new {@link ClassNameGenerator} instance for the specified - * feature name prefix, keeping track of all the class names generated - * by this instance. - * @param featureNamePrefix the feature name prefix to use - * @return a new instance for the specified feature name prefix - */ - ClassNameGenerator usingFeatureNamePrefix(String featureNamePrefix) { - return new ClassNameGenerator(this.defaultTarget, featureNamePrefix, - this.sequenceGenerator); + return targetToUse.getName().replace("$", "_") + SEPARATOR + StringUtils.capitalize(featureNameToUse); } private String clean(String name) { @@ -142,15 +123,26 @@ public final class ClassNameGenerator { } private ClassName generateSequencedClassName(String name) { - name = addSequence(name); + int sequence = this.sequenceGenerator.computeIfAbsent(name, key -> + new AtomicInteger()).getAndIncrement(); + if (sequence > 0) { + name = name + sequence; + } return ClassName.get(ClassUtils.getPackageName(name), ClassUtils.getShortName(name)); } - private String addSequence(String name) { - int sequence = this.sequenceGenerator - .computeIfAbsent(name, key -> new AtomicInteger()).getAndIncrement(); - return (sequence > 0) ? name + sequence : name; + + /** + * Return a new {@link ClassNameGenerator} instance for the specified + * feature name prefix, keeping track of all the class names generated + * by this instance. + * @param featureNamePrefix the feature name prefix to use + * @return a new instance for the specified feature name prefix + */ + ClassNameGenerator withFeatureNamePrefix(String featureNamePrefix) { + return new ClassNameGenerator(this.defaultTarget, featureNamePrefix, + this.sequenceGenerator); } } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/DefaultGenerationContext.java b/spring-core/src/main/java/org/springframework/aot/generate/DefaultGenerationContext.java index dd500d0b84f..f9fba39af1b 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/DefaultGenerationContext.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/DefaultGenerationContext.java @@ -74,9 +74,9 @@ public class DefaultGenerationContext implements GenerationContext { private DefaultGenerationContext(DefaultGenerationContext existing, String name) { int sequence = existing.sequenceGenerator .computeIfAbsent(name, key -> new AtomicInteger()).getAndIncrement(); - String nameToUse = (sequence > 0 ? name + sequence : name); + String featureName = (sequence > 0 ? name + sequence : name); this.sequenceGenerator = existing.sequenceGenerator; - this.generatedClasses = existing.generatedClasses.withName(nameToUse); + this.generatedClasses = existing.generatedClasses.withFeatureNamePrefix(featureName); this.generatedFiles = existing.generatedFiles; this.runtimeHints = existing.runtimeHints; } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java index 2776dfc9fe3..be591208fff 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java @@ -16,15 +16,18 @@ package org.springframework.aot.generate; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.TypeSpec; -import org.springframework.javapoet.TypeSpec.Builder; +import org.springframework.util.Assert; /** - * A generated class is a container for generated methods. + * A single generated class. * * @author Phillip Webb * @author Stephane Nicoll @@ -33,26 +36,49 @@ import org.springframework.javapoet.TypeSpec.Builder; */ public final class GeneratedClass { - private final Consumer typeSpecCustomizer; - private final ClassName name; private final GeneratedMethods methods; + private final Consumer type; + + private final Map methodNameSequenceGenerator = new ConcurrentHashMap<>(); + /** * Create a new {@link GeneratedClass} instance with the given name. This * constructor is package-private since names should only be generated via a * {@link GeneratedClasses}. * @param name the generated name + * @param type a {@link Consumer} used to build the type */ - GeneratedClass(Consumer typeSpecCustomizer, ClassName name) { - this.typeSpecCustomizer = typeSpecCustomizer; + GeneratedClass(ClassName name, Consumer type) { this.name = name; - this.methods = new GeneratedMethods(new MethodNameGenerator()); + this.type = type; + this.methods = new GeneratedMethods(this::generateSequencedMethodName); } + /** + * Update this instance with a set of reserved method names that should not + * be used for generated methods. Reserved names are often needed when a + * generated class implements a specific interface. + * @param reservedMethodNames the reserved method names + */ + public void reserveMethodNames(String... reservedMethodNames) { + for (String reservedMethodName : reservedMethodNames) { + String generatedName = generateSequencedMethodName(MethodName.of(reservedMethodNames)); + Assert.state(generatedName.equals(reservedMethodName), + () -> String.format("Unable to reserve method name '%s'", reservedMethodName)); + } + } + + private String generateSequencedMethodName(MethodName name) { + int sequence = this.methodNameSequenceGenerator + .computeIfAbsent(name, key -> new AtomicInteger()).getAndIncrement(); + return (sequence > 0) ? name.toString() + sequence : name.toString(); + } + /** * Return the name of the generated class. * @return the name of the generated class @@ -62,18 +88,28 @@ public final class GeneratedClass { } /** - * Return the method generator that can be used for this generated class. - * @return the method generator + * Return generated methods for this instance. + * @return the generated methods */ - public MethodGenerator getMethodGenerator() { + public GeneratedMethods getMethods() { return this.methods; } JavaFile generateJavaFile() { - TypeSpec.Builder typeSpecBuilder = TypeSpec.classBuilder(this.name); - this.typeSpecCustomizer.accept(typeSpecBuilder); - this.methods.doWithMethodSpecs(typeSpecBuilder::addMethod); - return JavaFile.builder(this.name.packageName(), typeSpecBuilder.build()).build(); + TypeSpec.Builder type = getBuilder(this.type); + this.methods.doWithMethodSpecs(type::addMethod); + return JavaFile.builder(this.name.packageName(), type.build()).build(); + } + + private TypeSpec.Builder getBuilder(Consumer type) { + TypeSpec.Builder builder = TypeSpec.classBuilder(this.name); + type.accept(builder); + return builder; + } + + void assertSameType(Consumer type) { + Assert.state(type == this.type || getBuilder(this.type).build().equals(getBuilder(type).build()), + "'type' consumer generated different result"); } } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClasses.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClasses.java index 0b2b622b98e..c7ebfedc70f 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClasses.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClasses.java @@ -46,11 +46,12 @@ public class GeneratedClasses { private final Map classesByOwner; + /** * Create a new instance using the specified naming conventions. * @param classNameGenerator the class name generator to use */ - public GeneratedClasses(ClassNameGenerator classNameGenerator) { + GeneratedClasses(ClassNameGenerator classNameGenerator) { this(classNameGenerator, new ArrayList<>(), new ConcurrentHashMap<>()); } @@ -62,29 +63,92 @@ public class GeneratedClasses { this.classesByOwner = classesByOwner; } + /** - * Prepare a {@link GeneratedClass} for the specified {@code featureName} - * targeting the specified {@code component}. - * @param featureName the name of the feature to associate with the generated class - * @param component the target component - * @return a {@link Builder} for further configuration + * Get or add a generated class for the specified {@code featureName} and no + * particular component. If this method has previously been called with the + * given {@code featureName} the existing class will be returned, otherwise + * a new class will be generated. + * @param featureName the name of the feature to associate with the + * generated class + * @param type a {@link Consumer} used to build the type + * @return an existing or newly generated class */ - public Builder forFeatureComponent(String featureName, Class component) { + public GeneratedClass getOrAddForFeature(String featureName, + Consumer type) { + Assert.hasLength(featureName, "'featureName' must not be empty"); - Assert.notNull(component, "'component' must not be null"); - return new Builder(featureName, component); + Assert.notNull(type, "'type' must not be null"); + Owner owner = new Owner(this.classNameGenerator.getFeatureNamePrefix(), featureName, null); + GeneratedClass generatedClass = this.classesByOwner.computeIfAbsent(owner, key -> createAndAddGeneratedClass(featureName, null, type)); + generatedClass.assertSameType(type); + return generatedClass; } /** - * Prepare a {@link GeneratedClass} for the specified {@code featureName} - * and no particular component. This should be used for high-level code - * generation that are widely applicable and for entry points. - * @param featureName the name of the feature to associate with the generated class - * @return a {@link Builder} for further configuration + * Get or add a generated class for the specified {@code featureName} + * targeting the specified {@code component}. If this method has previously + * been called with the given {@code featureName}/{@code target} the + * existing class will be returned, otherwise a new class will be generated, + * otherwise a new class will be generated. + * @param featureName the name of the feature to associate with the + * generated class + * @param targetComponent the target component + * @param type a {@link Consumer} used to build the type + * @return an existing or newly generated class */ - public Builder forFeature(String featureName) { + public GeneratedClass getOrAddForFeatureComponent(String featureName, + Class targetComponent, Consumer type) { + Assert.hasLength(featureName, "'featureName' must not be empty"); - return new Builder(featureName, null); + Assert.notNull(targetComponent, "'targetComponent' must not be null"); + Assert.notNull(type, "'type' must not be null"); + Owner owner = new Owner(this.classNameGenerator.getFeatureNamePrefix(), featureName, targetComponent); + GeneratedClass generatedClass = this.classesByOwner.computeIfAbsent(owner, key -> + createAndAddGeneratedClass(featureName, targetComponent, type)); + generatedClass.assertSameType(type); + return generatedClass; + } + + /** + * Add a new generated class for the specified {@code featureName} and no + * particular component. + * @param featureName the name of the feature to associate with the + * generated class + * @param type a {@link Consumer} used to build the type + * @return the newly generated class + */ + public GeneratedClass addForFeature(String featureName, Consumer type) { + Assert.hasLength(featureName, "'featureName' must not be empty"); + Assert.notNull(type, "'type' must not be null"); + return createAndAddGeneratedClass(featureName, null, type); + } + + /** + * Add a new generated class for the specified {@code featureName} targeting + * the specified {@code component}. + * @param featureName the name of the feature to associate with the + * generated class + * @param targetComponent the target component + * @param type a {@link Consumer} used to build the type + * @return the newly generated class + */ + public GeneratedClass addForFeatureComponent(String featureName, + Class targetComponent, Consumer type) { + + Assert.hasLength(featureName, "'featureName' must not be empty"); + Assert.notNull(targetComponent, "'targetComponent' must not be null"); + Assert.notNull(type, "'type' must not be null"); + return createAndAddGeneratedClass(featureName, targetComponent, type); + } + + private GeneratedClass createAndAddGeneratedClass(String featureName, + @Nullable Class targetComponent, Consumer type) { + + ClassName className = this.classNameGenerator.generateClassName(featureName, targetComponent); + GeneratedClass generatedClass = new GeneratedClass(className, type); + this.classes.add(generatedClass); + return generatedClass; } /** @@ -93,7 +157,7 @@ public class GeneratedClasses { * @param generatedFiles where to write the generated classes * @throws IOException on IO error */ - public void writeTo(GeneratedFiles generatedFiles) throws IOException { + void writeTo(GeneratedFiles generatedFiles) throws IOException { Assert.notNull(generatedFiles, "'generatedFiles' must not be null"); List generatedClasses = new ArrayList<>(this.classes); generatedClasses.sort(Comparator.comparing(GeneratedClass::getName)); @@ -102,62 +166,12 @@ public class GeneratedClasses { } } - GeneratedClasses withName(String name) { - return new GeneratedClasses(this.classNameGenerator.usingFeatureNamePrefix(name), + GeneratedClasses withFeatureNamePrefix(String name) { + return new GeneratedClasses(this.classNameGenerator.withFeatureNamePrefix(name), this.classes, this.classesByOwner); } - private record Owner(String id, String className) { - - } - - public class Builder { - - private final String featureName; - - @Nullable - private final Class target; - - - Builder(String featureName, @Nullable Class target) { - this.target = target; - this.featureName = featureName; - } - - /** - * Generate a new {@link GeneratedClass} using the specified type - * customizer. - * @param typeSpecCustomizer a customizer for the {@link TypeSpec.Builder} - * @return a new {@link GeneratedClass} - */ - public GeneratedClass generate(Consumer typeSpecCustomizer) { - Assert.notNull(typeSpecCustomizer, "'typeSpecCustomizer' must not be null"); - return createGeneratedClass(typeSpecCustomizer); - } - - - /** - * Get or generate a new {@link GeneratedClass} for the specified {@code id}. - * @param id a unique identifier - * @param typeSpecCustomizer a customizer for the {@link TypeSpec.Builder} - * @return a {@link GeneratedClass} instance - */ - public GeneratedClass getOrGenerate(String id, Consumer typeSpecCustomizer) { - Assert.hasLength(id, "'id' must not be empty"); - Assert.notNull(typeSpecCustomizer, "'typeSpecCustomizer' must not be null"); - Owner owner = new Owner(id, GeneratedClasses.this.classNameGenerator - .getClassName(this.target, this.featureName)); - return GeneratedClasses.this.classesByOwner.computeIfAbsent(owner, - key -> createGeneratedClass(typeSpecCustomizer)); - } - - private GeneratedClass createGeneratedClass(Consumer typeSpecCustomizer) { - ClassName className = GeneratedClasses.this.classNameGenerator - .generateClassName(this.target, this.featureName); - GeneratedClass generatedClass = new GeneratedClass(typeSpecCustomizer, className); - GeneratedClasses.this.classes.add(generatedClass); - return generatedClass; - } + private record Owner(String featureNamePrefix, String featureName, @Nullable Class target) { } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java index 51e1e035a14..6848bde9040 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java @@ -19,8 +19,6 @@ package org.springframework.aot.generate; import java.util.function.Consumer; import org.springframework.javapoet.MethodSpec; -import org.springframework.javapoet.MethodSpec.Builder; -import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -29,24 +27,28 @@ import org.springframework.util.Assert; * @author Phillip Webb * @since 6.0 * @see GeneratedMethods - * @see MethodGenerator */ public final class GeneratedMethod { private final String name; - @Nullable - private MethodSpec spec; + private final MethodSpec methodSpec; /** * Create a new {@link GeneratedMethod} instance with the given name. This * constructor is package-private since names should only be generated via * {@link GeneratedMethods}. - * @param name the generated name + * @param name the generated method name + * @param method consumer to generate the method */ - GeneratedMethod(String name) { + GeneratedMethod(String name, Consumer method) { this.name = name; + MethodSpec.Builder builder = MethodSpec.methodBuilder(getName()); + method.accept(builder); + this.methodSpec = builder.build(); + Assert.state(this.name.equals(this.methodSpec.name), + "'method' consumer must not change the generated method name"); } @@ -64,35 +66,13 @@ public final class GeneratedMethod { * @throws IllegalStateException if one of the {@code generateBy(...)} * methods has not been called */ - public MethodSpec getSpec() { - Assert.state(this.spec != null, - () -> "Method '%s' has no method spec defined".formatted(this.name)); - return this.spec; - } - - /** - * Generate the method using the given consumer. - * @param builder a consumer that will accept a method spec builder and - * configure it as necessary - * @return this instance - */ - public GeneratedMethod using(Consumer builder) { - Builder builderToUse = MethodSpec.methodBuilder(this.name); - builder.accept(builderToUse); - MethodSpec spec = builderToUse.build(); - assertNameHasNotBeenChanged(spec); - this.spec = spec; - return this; - } - - private void assertNameHasNotBeenChanged(MethodSpec spec) { - Assert.isTrue(this.name.equals(spec.name), - () -> "'spec' must use the generated name '%s'".formatted(this.name)); + MethodSpec getMethodSpec() { + return this.methodSpec; } @Override public String toString() { - return (this.spec != null) ? this.spec.toString() : this.name.toString(); + return this.name.toString(); } } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethods.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethods.java index 27d391e8dc9..02069976d33 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethods.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethods.java @@ -17,12 +17,13 @@ package org.springframework.aot.generate; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; import java.util.function.Consumer; +import java.util.function.Function; import java.util.stream.Stream; import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.MethodSpec.Builder; import org.springframework.util.Assert; /** @@ -32,49 +33,63 @@ import org.springframework.util.Assert; * @since 6.0 * @see GeneratedMethod */ -public class GeneratedMethods implements Iterable, MethodGenerator { +public class GeneratedMethods { - private final MethodNameGenerator methodNameGenerator; + private final Function methodNameGenerator; - private final List generatedMethods = new ArrayList<>(); + private final MethodName prefix; - - /** - * Create a new {@link GeneratedMethods} instance backed by a new - * {@link MethodNameGenerator}. - */ - public GeneratedMethods() { - this(new MethodNameGenerator()); - } + private final List generatedMethods; /** * Create a new {@link GeneratedMethods} instance backed by the given * {@link MethodNameGenerator}. * @param methodNameGenerator the method name generator */ - public GeneratedMethods(MethodNameGenerator methodNameGenerator) { + GeneratedMethods(Function methodNameGenerator) { Assert.notNull(methodNameGenerator, "'methodNameGenerator' must not be null"); this.methodNameGenerator = methodNameGenerator; + this.prefix = MethodName.NONE; + this.generatedMethods = new ArrayList<>(); } + private GeneratedMethods(Function methodNameGenerator, + MethodName prefix, List generatedMethods) { - @Override - public GeneratedMethod generateMethod(Object... methodNameParts) { - return add(methodNameParts); + this.methodNameGenerator = methodNameGenerator; + this.prefix = prefix; + this.generatedMethods = generatedMethods; } /** - * Add a new {@link GeneratedMethod}. The returned instance must define the - * method spec by calling {@code using(builder -> ...)}. - * @param methodNameParts the method name parts that should be used to - * generate a unique method name + * Add a new {@link GeneratedMethod}. + * @param suggestedName the suggested name for the method + * @param method a {@link Consumer} used to build method * @return the newly added {@link GeneratedMethod} */ - public GeneratedMethod add(Object... methodNameParts) { - GeneratedMethod method = new GeneratedMethod( - this.methodNameGenerator.generateMethodName(methodNameParts)); - this.generatedMethods.add(method); - return method; + public GeneratedMethod add(String suggestedName, Consumer method) { + Assert.notNull(suggestedName, "'suggestedName' must not be null"); + return add(MethodName.of(suggestedName), method); + } + + /** + * Add a new {@link GeneratedMethod}. + * @param suggestedName the suggested name for the method + * @param method a {@link Consumer} used to build the method + * @return the newly added {@link GeneratedMethod} + */ + public GeneratedMethod add(MethodName suggestedName, Consumer method) { + Assert.notNull(suggestedName, "'suggestedName' must not be null"); + Assert.notNull(method, "'method' must not be null"); + String generatedName = this.methodNameGenerator.apply(this.prefix.and(suggestedName)); + GeneratedMethod generatedMethod = new GeneratedMethod(generatedName, method); + this.generatedMethods.add(generatedMethod); + return generatedMethod; + } + + public GeneratedMethods withPrefix(String prefix) { + Assert.notNull(prefix, "'prefix' must not be null"); + return new GeneratedMethods(this.methodNameGenerator, this.prefix.and(prefix), this.generatedMethods); } /** @@ -82,20 +97,11 @@ public class GeneratedMethods implements Iterable, MethodGenera * that have been added to this collection. * @param action the action to perform */ - public void doWithMethodSpecs(Consumer action) { - stream().map(GeneratedMethod::getSpec).forEach(action); + void doWithMethodSpecs(Consumer action) { + stream().map(GeneratedMethod::getMethodSpec).forEach(action); } - @Override - public Iterator iterator() { - return this.generatedMethods.iterator(); - } - - /** - * Return a {@link Stream} of all the methods in this collection. - * @return a stream of {@link GeneratedMethod} instances - */ - public Stream stream() { + Stream stream() { return this.generatedMethods.stream(); } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GenerationContext.java b/spring-core/src/main/java/org/springframework/aot/generate/GenerationContext.java index d74d97e98fa..e203e6ab078 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GenerationContext.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GenerationContext.java @@ -43,8 +43,7 @@ import org.springframework.aot.hint.SerializationHints; public interface GenerationContext { /** - * Return the {@link GeneratedClasses} being used by the context. Allows a - * single generated class to be shared across multiple AOT processors. All + * Return the {@link GeneratedClasses} being used by the context. All * generated classes are written at the end of AOT processing. * @return the generated classes */ diff --git a/spring-core/src/main/java/org/springframework/aot/generate/MethodGenerator.java b/spring-core/src/main/java/org/springframework/aot/generate/MethodGenerator.java deleted file mode 100644 index bb782ecef21..00000000000 --- a/spring-core/src/main/java/org/springframework/aot/generate/MethodGenerator.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -/** - * Generates new {@link GeneratedMethod} instances. - * - * @author Phillip Webb - * @since 6.0 - * @see GeneratedMethods - */ -@FunctionalInterface -public interface MethodGenerator { - - /** - * Generate a new {@link GeneratedMethod}. The returned instance must define - * the method spec by calling {@code using(builder -> ...)}. - * @param methodNameParts the method name parts that should be used to - * generate a unique method name - * @return the newly added {@link GeneratedMethod} - */ - GeneratedMethod generateMethod(Object... methodNameParts); - - /** - * Return a new {@link MethodGenerator} instance that generates method with - * additional implicit method name parts. The final generated name will be - * of the following form: - *

- * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - *
OriginalUpdated
run<name>Run
getValueget<Name>Value
setValueset<Name>Value
isEnabledis<Name>Enabled
- * @param nameParts the implicit name parts - * @return a new {@link MethodGenerator} instance - */ - default MethodGenerator withName(Object... nameParts) { - return new MethodGeneratorWithName(this, nameParts); - } - -} diff --git a/spring-core/src/main/java/org/springframework/aot/generate/MethodGeneratorWithName.java b/spring-core/src/main/java/org/springframework/aot/generate/MethodGeneratorWithName.java deleted file mode 100644 index dee36e1894e..00000000000 --- a/spring-core/src/main/java/org/springframework/aot/generate/MethodGeneratorWithName.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -import org.springframework.util.ObjectUtils; -import org.springframework.util.StringUtils; - -/** - * Internal class used to support {@link MethodGenerator#withName(Object...)}. - * - * @author Phillip Webb - * @since 6.0 - */ -class MethodGeneratorWithName implements MethodGenerator { - - private static final String[] PREFIXES = { "get", "set", "is" }; - - private final MethodGenerator methodGenerator; - - private final Object[] nameParts; - - - MethodGeneratorWithName(MethodGenerator methodGenerator, Object[] nameParts) { - this.methodGenerator = methodGenerator; - this.nameParts = nameParts; - } - - - @Override - public GeneratedMethod generateMethod(Object... methodNameParts) { - return this.methodGenerator.generateMethod(generateName(methodNameParts)); - } - - private Object[] generateName(Object... methodNameParts) { - String joined = MethodNameGenerator.join(methodNameParts); - String prefix = getPrefix(joined); - String suffix = joined.substring(prefix.length()); - Object[] result = this.nameParts; - if (StringUtils.hasLength(prefix)) { - result = ObjectUtils.addObjectToArray(result, prefix, 0); - } - if (StringUtils.hasLength(suffix)) { - result = ObjectUtils.addObjectToArray(result, suffix); - } - return result; - } - - private String getPrefix(String name) { - for (String candidate : PREFIXES) { - if (name.startsWith(candidate)) { - return candidate; - } - } - return ""; - } - -} diff --git a/spring-core/src/main/java/org/springframework/aot/generate/MethodName.java b/spring-core/src/main/java/org/springframework/aot/generate/MethodName.java new file mode 100644 index 00000000000..af218d558d3 --- /dev/null +++ b/spring-core/src/main/java/org/springframework/aot/generate/MethodName.java @@ -0,0 +1,130 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.generate; + +import java.util.Arrays; +import java.util.stream.Collectors; + +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * A camel-case method name that can be built from distinct parts. + * + * @author Phillip Webb + * @since 6.0 + */ +public final class MethodName { + + private static final String[] PREFIXES = { "get", "set", "is" }; + + /** + * An empty method name. + */ + public static final MethodName NONE = of(); + + private final String value; + + + private MethodName(String value) { + this.value = value; + } + + + /** + * Create a new method name from the specific parts. The returned name will + * be in camel-case and will only contain valid characters from the parts. + * @param parts the parts the form the name + * @return a method name instance + */ + public static MethodName of(String... parts) { + Assert.notNull(parts, "'parts' must not be null"); + return new MethodName(join(parts)); + } + + /** + * Create a new method name by concatenating the specified name to this name. + * @param name the name to concatenate + * @return a new method name instance + */ + public MethodName and(MethodName name) { + Assert.notNull(name, "'name' must not be null"); + return and(name.value); + } + + /** + * Create a new method name by concatenating the specified parts to this name. + * @param parts the parts to concatenate + * @return a new method name instance + */ + public MethodName and(String... parts) { + Assert.notNull(parts, "'parts' must not be null"); + String joined = join(parts); + String prefix = getPrefix(joined); + String suffix = joined.substring(prefix.length()); + return of(prefix, this.value, suffix); + } + + private String getPrefix(String name) { + for (String candidate : PREFIXES) { + if (name.startsWith(candidate)) { + return candidate; + } + } + return ""; + } + + + @Override + public int hashCode() { + return this.value.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + return this.value.equals(((MethodName) obj).value); + } + + @Override + public String toString() { + return (!StringUtils.hasLength(this.value)) ? "$$aot" : this.value ; + } + + private static String join(String[] parts) { + return StringUtils.uncapitalize(Arrays.stream(parts).map(MethodName::clean) + .map(StringUtils::capitalize).collect(Collectors.joining())); + } + + private static String clean(String part) { + char[] chars = (part != null) ? part.toCharArray() : new char[0]; + StringBuilder name = new StringBuilder(chars.length); + boolean uppercase = false; + for (char ch : chars) { + char outputChar = (!uppercase) ? ch : Character.toUpperCase(ch); + name.append((!Character.isLetter(ch)) ? "" : outputChar); + uppercase = (ch == '.'); + } + return name.toString(); + } + +} diff --git a/spring-core/src/main/java/org/springframework/aot/generate/MethodNameGenerator.java b/spring-core/src/main/java/org/springframework/aot/generate/MethodNameGenerator.java deleted file mode 100644 index a8b30724a01..00000000000 --- a/spring-core/src/main/java/org/springframework/aot/generate/MethodNameGenerator.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import org.springframework.lang.Nullable; -import org.springframework.util.Assert; -import org.springframework.util.ClassUtils; -import org.springframework.util.StringUtils; - -/** - * Generates unique method names that can be used in ahead-of-time generated - * source code. This class is stateful so one instance should be used per - * generated type. - * - * @author Phillip Webb - * @since 6.0 - */ -public class MethodNameGenerator { - - private final Map sequenceGenerator = new ConcurrentHashMap<>(); - - - /** - * Create a new {@link MethodNameGenerator} instance without any reserved - * names. - */ - public MethodNameGenerator() { - } - - /** - * Create a new {@link MethodNameGenerator} instance with the specified - * reserved names. - * @param reservedNames the method names to reserve - */ - public MethodNameGenerator(String... reservedNames) { - this(List.of(reservedNames)); - } - - /** - * Create a new {@link MethodNameGenerator} instance with the specified - * reserved names. - * @param reservedNames the method names to reserve - */ - public MethodNameGenerator(Iterable reservedNames) { - Assert.notNull(reservedNames, "'reservedNames' must not be null"); - for (String reservedName : reservedNames) { - addSequence(StringUtils.uncapitalize(reservedName)); - } - } - - - /** - * Generate a new method name from the given parts. - * @param parts the parts used to build the name. - * @return the generated method name - */ - public String generateMethodName(Object... parts) { - String generatedName = join(parts); - return addSequence(generatedName.isEmpty() ? "$$aot" : generatedName); - } - - private String addSequence(String name) { - int sequence = this.sequenceGenerator - .computeIfAbsent(name, key -> new AtomicInteger()).getAndIncrement(); - return (sequence > 0) ? name + sequence : name; - } - - /** - * Join the specified parts to create a valid camel case method name. - * @param parts the parts to join - * @return a method name from the joined parts. - */ - public static String join(Object... parts) { - Stream capitalizedPartNames = Arrays.stream(parts) - .map(MethodNameGenerator::getPartName).map(StringUtils::capitalize); - return StringUtils.uncapitalize(capitalizedPartNames.collect(Collectors.joining())); - } - - private static String getPartName(@Nullable Object part) { - if (part == null) { - return ""; - } - if (part instanceof Class clazz) { - return clean(ClassUtils.getShortName(clazz)); - } - return clean(part.toString()); - } - - private static String clean(String string) { - char[] chars = string.toCharArray(); - StringBuilder name = new StringBuilder(chars.length); - boolean uppercase = false; - for (char ch : chars) { - char outputChar = (!uppercase) ? ch : Character.toUpperCase(ch); - name.append((!Character.isLetter(ch)) ? "" : outputChar); - uppercase = ch == '.'; - } - return name.toString(); - } - -} diff --git a/spring-core/src/test/java/org/springframework/aot/generate/ClassNameGeneratorTests.java b/spring-core/src/test/java/org/springframework/aot/generate/ClassNameGeneratorTests.java index ae8354341d2..de6f780178b 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/ClassNameGeneratorTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/ClassNameGeneratorTests.java @@ -36,72 +36,64 @@ class ClassNameGeneratorTests { @Test void generateClassNameWhenTargetClassIsNullUsesMainTarget() { - ClassName generated = this.generator.generateClassName(null, "test"); + ClassName generated = this.generator.generateClassName("test", null); assertThat(generated).hasToString("java.lang.Object__Test"); } @Test void generateClassNameUseFeatureNamePrefix() { ClassName generated = new ClassNameGenerator(Object.class, "One") - .generateClassName(InputStream.class, "test"); + .generateClassName("test", InputStream.class); assertThat(generated).hasToString("java.io.InputStream__OneTest"); } @Test void generateClassNameWithNoTextFeatureNamePrefix() { ClassName generated = new ClassNameGenerator(Object.class, " ") - .generateClassName(InputStream.class, "test"); + .generateClassName("test", InputStream.class); assertThat(generated).hasToString("java.io.InputStream__Test"); } @Test void generatedClassNameWhenFeatureIsEmptyThrowsException() { assertThatIllegalArgumentException() - .isThrownBy(() -> this.generator.generateClassName(InputStream.class, "")) + .isThrownBy(() -> this.generator.generateClassName("", InputStream.class)) .withMessage("'featureName' must not be empty"); } @Test void generatedClassNameWhenFeatureIsNotAllLettersThrowsException() { - assertThat(this.generator.generateClassName(InputStream.class, "name!")) + assertThat(this.generator.generateClassName("name!", InputStream.class)) .hasToString("java.io.InputStream__Name"); - assertThat(this.generator.generateClassName(InputStream.class, "1NameHere")) + assertThat(this.generator.generateClassName("1NameHere", InputStream.class)) .hasToString("java.io.InputStream__NameHere"); - assertThat(this.generator.generateClassName(InputStream.class, "Y0pe")) + assertThat(this.generator.generateClassName("Y0pe", InputStream.class)) .hasToString("java.io.InputStream__YPe"); } @Test void generateClassNameWithClassWhenLowercaseFeatureNameGeneratesName() { - ClassName generated = this.generator.generateClassName(InputStream.class, "bytes"); + ClassName generated = this.generator.generateClassName("bytes", InputStream.class); assertThat(generated).hasToString("java.io.InputStream__Bytes"); } @Test void generateClassNameWithClassWhenInnerClassGeneratesName() { - ClassName generated = this.generator.generateClassName(TestBean.class, "EventListener"); + ClassName generated = this.generator.generateClassName("EventListener", TestBean.class); assertThat(generated) .hasToString("org.springframework.aot.generate.ClassNameGeneratorTests_TestBean__EventListener"); } @Test void generateClassWithClassWhenMultipleCallsGeneratesSequencedName() { - ClassName generated1 = this.generator.generateClassName(InputStream.class, "bytes"); - ClassName generated2 = this.generator.generateClassName(InputStream.class, "bytes"); - ClassName generated3 = this.generator.generateClassName(InputStream.class, "bytes"); + ClassName generated1 = this.generator.generateClassName("bytes", InputStream.class); + ClassName generated2 = this.generator.generateClassName("bytes", InputStream.class); + ClassName generated3 = this.generator.generateClassName("bytes", InputStream.class); assertThat(generated1).hasToString("java.io.InputStream__Bytes"); assertThat(generated2).hasToString("java.io.InputStream__Bytes1"); assertThat(generated3).hasToString("java.io.InputStream__Bytes2"); } - @Test - void getClassNameWhenMultipleCallsReturnsSameName() { - String name1 = this.generator.getClassName(InputStream.class, "bytes"); - String name2 = this.generator.getClassName(InputStream.class, "bytes"); - String name3 = this.generator.getClassName(InputStream.class, "bytes"); - assertThat(name1).hasToString("java.io.InputStream__Bytes") - .isEqualTo(name2).isEqualTo(name3); - } static class TestBean { diff --git a/spring-core/src/test/java/org/springframework/aot/generate/DefaultGenerationContextTests.java b/spring-core/src/test/java/org/springframework/aot/generate/DefaultGenerationContextTests.java index 306a5ae0475..dc8d9937fca 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/DefaultGenerationContextTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/DefaultGenerationContextTests.java @@ -23,7 +23,7 @@ import org.junit.jupiter.api.Test; import org.springframework.aot.generate.GeneratedFiles.Kind; import org.springframework.aot.hint.RuntimeHints; import org.springframework.core.testfixture.aot.generate.TestTarget; -import org.springframework.javapoet.TypeSpec.Builder; +import org.springframework.javapoet.TypeSpec; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -36,7 +36,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException */ class DefaultGenerationContextTests { - private static final Consumer typeSpecCustomizer = type -> {}; + private static final Consumer typeSpecCustomizer = type -> {}; private final GeneratedClasses generatedClasses = new GeneratedClasses( new ClassNameGenerator(TestTarget.class)); @@ -113,34 +113,34 @@ class DefaultGenerationContextTests { new ClassNameGenerator(TestTarget.class), this.generatedFiles); GenerationContext anotherContext = context.withName("Another"); GeneratedClass generatedClass = anotherContext.getGeneratedClasses() - .forFeature("Test").generate(typeSpecCustomizer); + .addForFeature("Test", typeSpecCustomizer); assertThat(generatedClass.getName().simpleName()).endsWith("__AnotherTest"); } @Test - void withNameKeepTrackOfAllGeneratedFiles() { + void withNameKeepsTrackOfAllGeneratedFiles() { DefaultGenerationContext context = new DefaultGenerationContext( new ClassNameGenerator(TestTarget.class), this.generatedFiles); - context.getGeneratedClasses().forFeature("Test").generate(typeSpecCustomizer); + context.getGeneratedClasses().addForFeature("Test", typeSpecCustomizer); GenerationContext anotherContext = context.withName("Another"); assertThat(anotherContext.getGeneratedClasses()).isNotSameAs(context.getGeneratedClasses()); assertThat(anotherContext.getGeneratedFiles()).isSameAs(context.getGeneratedFiles()); assertThat(anotherContext.getRuntimeHints()).isSameAs(context.getRuntimeHints()); - anotherContext.getGeneratedClasses().forFeature("Test").generate(typeSpecCustomizer); + anotherContext.getGeneratedClasses().addForFeature("Test", typeSpecCustomizer); context.writeGeneratedContent(); assertThat(this.generatedFiles.getGeneratedFiles(Kind.SOURCE)).hasSize(2); } @Test - void withNameGenerateUniqueName() { + void withNameGeneratesUniqueName() { DefaultGenerationContext context = new DefaultGenerationContext( new ClassNameGenerator(Object.class), this.generatedFiles); context.withName("Test").getGeneratedClasses() - .forFeature("Feature").generate(typeSpecCustomizer); + .addForFeature("Feature", typeSpecCustomizer); context.withName("Test").getGeneratedClasses() - .forFeature("Feature").generate(typeSpecCustomizer); + .addForFeature("Feature", typeSpecCustomizer); context.withName("Test").getGeneratedClasses() - .forFeature("Feature").generate(typeSpecCustomizer); + .addForFeature("Feature", typeSpecCustomizer); context.writeGeneratedContent(); assertThat(this.generatedFiles.getGeneratedFiles(Kind.SOURCE)).containsOnlyKeys( "java/lang/Object__TestFeature.java", diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java index df0245c6523..bab98ebe6b2 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java @@ -21,9 +21,11 @@ import java.util.function.Consumer; import org.junit.jupiter.api.Test; import org.springframework.javapoet.ClassName; -import org.springframework.javapoet.TypeSpec.Builder; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.TypeSpec; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** * Tests for {@link GeneratedClass}. @@ -33,26 +35,49 @@ import static org.assertj.core.api.Assertions.assertThat; */ class GeneratedClassTests { + private static final Consumer emptyTypeCustomizer = type -> {}; + + private static final Consumer emptyMethodCustomizer = method -> {}; + @Test void getNameReturnsName() { ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(emptyTypeSpec(), name); + GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer); assertThat(generatedClass.getName()).isSameAs(name); } + @Test + void reserveMethodNamesWhenNameUsedThrowsException() { + ClassName name = ClassName.bestGuess("com.example.Test"); + GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer); + generatedClass.getMethods().add("apply", emptyMethodCustomizer); + assertThatIllegalStateException() + .isThrownBy(() -> generatedClass.reserveMethodNames("apply")); + } + + @Test + void reserveMethodNamesReservesNames() { + ClassName name = ClassName.bestGuess("com.example.Test"); + GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer); + generatedClass.reserveMethodNames("apply"); + GeneratedMethod generatedMethod = generatedClass.getMethods().add("apply", emptyMethodCustomizer); + assertThat(generatedMethod.getName()).isEqualTo("apply1"); + } + + @Test + void generateMethodNameWhenAllEmptyPartsGeneratesSetName() { + ClassName name = ClassName.bestGuess("com.example.Test"); + GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer); + GeneratedMethod generatedMethod = generatedClass.getMethods().add("123", emptyMethodCustomizer); + assertThat(generatedMethod.getName()).isEqualTo("$$aot"); + } + @Test void generateJavaFileIncludesGeneratedMethods() { ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(emptyTypeSpec(), name); - MethodGenerator methodGenerator = generatedClass.getMethodGenerator(); - methodGenerator.generateMethod("test") - .using(builder -> builder.addJavadoc("Test Method")); + GeneratedClass generatedClass = new GeneratedClass(name, emptyTypeCustomizer); + generatedClass.getMethods().add("test", method -> method.addJavadoc("Test Method")); assertThat(generatedClass.generateJavaFile().toString()).contains("Test Method"); } - - private Consumer emptyTypeSpec() { - return type -> {}; - } - } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassesTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassesTests.java index 7aca293319b..2edca07b55a 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassesTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassesTests.java @@ -23,7 +23,6 @@ import org.junit.jupiter.api.Test; import org.springframework.aot.generate.GeneratedFiles.Kind; import org.springframework.javapoet.TypeSpec; -import org.springframework.javapoet.TypeSpec.Builder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -52,93 +51,112 @@ class GeneratedClassesTests { } @Test - void forFeatureComponentWhenTargetIsNullThrowsException() { + void addForFeatureComponentWhenFeatureNameIsEmptyThrowsException() { assertThatIllegalArgumentException() - .isThrownBy(() -> this.generatedClasses.forFeatureComponent("test", null)) - .withMessage("'component' must not be null"); - } - - @Test - void forFeatureComponentWhenFeatureNameIsEmptyThrowsException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.generatedClasses.forFeatureComponent("", TestComponent.class)) + .isThrownBy(() -> this.generatedClasses.addForFeatureComponent("", + TestComponent.class, emptyTypeCustomizer)) .withMessage("'featureName' must not be empty"); } @Test - void forFeatureWhenFeatureNameIsEmptyThrowsException() { + void addForFeatureWhenFeatureNameIsEmptyThrowsException() { assertThatIllegalArgumentException() - .isThrownBy(() -> this.generatedClasses.forFeature("")) + .isThrownBy(() -> this.generatedClasses.addForFeature("", emptyTypeCustomizer)) .withMessage("'featureName' must not be empty"); } @Test - void generateWhenTypeSpecCustomizerIsNullThrowsException() { + void addForFeatureComponentWhenTypeSpecCustomizerIsNullThrowsException() { assertThatIllegalArgumentException() .isThrownBy(() -> this.generatedClasses - .forFeatureComponent("test", TestComponent.class).generate(null)) - .withMessage("'typeSpecCustomizer' must not be null"); + .addForFeatureComponent("test", TestComponent.class, null)) + .withMessage("'type' must not be null"); } @Test - void forFeatureUsesDefaultTarget() { - GeneratedClass generatedClass = this.generatedClasses - .forFeature("Test").generate(emptyTypeCustomizer); + void addForFeatureUsesDefaultTarget() { + GeneratedClass generatedClass = this.generatedClasses.addForFeature("Test", emptyTypeCustomizer); assertThat(generatedClass.getName()).hasToString("java.lang.Object__Test"); } @Test - void forFeatureComponentUsesComponent() { + void addForFeatureComponentUsesTarget() { GeneratedClass generatedClass = this.generatedClasses - .forFeatureComponent("Test", TestComponent.class).generate(emptyTypeCustomizer); + .addForFeatureComponent("Test", TestComponent.class, emptyTypeCustomizer); assertThat(generatedClass.getName().toString()).endsWith("TestComponent__Test"); } @Test - void generateReturnsDifferentInstances() { - Consumer typeCustomizer = mockTypeCustomizer(); + void addForFeatureComponentWithSameNameReturnsDifferentInstances() { GeneratedClass generatedClass1 = this.generatedClasses - .forFeatureComponent("one", TestComponent.class).generate(typeCustomizer); + .addForFeatureComponent("one", TestComponent.class, emptyTypeCustomizer); GeneratedClass generatedClass2 = this.generatedClasses - .forFeatureComponent("one", TestComponent.class).generate(typeCustomizer); + .addForFeatureComponent("one", TestComponent.class, emptyTypeCustomizer); assertThat(generatedClass1).isNotSameAs(generatedClass2); assertThat(generatedClass1.getName().simpleName()).endsWith("__One"); assertThat(generatedClass2.getName().simpleName()).endsWith("__One1"); } @Test - void getOrGenerateWhenNewReturnsGeneratedMethod() { - Consumer typeCustomizer = mockTypeCustomizer(); + void getOrAddForFeatureComponentWhenNewReturnsGeneratedMethod() { GeneratedClass generatedClass1 = this.generatedClasses - .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer); + .getOrAddForFeatureComponent("one", TestComponent.class, emptyTypeCustomizer); GeneratedClass generatedClass2 = this.generatedClasses - .forFeatureComponent("two", TestComponent.class).getOrGenerate("facet", typeCustomizer); + .getOrAddForFeatureComponent("two", TestComponent.class, emptyTypeCustomizer); assertThat(generatedClass1).isNotNull().isNotEqualTo(generatedClass2); assertThat(generatedClass2).isNotNull(); } @Test - void getOrGenerateWhenRepeatReturnsSameGeneratedMethod() { - Consumer typeCustomizer = mockTypeCustomizer(); + void getOrAddForFeatureWhenNewReturnsGeneratedMethod() { GeneratedClass generatedClass1 = this.generatedClasses - .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer); + .getOrAddForFeature("one", emptyTypeCustomizer); GeneratedClass generatedClass2 = this.generatedClasses - .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer); + .getOrAddForFeature("two", emptyTypeCustomizer); + assertThat(generatedClass1).isNotNull().isNotEqualTo(generatedClass2); + assertThat(generatedClass2).isNotNull(); + } + + @Test + void getOrAddForFeatureComponentWhenRepeatReturnsSameGeneratedMethod() { + GeneratedClass generatedClass1 = this.generatedClasses + .getOrAddForFeatureComponent("one", TestComponent.class, emptyTypeCustomizer); + GeneratedClass generatedClass2 = this.generatedClasses + .getOrAddForFeatureComponent("one", TestComponent.class, emptyTypeCustomizer); GeneratedClass generatedClass3 = this.generatedClasses - .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer); + .getOrAddForFeatureComponent("one", TestComponent.class, emptyTypeCustomizer); assertThat(generatedClass1).isNotNull().isSameAs(generatedClass2) .isSameAs(generatedClass3); - verifyNoInteractions(typeCustomizer); - generatedClass1.generateJavaFile(); - verify(typeCustomizer).accept(any()); + } + + @Test + void getOrAddForFeatureWhenRepeatReturnsSameGeneratedMethod() { + GeneratedClass generatedClass1 = this.generatedClasses + .getOrAddForFeature("one", emptyTypeCustomizer); + GeneratedClass generatedClass2 = this.generatedClasses + .getOrAddForFeature("one", emptyTypeCustomizer); + GeneratedClass generatedClass3 = this.generatedClasses + .getOrAddForFeature("one", emptyTypeCustomizer); + assertThat(generatedClass1).isNotNull().isSameAs(generatedClass2) + .isSameAs(generatedClass3); + } + + @Test + void getOrAddForFeatureComponentWhenHasFeatureNamePrefix() { + GeneratedClasses prefixed = this.generatedClasses.withFeatureNamePrefix("prefix"); + GeneratedClass generatedClass1 = this.generatedClasses.getOrAddForFeatureComponent("one", TestComponent.class, emptyTypeCustomizer); + GeneratedClass generatedClass2 = this.generatedClasses.getOrAddForFeatureComponent("one", TestComponent.class, emptyTypeCustomizer); + GeneratedClass generatedClass3 = prefixed.getOrAddForFeatureComponent("one", TestComponent.class, emptyTypeCustomizer); + GeneratedClass generatedClass4 = prefixed.getOrAddForFeatureComponent("one", TestComponent.class, emptyTypeCustomizer); + assertThat(generatedClass1).isSameAs(generatedClass2).isNotSameAs(generatedClass3); + assertThat(generatedClass3).isSameAs(generatedClass4); } @Test @SuppressWarnings("unchecked") void writeToInvokeTypeSpecCustomizer() throws IOException { Consumer typeSpecCustomizer = mock(Consumer.class); - this.generatedClasses.forFeatureComponent("one", TestComponent.class) - .generate(typeSpecCustomizer); + this.generatedClasses.addForFeatureComponent("one", TestComponent.class, typeSpecCustomizer); verifyNoInteractions(typeSpecCustomizer); InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); this.generatedClasses.writeTo(generatedFiles); @@ -149,20 +167,14 @@ class GeneratedClassesTests { @Test void withNameUpdatesNamingConventions() { GeneratedClass generatedClass1 = this.generatedClasses - .forFeatureComponent("one", TestComponent.class).generate(emptyTypeCustomizer); - GeneratedClass generatedClass2 = this.generatedClasses.withName("Another") - .forFeatureComponent("one", TestComponent.class).generate(emptyTypeCustomizer); + .addForFeatureComponent("one", TestComponent.class, emptyTypeCustomizer); + GeneratedClass generatedClass2 = this.generatedClasses.withFeatureNamePrefix("Another") + .addForFeatureComponent("one", TestComponent.class, emptyTypeCustomizer); assertThat(generatedClass1.getName().toString()).endsWith("TestComponent__One"); assertThat(generatedClass2.getName().toString()).endsWith("TestComponent__AnotherOne"); } - @SuppressWarnings("unchecked") - private Consumer mockTypeCustomizer() { - return mock(Consumer.class); - } - - private static class TestComponent { } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java index 44a1578aedb..f080ce8ed20 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java @@ -16,12 +16,13 @@ package org.springframework.aot.generate; -import javax.lang.model.element.Modifier; +import java.util.function.Consumer; import org.junit.jupiter.api.Test; +import org.springframework.javapoet.MethodSpec; + import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** @@ -31,42 +32,27 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException; */ class GeneratedMethodTests { + private static final Consumer methodSpecCustomizer = method -> {}; + private static final String NAME = "spring"; @Test void getNameReturnsName() { - GeneratedMethod method = new GeneratedMethod(NAME); - assertThat(method.getName()).isSameAs(NAME); + GeneratedMethod generatedMethod = new GeneratedMethod(NAME, methodSpecCustomizer); + assertThat(generatedMethod.getName()).isSameAs(NAME); } @Test - void getSpecReturnsSpec() { - GeneratedMethod method = new GeneratedMethod(NAME); - method.using(builder -> builder.addJavadoc("Test")); - assertThat(method.getSpec().javadoc).asString().contains("Test"); + void generateMethodSpecReturnsMethodSpec() { + GeneratedMethod generatedMethod = new GeneratedMethod(NAME, method -> method.addJavadoc("Test")); + assertThat(generatedMethod.getMethodSpec().javadoc).asString().contains("Test"); } @Test - void getSpecReturnsSpecWhenNoSpecDefinedThrowsException() { - GeneratedMethod method = new GeneratedMethod(NAME); - assertThatIllegalStateException().isThrownBy(() -> method.getSpec()) - .withMessage("Method 'spring' has no method spec defined"); - } - - @Test - void usingAddsSpec() { - GeneratedMethod method = new GeneratedMethod(NAME); - method.using(builder -> builder.addModifiers(Modifier.PUBLIC)); - assertThat(method.getSpec()).asString() - .isEqualToIgnoringNewLines("public void spring() {}"); - } - - @Test - void usingWhenBuilderChanagesNameThrowsException() { - GeneratedMethod method = new GeneratedMethod(NAME); - assertThatIllegalArgumentException() - .isThrownBy(() -> method.using(builder -> builder.setName("badname"))) - .withMessage("'spec' must use the generated name 'spring'"); + void generateMethodSpecWhenMethodNameIsChangedThrowsException() { + assertThatIllegalStateException().isThrownBy(() -> + new GeneratedMethod(NAME, method -> method.setName("badname")).getMethodSpec()) + .withMessage("'method' consumer must not change the generated method name"); } } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodsTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodsTests.java index 6b091eff0ce..a4755f52945 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodsTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodsTests.java @@ -17,8 +17,9 @@ package org.springframework.aot.generate; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; import org.junit.jupiter.api.Test; @@ -26,7 +27,6 @@ import org.springframework.javapoet.MethodSpec; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** * Tests for {@link GeneratedMethods}. @@ -35,7 +35,9 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException; */ class GeneratedMethodsTests { - private final GeneratedMethods methods = new GeneratedMethods(); + private static final Consumer methodSpecCustomizer = method -> {}; + + private final GeneratedMethods methods = new GeneratedMethods(MethodName::toString); @Test void createWhenMethodNameGeneratorIsNullThrowsException() { @@ -45,56 +47,82 @@ class GeneratedMethodsTests { @Test void createWithExistingGeneratorUsesGenerator() { - MethodNameGenerator generator = new MethodNameGenerator(); - generator.generateMethodName("test"); + Function generator = name -> "__" + name.toString(); GeneratedMethods methods = new GeneratedMethods(generator); - assertThat(methods.add("test").getName()).hasToString("test1"); + assertThat(methods.add("test", methodSpecCustomizer).getName()).hasToString("__test"); + } + + @Test + void addWithMethodNameWhenSuggestedMethodIsNullThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> + this.methods.add((MethodName) null, methodSpecCustomizer)) + .withMessage("'suggestedName' must not be null"); + } + + @Test + void addWithMethodNameWhenMethodIsNullThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> + this.methods.add(MethodName.of("test"), null)) + .withMessage("'method' must not be null"); + } + + @Test + void addWithStringNameWhenSuggestedMethodIsNullThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> + this.methods.add((String) null, methodSpecCustomizer)) + .withMessage("'suggestedName' must not be null"); + } + + @Test + void addWithStringNameWhenMethodIsNullThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> + this.methods.add("test", null)) + .withMessage("'method' must not be null"); } @Test void addAddsMethod() { - this.methods.add("spring", "beans").using(this::build); - this.methods.add("spring", "context").using(this::build); - assertThat( - this.methods.stream().map(GeneratedMethod::getName).map(Object::toString)) + this.methods.add("springBeans", methodSpecCustomizer); + this.methods.add("springContext", methodSpecCustomizer); + assertThat(this.methods.stream().map(GeneratedMethod::getName).map(Object::toString)) .containsExactly("springBeans", "springContext"); } + @Test + void withPrefixWhenGeneratingGetMethodUsesPrefix() { + GeneratedMethod generateMethod = this.methods.withPrefix("myBean") + .add("getTest", methodSpecCustomizer); + assertThat(generateMethod.getName()).hasToString("getMyBeanTest"); + } + + @Test + void withPrefixWhenGeneratingSetMethodUsesPrefix() { + GeneratedMethod generateMethod = this.methods.withPrefix("myBean") + .add("setTest", methodSpecCustomizer); + assertThat(generateMethod.getName()).hasToString("setMyBeanTest"); + } + + @Test + void withPrefixWhenGeneratingIsMethodUsesPrefix() { + GeneratedMethod generateMethod = this.methods.withPrefix("myBean") + .add("isTest", methodSpecCustomizer); + assertThat(generateMethod.getName()).hasToString("isMyBeanTest"); + } + + @Test + void withPrefixWhenGeneratingOtherMethodUsesPrefix() { + GeneratedMethod generateMethod = this.methods.withPrefix("myBean") + .add("test", methodSpecCustomizer); + assertThat(generateMethod.getName()).hasToString("myBeanTest"); + } + @Test void doWithMethodSpecsAcceptsMethodSpecs() { - this.methods.add("spring", "beans").using(this::build); - this.methods.add("spring", "context").using(this::build); + this.methods.add("springBeans", methodSpecCustomizer); + this.methods.add("springContext", methodSpecCustomizer); List names = new ArrayList<>(); - this.methods.doWithMethodSpecs(spec -> names.add(spec.name)); + this.methods.doWithMethodSpecs(methodSpec -> names.add(methodSpec.name)); assertThat(names).containsExactly("springBeans", "springContext"); } - @Test - void doWithMethodSpecsWhenMethodHasNotHadSpecDefinedThrowsException() { - this.methods.add("spring"); - assertThatIllegalStateException() - .isThrownBy(() -> this.methods.doWithMethodSpecs(spec -> { - })).withMessage("Method 'spring' has no method spec defined"); - } - - @Test - void iteratorIteratesMethods() { - this.methods.add("spring", "beans").using(this::build); - this.methods.add("spring", "context").using(this::build); - Iterator iterator = this.methods.iterator(); - assertThat(iterator.next().getName()).hasToString("springBeans"); - assertThat(iterator.next().getName()).hasToString("springContext"); - assertThat(iterator.hasNext()).isFalse(); - } - - @Test - void streamStreamsMethods() { - this.methods.add("spring", "beans").using(this::build); - this.methods.add("spring", "context").using(this::build); - assertThat(this.methods.stream()).hasSize(2); - } - - private void build(MethodSpec.Builder builder) { - } - } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/MethodGeneratorWithNameTests.java b/spring-core/src/test/java/org/springframework/aot/generate/MethodGeneratorWithNameTests.java deleted file mode 100644 index 0b4fcaf065b..00000000000 --- a/spring-core/src/test/java/org/springframework/aot/generate/MethodGeneratorWithNameTests.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Tests {@link MethodGeneratorWithName}. - * - * @author Phillip Webb - * @since 6.0 - */ -class MethodGeneratorWithNameTests { - - private final GeneratedMethods generatedMethods = new GeneratedMethods(); - - @Test - void withNameWhenGeneratingGetMethod() { - GeneratedMethod generateMethod = generatedMethods.withName("my", "bean") - .generateMethod("get", "test"); - assertThat(generateMethod.getName()).hasToString("getMyBeanTest"); - } - - @Test - void withNameWhenGeneratingSetMethod() { - GeneratedMethod generateMethod = generatedMethods.withName("my", "bean") - .generateMethod("set", "test"); - assertThat(generateMethod.getName()).hasToString("setMyBeanTest"); - } - - @Test - void withNameWhenGeneratingIsMethod() { - GeneratedMethod generateMethod = generatedMethods.withName("my", "bean") - .generateMethod("is", "test"); - assertThat(generateMethod.getName()).hasToString("isMyBeanTest"); - } - - @Test - void withNameWhenGeneratingOtherMethod() { - GeneratedMethod generateMethod = generatedMethods.withName("my", "bean") - .generateMethod("test"); - assertThat(generateMethod.getName()).hasToString("myBeanTest"); - } - -} diff --git a/spring-core/src/test/java/org/springframework/aot/generate/MethodNameGeneratorTests.java b/spring-core/src/test/java/org/springframework/aot/generate/MethodNameGeneratorTests.java deleted file mode 100644 index 997853b4bd6..00000000000 --- a/spring-core/src/test/java/org/springframework/aot/generate/MethodNameGeneratorTests.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -import java.io.InputStream; - -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Tests for {@link MethodNameGenerator}. - * - * @author Phillip Webb - */ -class MethodNameGeneratorTests { - - private final MethodNameGenerator generator = new MethodNameGenerator(); - - @Test - void createWithReservedNamesReservesNames() { - MethodNameGenerator generator = new MethodNameGenerator("testName"); - assertThat(generator.generateMethodName("test", "name")).hasToString("testName1"); - } - - @Test - void generateMethodNameGeneratesName() { - String generated = this.generator.generateMethodName("register", "myBean", - "bean"); - assertThat(generated).isEqualTo("registerMyBeanBean"); - } - - @Test - void generateMethodNameWhenHasNonLettersGeneratesName() { - String generated = this.generator.generateMethodName("register", "myBean123", - "bean"); - assertThat(generated).isEqualTo("registerMyBeanBean"); - } - - @Test - void generateMethodNameWhenHasDotsGeneratesCamelCaseName() { - String generated = this.generator.generateMethodName("register", - "org.springframework.example.bean"); - assertThat(generated).isEqualTo("registerOrgSpringframeworkExampleBean"); - } - - @Test - void generateMethodNameWhenMultipleCallsGeneratesSequencedName() { - String generated1 = this.generator.generateMethodName("register", "myBean123", - "bean"); - String generated2 = this.generator.generateMethodName("register", "myBean!", - "bean"); - String generated3 = this.generator.generateMethodName("register", "myBean%%", - "bean"); - assertThat(generated1).isEqualTo("registerMyBeanBean"); - assertThat(generated2).isEqualTo("registerMyBeanBean1"); - assertThat(generated3).isEqualTo("registerMyBeanBean2"); - } - - @Test - void generateMethodNameWhenAllEmptyPartsGeneratesSetName() { - String generated = this.generator.generateMethodName("123"); - assertThat(generated).isEqualTo("$$aot"); - } - - @Test - void joinReturnsJoinedName() { - assertThat(MethodNameGenerator.join("get", "bean", "factory")) - .isEqualTo("getBeanFactory"); - assertThat(MethodNameGenerator.join("get", null, "factory")) - .isEqualTo("getFactory"); - assertThat(MethodNameGenerator.join(null, null)).isEqualTo(""); - assertThat(MethodNameGenerator.join("", null)).isEqualTo(""); - assertThat(MethodNameGenerator.join("get", InputStream.class)) - .isEqualTo("getInputStream"); - - } - -} diff --git a/spring-core/src/test/java/org/springframework/aot/generate/MethodNameTests.java b/spring-core/src/test/java/org/springframework/aot/generate/MethodNameTests.java new file mode 100644 index 00000000000..4267275bfa8 --- /dev/null +++ b/spring-core/src/test/java/org/springframework/aot/generate/MethodNameTests.java @@ -0,0 +1,92 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.generate; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link MethodName}. + * + * @author Phillip Webb + */ +class MethodNameTests { + + @Test + void ofWhenPartsIsNullThrowsException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> MethodName.of((String[]) null)) + .withMessage("'parts' must not be null"); + } + + @Test + void ofReturnsMethodName() { + assertThat(MethodName.of("get", "bean", "factory")).hasToString("getBeanFactory"); + assertThat(MethodName.of("get", null, "factory")).hasToString("getFactory"); + assertThat(MethodName.of(null, null)).hasToString("$$aot"); + assertThat(MethodName.of("", null)).hasToString("$$aot"); + assertThat(MethodName.of("get", "InputStream")).hasToString("getInputStream"); + assertThat(MethodName.of("register", "myBean123", "bean")).hasToString("registerMyBeanBean"); + assertThat(MethodName.of("register", "org.springframework.example.bean")) + .hasToString("registerOrgSpringframeworkExampleBean"); + } + + @Test + void andPartsWhenPartsIsNullThrowsException() { + MethodName name = MethodName.of("myBean"); + assertThatIllegalArgumentException() + .isThrownBy(() -> name.and(((String[]) null))) + .withMessage("'parts' must not be null"); + } + + @Test + void andPartsReturnsMethodName() { + MethodName name = MethodName.of("myBean"); + assertThat(name.and("test")).hasToString("myBeanTest"); + assertThat(name.and("test", null)).hasToString("myBeanTest"); + assertThat(name.and("getName")).hasToString("getMyBeanName"); + assertThat(name.and("setName")).hasToString("setMyBeanName"); + assertThat(name.and("isDoingOk")).hasToString("isMyBeanDoingOk"); + assertThat(name.and("this", "that", "the", "other")).hasToString("myBeanThisThatTheOther"); + } + + @Test + void andNameWhenPartsIsNullThrowsException() { + MethodName name = MethodName.of("myBean"); + assertThatIllegalArgumentException() + .isThrownBy(() -> name.and(((MethodName) null))) + .withMessage("'name' must not be null"); + } + + @Test + void andNameReturnsMethodName() { + MethodName name = MethodName.of("myBean"); + assertThat(name.and(MethodName.of("test"))).hasToString("myBeanTest"); + } + + @Test + void hashCodeAndEquals() { + MethodName name1 = MethodName.of("myBean"); + MethodName name2 = MethodName.of("my", "bean"); + MethodName name3 = MethodName.of("myOtherBean"); + assertThat(name1.hashCode()).isEqualTo(name2.hashCode()); + assertThat(name1).isEqualTo(name1).isEqualTo(name2).isNotEqualTo(name3); + } + +} diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java index 3363a6ec018..b5d08004b6c 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java @@ -30,7 +30,6 @@ import java.util.Map; import java.util.Properties; import java.util.TreeSet; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Consumer; import jakarta.persistence.EntityManager; import jakarta.persistence.EntityManagerFactory; @@ -42,8 +41,9 @@ import jakarta.persistence.SynchronizationType; import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.MethodName; import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.BeanUtils; @@ -72,7 +72,6 @@ import org.springframework.core.PriorityOrdered; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; -import org.springframework.javapoet.MethodSpec.Builder; import org.springframework.jndi.JndiLocatorDelegate; import org.springframework.jndi.JndiTemplate; import org.springframework.lang.Nullable; @@ -766,8 +765,6 @@ public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwar private static class AotContribution implements BeanRegistrationAotContribution { - private static final String APPLY_METHOD = "apply"; - private static final String REGISTERED_BEAN_PARAMETER = "registeredBean"; private static final String INSTANCE_PARAMETER = "instance"; @@ -788,45 +785,38 @@ public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwar public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) { GeneratedClass generatedClass = generationContext.getGeneratedClasses() - .forFeatureComponent("PersistenceInjection", this.target).generate(type -> { + .addForFeatureComponent("PersistenceInjection", this.target, type -> { type.addJavadoc("Persistence injection for {@link $T}.", this.target); type.addModifiers(javax.lang.model.element.Modifier.PUBLIC); }); - generatedClass.getMethodGenerator().generateMethod(APPLY_METHOD) - .using(generateMethod(generationContext.getRuntimeHints(), generatedClass.getMethodGenerator())); - beanRegistrationCode.addInstancePostProcessor( - MethodReference.ofStatic(generatedClass.getName(), APPLY_METHOD)); - } - - private Consumer generateMethod(RuntimeHints hints, MethodGenerator methodGenerator) { - return method -> { + GeneratedMethod generatedMethod = generatedClass.getMethods().add("apply", method -> { method.addJavadoc("Apply the persistence injection."); method.addModifiers(javax.lang.model.element.Modifier.PUBLIC, javax.lang.model.element.Modifier.STATIC); method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); method.addParameter(this.target, INSTANCE_PARAMETER); method.returns(this.target); - method.addCode(generateMethodCode(hints, methodGenerator)); - }; + method.addCode(generateMethodCode(generationContext.getRuntimeHints(), generatedClass.getMethods())); + }); + beanRegistrationCode.addInstancePostProcessor(MethodReference + .ofStatic(generatedClass.getName(), generatedMethod.getName())); } - private CodeBlock generateMethodCode(RuntimeHints hints, - MethodGenerator methodGenerator) { - CodeBlock.Builder builder = CodeBlock.builder(); - InjectionCodeGenerator injectionCodeGenerator = new InjectionCodeGenerator( - hints); + private CodeBlock generateMethodCode(RuntimeHints hints, GeneratedMethods generatedMethods) { + CodeBlock.Builder code = CodeBlock.builder(); + InjectionCodeGenerator injectionCodeGenerator = new InjectionCodeGenerator(hints); for (InjectedElement injectedElement : this.injectedElements) { - CodeBlock resourceToInject = getResourceToInject(methodGenerator, + CodeBlock resourceToInject = generateResourceToInjectCode(generatedMethods, (PersistenceElement) injectedElement); - builder.add(injectionCodeGenerator.generateInjectionCode( + code.add(injectionCodeGenerator.generateInjectionCode( injectedElement.getMember(), INSTANCE_PARAMETER, resourceToInject)); } - builder.addStatement("return $L", INSTANCE_PARAMETER); - return builder.build(); + code.addStatement("return $L", INSTANCE_PARAMETER); + return code.build(); } - private CodeBlock getResourceToInject(MethodGenerator methodGenerator, + private CodeBlock generateResourceToInjectCode(GeneratedMethods generatedMethods, PersistenceElement injectedElement) { String unitName = injectedElement.unitName; boolean requireEntityManager = (injectedElement.type != null); @@ -836,40 +826,38 @@ public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwar EntityManagerFactoryUtils.class, ListableBeanFactory.class, REGISTERED_BEAN_PARAMETER, unitName); } - GeneratedMethod getEntityManagerMethod = methodGenerator - .generateMethod("get", unitName, "EntityManager") - .using(builder -> buildGetEntityManagerMethod(builder, - injectedElement)); - return CodeBlock.of("$L($L)", getEntityManagerMethod.getName(), - REGISTERED_BEAN_PARAMETER); + GeneratedMethod generatedMethod = generatedMethods + .add(MethodName.of("get", unitName, "EntityManager"), method -> + generateGetEntityManagerMethod(method, injectedElement)); + return CodeBlock.of("$L($L)", generatedMethod.getName(), REGISTERED_BEAN_PARAMETER); } - private void buildGetEntityManagerMethod(MethodSpec.Builder builder, + private void generateGetEntityManagerMethod(MethodSpec.Builder method, PersistenceElement injectedElement) { String unitName = injectedElement.unitName; Properties properties = injectedElement.properties; - builder.addJavadoc("Get the '$L' {@link $T}", + method.addJavadoc("Get the '$L' {@link $T}", (StringUtils.hasLength(unitName)) ? unitName : "default", EntityManager.class); - builder.addModifiers(javax.lang.model.element.Modifier.PUBLIC, + method.addModifiers(javax.lang.model.element.Modifier.PUBLIC, javax.lang.model.element.Modifier.STATIC); - builder.returns(EntityManager.class); - builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); - builder.addStatement( + method.returns(EntityManager.class); + method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); + method.addStatement( "$T entityManagerFactory = $T.findEntityManagerFactory(($T) $L.getBeanFactory(), $S)", EntityManagerFactory.class, EntityManagerFactoryUtils.class, ListableBeanFactory.class, REGISTERED_BEAN_PARAMETER, unitName); boolean hasProperties = !CollectionUtils.isEmpty(properties); if (hasProperties) { - builder.addStatement("$T properties = new Properties()", + method.addStatement("$T properties = new Properties()", Properties.class); for (String propertyName : new TreeSet<>( properties.stringPropertyNames())) { - builder.addStatement("properties.put($S, $S)", propertyName, + method.addStatement("properties.put($S, $S)", propertyName, properties.getProperty(propertyName)); } } - builder.addStatement( + method.addStatement( "return $T.createSharedEntityManager(entityManagerFactory, $L, $L)", SharedEntityManagerCreator.class, (hasProperties) ? "properties" : null,