diff --git a/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java b/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java index 36d44764b15..5ad52df235b 100644 --- a/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java +++ b/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java @@ -163,8 +163,8 @@ class ScopedProxyBeanRegistrationAotProcessor implements BeanRegistrationAotProc method.addStatement("return ($T) factory.getObject()", beanClass); }); - return CodeBlock.of("$T.of($T::$L)", InstanceSupplier.class, - beanRegistrationCode.getClassName(), generatedMethod.getName()); + return CodeBlock.of("$T.of($L)", InstanceSupplier.class, + generatedMethod.toMethodReference().toCodeBlock()); } } diff --git a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java index e29a3192492..1ce8b798680 100644 --- a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java +++ b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java @@ -26,6 +26,7 @@ import org.junit.jupiter.api.Test; import org.springframework.aop.framework.AopInfrastructureBean; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.Compiled; import org.springframework.aot.test.generate.compile.TestCompiler; @@ -139,11 +140,14 @@ class ScopedProxyBeanRegistrationAotProcessorTests { MethodReference methodReference = this.beanFactoryInitializationCode .getInitializers().get(0); this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"), + this.beanFactoryInitializationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .addStatement(methodInvocation) .build()); }); this.generationContext.writeGeneratedContent(); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java index 2154610858f..ace0f093ac6 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java @@ -44,7 +44,6 @@ import org.springframework.aot.generate.AccessVisibility; import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.ExecutableMode; import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.BeanUtils; @@ -944,8 +943,7 @@ public class AutowiredAnnotationBeanPostProcessor implements SmartInstantiationA method.returns(this.target); method.addCode(generateMethodCode(generationContext.getRuntimeHints())); }); - beanRegistrationCode.addInstancePostProcessor( - MethodReference.ofStatic(generatedClass.getName(), generateMethod.getName())); + beanRegistrationCode.addInstancePostProcessor(generateMethod.toMethodReference()); if (this.candidateResolver != null) { registerHints(generationContext.getRuntimeHints()); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java index 8f08985b792..29703c3eeb5 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java @@ -107,16 +107,14 @@ class BeanDefinitionMethodGenerator { GeneratedMethod generatedMethod = generateBeanDefinitionMethod( generationContext, generatedClass.getName(), generatedMethods, codeFragments, Modifier.PUBLIC); - return MethodReference.ofStatic(generatedClass.getName(), - generatedMethod.getName()); + return generatedMethod.toMethodReference(); } GeneratedMethods generatedMethods = beanRegistrationsCode.getMethods() .withPrefix(getName()); GeneratedMethod generatedMethod = generateBeanDefinitionMethod(generationContext, beanRegistrationsCode.getClassName(), generatedMethods, codeFragments, Modifier.PRIVATE); - return MethodReference.ofStatic(beanRegistrationsCode.getClassName(), - generatedMethod.getName()); + return generatedMethod.toMethodReference(); } private BeanRegistrationCodeFragments getCodeFragments(GenerationContext generationContext, diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java index 654fdda5866..e7da3299e81 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java @@ -24,6 +24,7 @@ import org.springframework.aot.generate.MethodReference; * perform bean factory initialization. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 * @see BeanFactoryInitializationAotContribution */ @@ -41,10 +42,16 @@ public interface BeanFactoryInitializationCode { GeneratedMethods getMethods(); /** - * Add an initializer method call. - * @param methodReference a reference to the initialize method to call. The - * referenced method must have the same functional signature as - * {@code Consumer}. + * Add an initializer method call. An initializer can use a flexible signature, + * using any of the following: + * + * @param methodReference a reference to the initialize method to call. */ void addInitializer(MethodReference methodReference); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java index 2dad04b877f..a80db112d35 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java @@ -25,6 +25,7 @@ import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; @@ -65,8 +66,7 @@ class BeanRegistrationsAotContribution BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator(generatedClass); GeneratedMethod generatedMethod = codeGenerator.getMethods().add("registerBeanDefinitions", method -> generateRegisterMethod(method, generationContext, codeGenerator)); - beanFactoryInitializationCode.addInitializer( - MethodReference.of(generatedClass.getName(), generatedMethod.getName())); + beanFactoryInitializationCode.addInitializer(generatedMethod.toMethodReference()); } private void generateRegisterMethod(MethodSpec.Builder method, @@ -82,9 +82,11 @@ class BeanRegistrationsAotContribution MethodReference beanDefinitionMethod = beanDefinitionMethodGenerator .generateBeanDefinitionMethod(generationContext, beanRegistrationsCode); + CodeBlock methodInvocation = beanDefinitionMethod.toInvokeCodeBlock( + ArgumentCodeGenerator.none(), beanRegistrationsCode.getClassName()); code.addStatement("$L.registerBeanDefinition($S, $L)", BEAN_FACTORY_PARAMETER_NAME, beanName, - beanDefinitionMethod.toInvokeCodeBlock()); + methodInvocation); }); method.addCode(code.build()); } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index bf5719eccdb..fa04f862115 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -24,6 +24,7 @@ import java.util.function.Predicate; import org.springframework.aot.generate.AccessVisibility; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionHolder; @@ -156,7 +157,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments MethodReference generatedMethod = methodGenerator .generateBeanDefinitionMethod(generationContext, this.beanRegistrationsCode); - return generatedMethod.toInvokeCodeBlock(); + return generatedMethod.toInvokeCodeBlock(ArgumentCodeGenerator.none()); } return null; } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java index 66bba162047..e6cb5df84f1 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java @@ -28,6 +28,7 @@ import org.springframework.aot.generate.AccessVisibility; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.ExecutableMode; import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.RegisteredBean; @@ -296,8 +297,9 @@ class InstanceSupplierCodeGenerator { REGISTERED_BEAN_PARAMETER_NAME, declaringClass, factoryMethodName, args); } - private CodeBlock generateReturnStatement(GeneratedMethod getInstanceMethod) { - return CodeBlock.of("$T.$L()", this.className, getInstanceMethod.getName()); + private CodeBlock generateReturnStatement(GeneratedMethod generatedMethod) { + return generatedMethod.toMethodReference().toInvokeCodeBlock( + ArgumentCodeGenerator.none(), this.className); } private CodeBlock generateWithGeneratorCode(boolean hasArguments, CodeBlock newInstance) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java index 2f4fe187b13..219093424e7 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java @@ -24,6 +24,7 @@ import javax.lang.model.element.Modifier; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess; @@ -161,13 +162,16 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { Class target = registeredBean.getBeanClass(); MethodReference methodReference = this.beanRegistrationCode.getInstancePostProcessors().get(0); this.beanRegistrationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(RegisteredBean.class, "registeredBean").and(target, "instance"), + this.beanRegistrationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(BiFunction.class, RegisteredBean.class, target, target)); type.addMethod(MethodSpec.methodBuilder("apply") .addModifiers(Modifier.PUBLIC) .addParameter(RegisteredBean.class, "registeredBean") .addParameter(target, "instance").returns(target) - .addStatement("return $L", methodReference.toInvokeCodeBlock(CodeBlock.of("registeredBean"), CodeBlock.of("instance"))) + .addStatement("return $L", methodInvocation) .build()); }); diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 53c182ddc58..9020bf8626a 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -30,6 +30,7 @@ import org.junit.jupiter.api.Test; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess; import org.springframework.aot.test.generate.compile.Compiled; @@ -129,8 +130,7 @@ class BeanDefinitionMethodGeneratorTests { .addParameter(RegisteredBean.class, "registeredBean") .addParameter(TestBean.class, "testBean") .returns(TestBean.class).addCode("return new $T($S);", TestBean.class, "postprocessed")); - beanRegistrationCode.addInstancePostProcessor(MethodReference.ofStatic( - beanRegistrationCode.getClassName(), generatedMethod.getName())); + beanRegistrationCode.addInstancePostProcessor(generatedMethod.toMethodReference()); }; List aotContributions = Collections .singletonList(aotContribution); @@ -167,8 +167,7 @@ class BeanDefinitionMethodGeneratorTests { .addParameter(RegisteredBean.class, "registeredBean") .addParameter(TestBean.class, "testBean") .returns(TestBean.class).addCode("return new $T($S);", TestBean.class, "postprocessed")); - beanRegistrationCode.addInstancePostProcessor(MethodReference.ofStatic( - beanRegistrationCode.getClassName(), generatedMethod.getName())); + beanRegistrationCode.addInstancePostProcessor(generatedMethod.toMethodReference()); }; List aotContributions = Collections .singletonList(aotContribution); @@ -416,12 +415,14 @@ class BeanDefinitionMethodGeneratorTests { private void compile(MethodReference method, BiConsumer result) { this.beanRegistrationsCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = method.toInvokeCodeBlock(ArgumentCodeGenerator.none(), + this.beanRegistrationsCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Supplier.class, BeanDefinition.class)); type.addMethod(MethodSpec.methodBuilder("get") .addModifiers(Modifier.PUBLIC) .returns(BeanDefinition.class) - .addCode("return $L;", method.toInvokeCodeBlock()).build()); + .addCode("return $L;", methodInvocation).build()); }); this.generationContext.writeGeneratedContent(); TestCompiler.forSystem().withFiles(this.generationContext.getGeneratedFiles()).compile(compiled -> diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java index 1eee3a83434..bd2bba145e9 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java @@ -31,6 +31,7 @@ import org.junit.jupiter.api.Test; import org.springframework.aot.generate.ClassNameGenerator; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.TestTarget; import org.springframework.aot.test.generate.compile.Compiled; @@ -155,11 +156,14 @@ class BeanRegistrationsAotContributionTests { MethodReference methodReference = this.beanFactoryInitializationCode .getInitializers().get(0); this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"), + this.beanFactoryInitializationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .addStatement(methodInvocation) .build()); }); this.generationContext.writeGeneratedContent(); diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java index 01a78dda3b4..c6986c7c4b0 100644 --- a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java @@ -25,6 +25,7 @@ import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; +import org.springframework.javapoet.ClassName; /** * Mock {@link BeanFactoryInitializationCode} implementation. @@ -46,6 +47,9 @@ public class MockBeanFactoryInitializationCode implements BeanFactoryInitializat .addForFeature("TestCode", this.typeBuilder); } + public ClassName getClassName() { + return this.generatedClass.getName(); + } public DeferredTypeBuilder getTypeBuilder() { return this.typeBuilder; diff --git a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java index 3c2fa738179..02db1a37c9c 100644 --- a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java +++ b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java @@ -33,7 +33,6 @@ import org.apache.commons.logging.LogFactory; import org.springframework.aop.framework.autoproxy.AutoProxyUtils; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.ResourceHints; import org.springframework.aot.hint.TypeReference; import org.springframework.beans.PropertyValues; @@ -536,7 +535,7 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo .add("addImportAwareBeanPostProcessors", method -> generateAddPostProcessorMethod(method, mappings)); beanFactoryInitializationCode - .addInitializer(MethodReference.of(generatedMethod.getName())); + .addInitializer(generatedMethod.toMethodReference()); ResourceHints hints = generationContext.getRuntimeHints().resources(); mappings.forEach( (target, from) -> hints.registerType(TypeReference.of(from))); diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java index 29f502c7353..1be6885e9f8 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java @@ -18,6 +18,7 @@ package org.springframework.context.aot; import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import javax.lang.model.element.Modifier; @@ -25,16 +26,24 @@ import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver; import org.springframework.context.support.GenericApplicationContext; import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.env.Environment; +import org.springframework.core.io.ResourceLoader; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.javapoet.TypeName; import org.springframework.javapoet.TypeSpec; +import org.springframework.lang.Nullable; /** * Internal code generator to create the {@link ApplicationContextInitializer}. @@ -88,12 +97,17 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia BEAN_FACTORY_VARIABLE, ContextAnnotationAutowireCandidateResolver.class); code.addStatement("$L.setDependencyComparator($T.INSTANCE)", BEAN_FACTORY_VARIABLE, AnnotationAwareOrderComparator.class); + ArgumentCodeGenerator argCodeGenerator = createInitializerMethodArgumentCodeGenerator(); for (MethodReference initializer : this.initializers) { - code.addStatement(initializer.toInvokeCodeBlock(CodeBlock.of(BEAN_FACTORY_VARIABLE))); + code.addStatement(initializer.toInvokeCodeBlock(argCodeGenerator, this.generatedClass.getName())); } return code.build(); } + static ArgumentCodeGenerator createInitializerMethodArgumentCodeGenerator() { + return ArgumentCodeGenerator.from(new InitializerMethodArgumentCodeGenerator()); + } + GeneratedClass getGeneratedClass() { return this.generatedClass; } @@ -108,4 +122,30 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia this.initializers.add(methodReference); } + private static class InitializerMethodArgumentCodeGenerator implements Function { + + @Override + @Nullable + public CodeBlock apply(TypeName typeName) { + return (typeName instanceof ClassName className ? apply(className) : null); + } + + @Nullable + private CodeBlock apply(ClassName className) { + String name = className.canonicalName(); + if (name.equals(DefaultListableBeanFactory.class.getName()) + || name.equals(ConfigurableListableBeanFactory.class.getName())) { + return CodeBlock.of(BEAN_FACTORY_VARIABLE); + } + else if (name.equals(ConfigurableEnvironment.class.getName()) + || name.equals(Environment.class.getName())) { + return CodeBlock.of("$L.getConfigurableEnvironment()", APPLICATION_CONTEXT_VARIABLE); + } + else if (name.equals(ResourceLoader.class.getName())) { + return CodeBlock.of(APPLICATION_CONTEXT_VARIABLE); + } + return null; + } + } + } diff --git a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java index 24bef4dcf15..29961c61070 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java @@ -27,6 +27,7 @@ import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.ResourcePatternHint; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.Compiled; @@ -162,11 +163,14 @@ class ConfigurationClassPostProcessorAotContributionTests { private void compile(BiConsumer, Compiled> result) { MethodReference methodReference = this.beanFactoryInitializationCode.getInitializers().get(0); this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"), + this.beanFactoryInitializationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .addStatement(methodInvocation) .build()); }); this.generationContext.writeGeneratedContent(); diff --git a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextInitializationCodeGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextInitializationCodeGeneratorTests.java new file mode 100644 index 00000000000..1155be09528 --- /dev/null +++ b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextInitializationCodeGeneratorTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.aot; + +import java.util.stream.Stream; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.AbstractBeanFactory; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.env.Environment; +import org.springframework.core.env.StandardEnvironment; +import org.springframework.core.io.ResourceLoader; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link ApplicationContextInitializationCodeGenerator}. + * + * @author Stephane Nicoll + */ +class ApplicationContextInitializationCodeGeneratorTests { + + private static final ArgumentCodeGenerator argCodeGenerator = ApplicationContextInitializationCodeGenerator. + createInitializerMethodArgumentCodeGenerator(); + + @ParameterizedTest + @MethodSource("methodArguments") + void argumentsForSupportedTypesAreResolved(Class target, String expectedArgument) { + CodeBlock code = CodeBlock.of(expectedArgument); + assertThat(argCodeGenerator.generateCode(ClassName.get(target))).isEqualTo(code); + } + + @Test + void argumentForUnsupportedBeanFactoryIsNotResolved() { + assertThat(argCodeGenerator.generateCode(ClassName.get(AbstractBeanFactory.class))).isNull(); + } + + @Test + void argumentForUnsupportedEnvironmentIsNotResolved() { + assertThat(argCodeGenerator.generateCode(ClassName.get(StandardEnvironment.class))).isNull(); + } + + static Stream methodArguments() { + String applicationContext = "applicationContext"; + String environment = applicationContext + ".getConfigurableEnvironment()"; + return Stream.of( + Arguments.of(DefaultListableBeanFactory.class, "beanFactory"), + Arguments.of(ConfigurableListableBeanFactory.class, "beanFactory"), + Arguments.of(ConfigurableEnvironment.class, environment), + Arguments.of(Environment.class, environment), + Arguments.of(ResourceLoader.class, applicationContext)); + } + +} diff --git a/spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java b/spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java new file mode 100644 index 00000000000..b3a3ab117d8 --- /dev/null +++ b/spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java @@ -0,0 +1,134 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.generate; + +import java.util.ArrayList; +import java.util.List; + +import javax.lang.model.element.Modifier; + +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.TypeName; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Default {@link MethodReference} implementation based on a {@link MethodSpec}. + * + * @author Stephane Nicoll + * @author Phillip Webb + * @since 6.0 + */ +public class DefaultMethodReference implements MethodReference { + + private final MethodSpec method; + + @Nullable + private final ClassName declaringClass; + + public DefaultMethodReference(MethodSpec method, @Nullable ClassName declaringClass) { + this.method = method; + this.declaringClass = declaringClass; + } + + @Override + public CodeBlock toCodeBlock() { + String methodName = this.method.name; + if (isStatic()) { + Assert.notNull(this.declaringClass, "static method reference must define a declaring class"); + return CodeBlock.of("$T::$L", this.declaringClass, methodName); + } + else { + return CodeBlock.of("this::$L", methodName); + } + } + + public CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator, + @Nullable ClassName targetClassName) { + String methodName = this.method.name; + CodeBlock.Builder code = CodeBlock.builder(); + if (isStatic()) { + Assert.notNull(this.declaringClass, "static method reference must define a declaring class"); + if (isSameDeclaringClass(targetClassName)) { + code.add("$L", methodName); + } + else { + code.add("$T.$L", this.declaringClass, methodName); + } + } + else { + if (!isSameDeclaringClass(targetClassName)) { + code.add(instantiateDeclaringClass(this.declaringClass)); + } + code.add("$L", methodName); + } + code.add("("); + addArguments(code, argumentCodeGenerator); + code.add(")"); + return code.build(); + } + + /** + * Add the code for the method arguments using the specified + * {@link ArgumentCodeGenerator} if necessary. + * @param code the code builder to use to add method arguments + * @param argumentCodeGenerator the code generator to use + */ + protected void addArguments(CodeBlock.Builder code, ArgumentCodeGenerator argumentCodeGenerator) { + List arguments = new ArrayList<>(); + TypeName[] argumentTypes = this.method.parameters.stream() + .map(parameter -> parameter.type).toArray(TypeName[]::new); + for (int i = 0; i < argumentTypes.length; i++) { + TypeName argumentType = argumentTypes[i]; + CodeBlock argumentCode = argumentCodeGenerator.generateCode(argumentType); + if (argumentCode == null) { + throw new IllegalArgumentException("Could not generate code for " + this + + ": parameter " + i + " of type " + argumentType + " is not supported"); + } + arguments.add(argumentCode); + } + code.add(CodeBlock.join(arguments, ", ")); + } + + protected CodeBlock instantiateDeclaringClass(ClassName declaringClass) { + return CodeBlock.of("new $T().", declaringClass); + } + + private boolean isStatic() { + return this.method.modifiers.contains(Modifier.STATIC); + } + + private boolean isSameDeclaringClass(ClassName declaringClass) { + return this.declaringClass == null || this.declaringClass.equals(declaringClass); + } + + @Override + public String toString() { + String methodName = this.method.name; + if (isStatic()) { + return this.declaringClass + "::" + methodName; + } + else { + return ((this.declaringClass != null) + ? "<" + this.declaringClass + ">" : "") + + "::" + methodName; + } + } + +} diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java index be591208fff..5a1bb302187 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java @@ -55,7 +55,7 @@ public final class GeneratedClass { GeneratedClass(ClassName name, Consumer type) { this.name = name; this.type = type; - this.methods = new GeneratedMethods(this::generateSequencedMethodName); + this.methods = new GeneratedMethods(name, this::generateSequencedMethodName); } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java index 4d351241d0d..b09d36f61f2 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java @@ -18,6 +18,7 @@ package org.springframework.aot.generate; import java.util.function.Consumer; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.MethodSpec; import org.springframework.util.Assert; @@ -25,11 +26,14 @@ import org.springframework.util.Assert; * A generated method. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 * @see GeneratedMethods */ public final class GeneratedMethod { + private final ClassName className; + private final String name; private final MethodSpec methodSpec; @@ -39,12 +43,14 @@ public final class GeneratedMethod { * Create a new {@link GeneratedMethod} instance with the given name. This * constructor is package-private since names should only be generated via * {@link GeneratedMethods}. + * @param className the declaring class of the method * @param name the generated method name * @param method consumer to generate the method */ - GeneratedMethod(String name, Consumer method) { + GeneratedMethod(ClassName className, String name, Consumer method) { + this.className = className; this.name = name; - MethodSpec.Builder builder = MethodSpec.methodBuilder(getName()); + MethodSpec.Builder builder = MethodSpec.methodBuilder(this.name); method.accept(builder); this.methodSpec = builder.build(); Assert.state(this.name.equals(this.methodSpec.name), @@ -60,6 +66,14 @@ public final class GeneratedMethod { return this.name; } + /** + * Return a {@link MethodReference} to this generated method. + * @return a method reference + */ + public MethodReference toMethodReference() { + return new DefaultMethodReference(this.methodSpec, this.className); + } + /** * Return the {@link MethodSpec} for this generated method. * @return the method spec diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethods.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethods.java index e16779ab779..0c65c37582a 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethods.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethods.java @@ -22,6 +22,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Stream; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.MethodSpec.Builder; import org.springframework.util.Assert; @@ -30,11 +31,14 @@ import org.springframework.util.Assert; * A managed collection of generated methods. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 * @see GeneratedMethod */ public class GeneratedMethods { + private final ClassName className; + private final Function methodNameGenerator; private final MethodName prefix; @@ -44,18 +48,22 @@ public class GeneratedMethods { /** * Create a new {@link GeneratedMethods} using the specified method name * generator. + * @param className the declaring class name * @param methodNameGenerator the method name generator */ - GeneratedMethods(Function methodNameGenerator) { + GeneratedMethods(ClassName className, Function methodNameGenerator) { + Assert.notNull(className, "'className' must not be null"); Assert.notNull(methodNameGenerator, "'methodNameGenerator' must not be null"); + this.className = className; this.methodNameGenerator = methodNameGenerator; this.prefix = MethodName.NONE; this.generatedMethods = new ArrayList<>(); } - private GeneratedMethods(Function methodNameGenerator, + private GeneratedMethods(ClassName className, Function methodNameGenerator, MethodName prefix, List generatedMethods) { + this.className = className; this.methodNameGenerator = methodNameGenerator; this.prefix = prefix; this.generatedMethods = generatedMethods; @@ -82,7 +90,7 @@ public class GeneratedMethods { Assert.notNull(suggestedNameParts, "'suggestedNameParts' must not be null"); Assert.notNull(method, "'method' must not be null"); String generatedName = this.methodNameGenerator.apply(this.prefix.and(suggestedNameParts)); - GeneratedMethod generatedMethod = new GeneratedMethod(generatedName, method); + GeneratedMethod generatedMethod = new GeneratedMethod(this.className, generatedName, method); this.generatedMethods.add(generatedMethod); return generatedMethod; } @@ -90,7 +98,8 @@ public class GeneratedMethods { public GeneratedMethods withPrefix(String prefix) { Assert.notNull(prefix, "'prefix' must not be null"); - return new GeneratedMethods(this.methodNameGenerator, this.prefix.and(prefix), this.generatedMethods); + return new GeneratedMethods(this.className, this.methodNameGenerator, + this.prefix.and(prefix), this.generatedMethods); } /** diff --git a/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java b/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java index 80359dd314b..f6dda971007 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java @@ -16,223 +16,124 @@ package org.springframework.aot.generate; +import java.util.function.Function; + import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.TypeName; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; /** - * A reference to a static or instance method. + * A reference to a method with convenient code generation for + * referencing, or invoking it. * + * @author Stephane Nicoll * @author Phillip Webb * @since 6.0 */ -public final class MethodReference { - - private final Kind kind; - - @Nullable - private final ClassName declaringClass; - - private final String methodName; - - - private MethodReference(Kind kind, @Nullable ClassName declaringClass, - String methodName) { - this.kind = kind; - this.declaringClass = declaringClass; - this.methodName = methodName; - } - - - /** - * Create a new method reference that refers to the given instance method. - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference of(String methodName) { - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.INSTANCE, null, methodName); - } - - /** - * Create a new method reference that refers to the given instance method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference of(Class declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.INSTANCE, ClassName.get(declaringClass), - methodName); - } - - /** - * Create a new method reference that refers to the given instance method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference of(ClassName declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.INSTANCE, declaringClass, methodName); - } - - /** - * Create a new method reference that refers to the given static method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference ofStatic(Class declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.STATIC, ClassName.get(declaringClass), - methodName); - } - - /** - * Create a new method reference that refers to the given static method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference ofStatic(ClassName declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.STATIC, declaringClass, methodName); - } - - - /** - * Return the referenced declaring class. - * @return the declaring class - */ - @Nullable - public ClassName getDeclaringClass() { - return this.declaringClass; - } - - /** - * Return the referenced method name. - * @return the method name - */ - public String getMethodName() { - return this.methodName; - } +public interface MethodReference { /** * Return this method reference as a {@link CodeBlock}. If the reference is * to an instance method then {@code this::} will be returned. * @return a code block for the method reference. - * @see #toCodeBlock(String) */ - public CodeBlock toCodeBlock() { - return toCodeBlock(null); + CodeBlock toCodeBlock(); + + /** + * Return this method reference as a {@link CodeBlock} using the specified + * {@link ArgumentCodeGenerator}. + * @param argumentCodeGenerator the argument code generator to use + * @return a code block to invoke the method + */ + default CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator) { + return toInvokeCodeBlock(argumentCodeGenerator, null); } /** - * Return this method reference as a {@link CodeBlock}. If the reference is - * to an instance method and {@code instanceVariable} is {@code null} then - * {@code this::} will be returned. No {@code instanceVariable} - * can be specified for static method references. - * @param instanceVariable the instance variable or {@code null} - * @return a code block for the method reference. - * @see #toCodeBlock(String) + * Return this method reference as a {@link CodeBlock} using the specified + * {@link ArgumentCodeGenerator}. The {@code targetClassName} defines the + * context in which the method invocation is added. + *

If the caller has an instance of the type in which this method is + * defined, it can hint that by specifying the type as a target class. + * @param argumentCodeGenerator the argument code generator to use + * @param targetClassName the target class name + * @return a code block to invoke the method */ - public CodeBlock toCodeBlock(@Nullable String instanceVariable) { - return switch (this.kind) { - case INSTANCE -> toCodeBlockForInstance(instanceVariable); - case STATIC -> toCodeBlockForStatic(instanceVariable); - }; - } + CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator, @Nullable ClassName targetClassName); - private CodeBlock toCodeBlockForInstance(@Nullable String instanceVariable) { - instanceVariable = (instanceVariable != null) ? instanceVariable : "this"; - return CodeBlock.of("$L::$L", instanceVariable, this.methodName); - } - - private CodeBlock toCodeBlockForStatic(@Nullable String instanceVariable) { - Assert.isTrue(instanceVariable == null, - "'instanceVariable' must be null for static method references"); - return CodeBlock.of("$T::$L", this.declaringClass, this.methodName); - } /** - * Return this method reference as an invocation {@link CodeBlock}. - * @param arguments the method arguments - * @return a code back to invoke the method + * Strategy for generating code for arguments based on their type. */ - public CodeBlock toInvokeCodeBlock(CodeBlock... arguments) { - return toInvokeCodeBlock(null, arguments); - } + interface ArgumentCodeGenerator { - /** - * Return this method reference as an invocation {@link CodeBlock}. - * @param instanceVariable the instance variable or {@code null} - * @param arguments the method arguments - * @return a code back to invoke the method - */ - public CodeBlock toInvokeCodeBlock(@Nullable String instanceVariable, - CodeBlock... arguments) { + /** + * Generate the code for the given argument type. If this type is + * not supported, return {@code null}. + * @param argumentType the argument type + * @return the code for this argument, or {@code null} + */ + @Nullable + CodeBlock generateCode(TypeName argumentType); - return switch (this.kind) { - case INSTANCE -> toInvokeCodeBlockForInstance(instanceVariable, arguments); - case STATIC -> toInvokeCodeBlockForStatic(instanceVariable, arguments); - }; - } - - private CodeBlock toInvokeCodeBlockForInstance(@Nullable String instanceVariable, - CodeBlock[] arguments) { - - CodeBlock.Builder code = CodeBlock.builder(); - if (instanceVariable != null) { - code.add("$L.", instanceVariable); + /** + * Factory method that returns an {@link ArgumentCodeGenerator} that + * always returns {@code null}. + * @return a new {@link ArgumentCodeGenerator} instance + */ + static ArgumentCodeGenerator none() { + return from(type -> null); } - else if (this.declaringClass != null) { - code.add("new $T().", this.declaringClass); + + /** + * Factory method that can be used to create an {@link ArgumentCodeGenerator} + * that support only the given argument type. + * @param argumentType the argument type + * @param argumentCode the code for an argument of that type + * @return a new {@link ArgumentCodeGenerator} instance + */ + static ArgumentCodeGenerator of(Class argumentType, String argumentCode) { + return from(candidateType -> (candidateType.equals(ClassName.get(argumentType)) + ? CodeBlock.of(argumentCode) : null)); } - code.add("$L", this.methodName); - addArguments(code, arguments); - return code.build(); - } - private CodeBlock toInvokeCodeBlockForStatic(@Nullable String instanceVariable, - CodeBlock[] arguments) { - - Assert.isTrue(instanceVariable == null, - "'instanceVariable' must be null for static method references"); - CodeBlock.Builder code = CodeBlock.builder(); - code.add("$T.$L", this.declaringClass, this.methodName); - addArguments(code, arguments); - return code.build(); - } - - private void addArguments(CodeBlock.Builder code, CodeBlock[] arguments) { - code.add("("); - for (int i = 0; i < arguments.length; i++) { - if (i != 0) { - code.add(", "); - } - code.add(arguments[i]); + /** + * Factory method that creates a new {@link ArgumentCodeGenerator} from + * a lambda friendly function. The given function is provided with the + * argument type and must provide the code to use or {@code null} if + * the type is not supported. + * @param function the resolver function + * @return a new {@link ArgumentCodeGenerator} instance backed by the function + */ + static ArgumentCodeGenerator from(Function function) { + return function::apply; } - code.add(")"); - } - @Override - public String toString() { - return switch (this.kind) { - case INSTANCE -> ((this.declaringClass != null) ? "<" + this.declaringClass + ">" - : "") + "::" + this.methodName; - case STATIC -> this.declaringClass + "::" + this.methodName; - }; - } + /** + * Create a new composed {@link ArgumentCodeGenerator} by combining this + * generator with supporting the given argument type. + * @param argumentType the argument type + * @param argumentCode the code for an argument of that type + * @return a new composite {@link ArgumentCodeGenerator} instance + */ + default ArgumentCodeGenerator and(Class argumentType, String argumentCode) { + return and(ArgumentCodeGenerator.of(argumentType, argumentCode)); + } + /** + * Create a new composed {@link ArgumentCodeGenerator} by combining this + * generator with the given generator. + * @param argumentCodeGenerator the argument generator to add + * @return a new composite {@link ArgumentCodeGenerator} instance + */ + default ArgumentCodeGenerator and(ArgumentCodeGenerator argumentCodeGenerator) { + return from(type -> { + CodeBlock code = generateCode(type); + return (code != null ? code : argumentCodeGenerator.generateCode(type)); + }); + } - private enum Kind { - INSTANCE, STATIC } } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java b/spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java new file mode 100644 index 00000000000..b9643151ca9 --- /dev/null +++ b/spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java @@ -0,0 +1,199 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.generate; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.MethodSpec.Builder; +import org.springframework.javapoet.TypeName; +import org.springframework.lang.Nullable; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link DefaultMethodReference}. + * + * @author Phillip Webb + * @author Stephane Nicoll + */ +class DefaultMethodReferenceTests { + + private static final String EXPECTED_STATIC = "org.springframework.aot.generate.DefaultMethodReferenceTests::someMethod"; + + private static final String EXPECTED_ANONYMOUS_INSTANCE = "::someMethod"; + + private static final String EXPECTED_DECLARED_INSTANCE = "::someMethod"; + + private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test"); + + private static final ClassName INITIALIZER_CLASS_NAME = ClassName.get("com.example", "Initializer"); + + @Test + void createWithStringCreatesMethodReference() { + MethodSpec method = createTestMethod("someMethod", new TypeName[0]); + MethodReference reference = new DefaultMethodReference(method, null); + assertThat(reference).hasToString(EXPECTED_ANONYMOUS_INSTANCE); + } + + @Test + void createWithClassNameAndStringCreateMethodReference() { + ClassName declaringClass = ClassName.get(DefaultMethodReferenceTests.class); + MethodReference reference = createMethodReference("someMethod", new TypeName[0], declaringClass); + assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE); + } + + @Test + void createWithStaticAndClassAndStringCreatesMethodReference() { + ClassName declaringClass = ClassName.get(DefaultMethodReferenceTests.class); + MethodReference reference = createStaticMethodReference("someMethod", declaringClass); + assertThat(reference).hasToString(EXPECTED_STATIC); + } + + @Test + void toCodeBlock() { + assertThat(createLocalMethodReference("methodName").toCodeBlock()) + .isEqualTo(CodeBlock.of("this::methodName")); + } + + @Test + void toCodeBlockWithStaticMethod() { + assertThat(createStaticMethodReference("methodName", TEST_CLASS_NAME).toCodeBlock()) + .isEqualTo(CodeBlock.of("com.example.Test::methodName")); + } + + @Test + void toCodeBlockWithStaticMethodRequiresDeclaringClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0], Modifier.STATIC); + MethodReference methodReference = new DefaultMethodReference(method, null); + assertThatIllegalArgumentException().isThrownBy(methodReference::toCodeBlock) + .withMessage("static method reference must define a declaring class"); + } + + @Test + void toInvokeCodeBlockWithNullDeclaringClassAndTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, null); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME)) + .isEqualTo(CodeBlock.of("methodName()")); + } + + @Test + void toInvokeCodeBlockWithNullDeclaringClassAndNullTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, null); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none())) + .isEqualTo(CodeBlock.of("methodName()")); + } + + @Test + void toInvokeCodeBlockWithDeclaringClassAndNullTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none())) + .isEqualTo(CodeBlock.of("new com.example.Test().methodName()")); + } + + @Test + void toInvokeCodeBlockWithMatchingTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME); + CodeBlock invocation = methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME); + // Assume com.example.Test is in a `test` variable. + assertThat(CodeBlock.of("$L.$L", "test", invocation)).isEqualTo(CodeBlock.of("test.methodName()")); + } + + @Test + void toInvokeCodeBlockWithNonMatchingDeclaringClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), INITIALIZER_CLASS_NAME)) + .isEqualTo(CodeBlock.of("new com.example.Test().methodName()")); + } + + @Test + void toInvokeCodeBlockWithMatchingArg() { + MethodReference methodReference = createLocalMethodReference("methodName", ClassName.get(String.class)); + ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(String.class, "stringArg"); + assertThat(methodReference.toInvokeCodeBlock(argCodeGenerator)) + .isEqualTo(CodeBlock.of("methodName(stringArg)")); + } + + @Test + void toInvokeCodeBlockWithMatchingArgs() { + MethodReference methodReference = createLocalMethodReference("methodName", + ClassName.get(Integer.class), ClassName.get(String.class)); + ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(String.class, "stringArg") + .and(Integer.class, "integerArg"); + assertThat(methodReference.toInvokeCodeBlock(argCodeGenerator)) + .isEqualTo(CodeBlock.of("methodName(integerArg, stringArg)")); + } + + @Test + void toInvokeCodeBlockWithNonMatchingArg() { + MethodReference methodReference = createLocalMethodReference("methodName", + ClassName.get(Integer.class), ClassName.get(String.class)); + ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(Integer.class, "integerArg"); + assertThatIllegalArgumentException().isThrownBy(() -> methodReference.toInvokeCodeBlock(argCodeGenerator)) + .withMessageContaining("parameter 1 of type java.lang.String is not supported"); + } + + @Test + void toInvokeCodeBlockWithStaticMethodAndMatchingDeclaringClass() { + MethodReference methodReference = createStaticMethodReference("methodName", TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME)) + .isEqualTo(CodeBlock.of("methodName()")); + } + + @Test + void toInvokeCodeBlockWithStaticMethodAndSeparateDeclaringClass() { + MethodReference methodReference = createStaticMethodReference("methodName", TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), INITIALIZER_CLASS_NAME)) + .isEqualTo(CodeBlock.of("com.example.Test.methodName()")); + } + + + private MethodReference createLocalMethodReference(String name, TypeName... argumentTypes) { + return createMethodReference(name, argumentTypes, null); + } + + private MethodReference createMethodReference(String name, TypeName[] argumentTypes, @Nullable ClassName declaringClass) { + MethodSpec method = createTestMethod(name, argumentTypes); + return new DefaultMethodReference(method, declaringClass); + } + + private MethodReference createStaticMethodReference(String name, ClassName declaringClass, TypeName... argumentTypes) { + MethodSpec method = createTestMethod(name, argumentTypes, Modifier.STATIC); + return new DefaultMethodReference(method, declaringClass); + } + + private MethodSpec createTestMethod(String name, TypeName[] argumentTypes, Modifier... modifiers) { + Builder method = MethodSpec.methodBuilder(name); + for (int i = 0; i < argumentTypes.length; i++) { + method.addParameter(argumentTypes[i], "args" + i); + } + method.addModifiers(modifiers); + return method.build(); + } + +} diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java index f080ce8ed20..6e865bd4eb8 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java @@ -18,8 +18,13 @@ package org.springframework.aot.generate; import java.util.function.Consumer; +import javax.lang.model.element.Modifier; + import org.junit.jupiter.api.Test; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; import static org.assertj.core.api.Assertions.assertThat; @@ -29,30 +34,56 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException; * Tests for {@link GeneratedMethod}. * * @author Phillip Webb + * @author Stephane Nicoll */ class GeneratedMethodTests { - private static final Consumer methodSpecCustomizer = method -> {}; + private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test"); + + private static final Consumer emptyMethod = method -> {}; private static final String NAME = "spring"; @Test void getNameReturnsName() { - GeneratedMethod generatedMethod = new GeneratedMethod(NAME, methodSpecCustomizer); + GeneratedMethod generatedMethod = new GeneratedMethod(TEST_CLASS_NAME, NAME, emptyMethod); assertThat(generatedMethod.getName()).isSameAs(NAME); } @Test void generateMethodSpecReturnsMethodSpec() { - GeneratedMethod generatedMethod = new GeneratedMethod(NAME, method -> method.addJavadoc("Test")); + GeneratedMethod generatedMethod = create(method -> method.addJavadoc("Test")); assertThat(generatedMethod.getMethodSpec().javadoc).asString().contains("Test"); } @Test void generateMethodSpecWhenMethodNameIsChangedThrowsException() { assertThatIllegalStateException().isThrownBy(() -> - new GeneratedMethod(NAME, method -> method.setName("badname")).getMethodSpec()) - .withMessage("'method' consumer must not change the generated method name"); + create(method -> method.setName("badname")).getMethodSpec()) + .withMessage("'method' consumer must not change the generated method name"); + } + + @Test + void toMethodReferenceWithInstanceMethod() { + GeneratedMethod generatedMethod = create(emptyMethod); + MethodReference methodReference = generatedMethod.toMethodReference(); + assertThat(methodReference).isNotNull(); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME)) + .isEqualTo(CodeBlock.of("spring()")); + } + + @Test + void toMethodReferenceWithStaticMethod() { + GeneratedMethod generatedMethod = create(method -> method.addModifiers(Modifier.STATIC)); + MethodReference methodReference = generatedMethod.toMethodReference(); + assertThat(methodReference).isNotNull(); + ClassName anotherDeclaringClass = ClassName.get("com.example", "Another"); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), anotherDeclaringClass)) + .isEqualTo(CodeBlock.of("com.example.Test.spring()")); + } + + private GeneratedMethod create(Consumer method) { + return new GeneratedMethod(TEST_CLASS_NAME, NAME, method); } } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodsTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodsTests.java index 2ae2517e6c8..1b051f691e0 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodsTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodsTests.java @@ -23,6 +23,7 @@ import java.util.function.Function; import org.junit.jupiter.api.Test; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.MethodSpec; import static org.assertj.core.api.Assertions.assertThat; @@ -32,38 +33,49 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException * Tests for {@link GeneratedMethods}. * * @author Phillip Webb + * @author Stephane Nicoll */ class GeneratedMethodsTests { + private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test"); + private static final Consumer methodSpecCustomizer = method -> {}; - private final GeneratedMethods methods = new GeneratedMethods(MethodName::toString); + private final GeneratedMethods methods = new GeneratedMethods(TEST_CLASS_NAME, MethodName::toString); + + @Test + void createWhenClassNameIsNullThrowsException() { + assertThatIllegalArgumentException().isThrownBy(() -> + new GeneratedMethods(null, MethodName::toString)) + .withMessage("'className' must not be null"); + } @Test void createWhenMethodNameGeneratorIsNullThrowsException() { - assertThatIllegalArgumentException().isThrownBy(() -> new GeneratedMethods(null)) + assertThatIllegalArgumentException().isThrownBy(() -> + new GeneratedMethods(TEST_CLASS_NAME, null)) .withMessage("'methodNameGenerator' must not be null"); } @Test void createWithExistingGeneratorUsesGenerator() { Function generator = name -> "__" + name.toString(); - GeneratedMethods methods = new GeneratedMethods(generator); + GeneratedMethods methods = new GeneratedMethods(TEST_CLASS_NAME, generator); assertThat(methods.add("test", methodSpecCustomizer).getName()).hasToString("__test"); } @Test void addWithStringNameWhenSuggestedMethodIsNullThrowsException() { assertThatIllegalArgumentException().isThrownBy(() -> - this.methods.add((String) null, methodSpecCustomizer)) - .withMessage("'suggestedName' must not be null"); + this.methods.add((String) null, methodSpecCustomizer)) + .withMessage("'suggestedName' must not be null"); } @Test void addWithStringNameWhenMethodIsNullThrowsException() { assertThatIllegalArgumentException().isThrownBy(() -> - this.methods.add("test", null)) - .withMessage("'method' must not be null"); + this.methods.add("test", null)) + .withMessage("'method' must not be null"); } @Test @@ -71,7 +83,7 @@ class GeneratedMethodsTests { this.methods.add("springBeans", methodSpecCustomizer); this.methods.add("springContext", methodSpecCustomizer); assertThat(this.methods.stream().map(GeneratedMethod::getName).map(Object::toString)) - .containsExactly("springBeans", "springContext"); + .containsExactly("springBeans", "springContext"); } @Test diff --git a/spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java b/spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java deleted file mode 100644 index de5c79667b4..00000000000 --- a/spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -import org.junit.jupiter.api.Test; - -import org.springframework.javapoet.ClassName; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; - -/** - * Tests for {@link MethodReference}. - * - * @author Phillip Webb - */ -class MethodReferenceTests { - - private static final String EXPECTED_STATIC = "org.springframework.aot.generate.MethodReferenceTests::someMethod"; - - private static final String EXPECTED_ANONYMOUS_INSTANCE = "::someMethod"; - - private static final String EXPECTED_DECLARED_INSTANCE = "::someMethod"; - - - @Test - void ofWithStringWhenMethodNameIsNullThrowsException() { - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofWithStringCreatesMethodReference() { - String methodName = "someMethod"; - MethodReference reference = MethodReference.of(methodName); - assertThat(reference).hasToString(EXPECTED_ANONYMOUS_INSTANCE); - } - - @Test - void ofWithClassAndStringWhenDeclaringClassIsNullThrowsException() { - Class declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofWithClassAndStringWhenMethodNameIsNullThrowsException() { - Class declaringClass = MethodReferenceTests.class; - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofWithClassAndStringCreatesMethodReference() { - Class declaringClass = MethodReferenceTests.class; - String methodName = "someMethod"; - MethodReference reference = MethodReference.of(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE); - } - - @Test - void ofWithClassNameAndStringWhenDeclaringClassIsNullThrowsException() { - ClassName declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofWithClassNameAndStringWhenMethodNameIsNullThrowsException() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofWithClassNameAndStringCreateMethodReference() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = "someMethod"; - MethodReference reference = MethodReference.of(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE); - } - - @Test - void ofStaticWithClassAndStringWhenDeclaringClassIsNullThrowsException() { - Class declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofStaticWithClassAndStringWhenMethodNameIsEmptyThrowsException() { - Class declaringClass = MethodReferenceTests.class; - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofStaticWithClassAndStringCreatesMethodReference() { - Class declaringClass = MethodReferenceTests.class; - String methodName = "someMethod"; - MethodReference reference = MethodReference.ofStatic(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_STATIC); - } - - @Test - void ofStaticWithClassNameAndGeneratedMethodNameWhenDeclaringClassIsNullThrowsException() { - ClassName declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofStaticWithClassNameAndGeneratedMethodNameWhenMethodNameIsEmptyThrowsException() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofStaticWithClassNameAndGeneratedMethodNameCreatesMethodReference() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = "someMethod"; - MethodReference reference = MethodReference.ofStatic(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_STATIC); - } - - @Test - void toCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toCodeBlock(null)).hasToString("this::someMethod"); - } - - @Test - void toCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNotNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toCodeBlock("myInstance")) - .hasToString("myInstance::someMethod"); - } - - @Test - void toCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThat(reference.toCodeBlock(null)).hasToString(EXPECTED_STATIC); - } - - @Test - void toCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNotNullThrowsException() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThatIllegalArgumentException() - .isThrownBy(() -> reference.toCodeBlock("myInstance")).withMessage( - "'instanceVariable' must be null for static method references"); - } - - @Test - void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toInvokeCodeBlock()).hasToString("someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNullAndHasDecalredClass() { - MethodReference reference = MethodReference.of(MethodReferenceTests.class, - "someMethod"); - assertThat(reference.toInvokeCodeBlock()).hasToString( - "new org.springframework.aot.generate.MethodReferenceTests().someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNotNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toInvokeCodeBlock("myInstance")) - .hasToString("myInstance.someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThat(reference.toInvokeCodeBlock()).hasToString( - "org.springframework.aot.generate.MethodReferenceTests.someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNotNullThrowsException() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThatIllegalArgumentException() - .isThrownBy(() -> reference.toInvokeCodeBlock("myInstance")).withMessage( - "'instanceVariable' must be null for static method references"); - } - -} diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java index cf88dc8acbc..e8d8fe3848b 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java @@ -99,7 +99,7 @@ class PersistenceManagedTypesBeanRegistrationAotProcessor implements BeanRegistr List.class, toCodeBlock(persistenceManagedTypes.getManagedPackages())); method.addStatement("return $T.of($L, $L)", beanType, "managedClassNames", "managedPackages"); }); - return CodeBlock.of("() -> $T.$L()", beanRegistrationCode.getClassName(), generatedMethod.getName()); + return generatedMethod.toMethodReference().toCodeBlock(); } private CodeBlock toCodeBlock(List values) { diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java index 47078596607..abac448ac02 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java @@ -43,7 +43,6 @@ import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.BeanUtils; import org.springframework.beans.PropertyValues; @@ -797,8 +796,7 @@ public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwar method.returns(this.target); method.addCode(generateMethodCode(generationContext.getRuntimeHints(), generatedClass.getMethods())); }); - beanRegistrationCode.addInstancePostProcessor(MethodReference - .ofStatic(generatedClass.getName(), generatedMethod.getName())); + beanRegistrationCode.addInstancePostProcessor(generatedMethod.toMethodReference()); } private CodeBlock generateMethodCode(RuntimeHints hints, GeneratedMethods generatedMethods) {