diff --git a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java index e29a3192492..1ce8b798680 100644 --- a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java +++ b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java @@ -26,6 +26,7 @@ import org.junit.jupiter.api.Test; import org.springframework.aop.framework.AopInfrastructureBean; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.Compiled; import org.springframework.aot.test.generate.compile.TestCompiler; @@ -139,11 +140,14 @@ class ScopedProxyBeanRegistrationAotProcessorTests { MethodReference methodReference = this.beanFactoryInitializationCode .getInitializers().get(0); this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"), + this.beanFactoryInitializationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .addStatement(methodInvocation) .build()); }); this.generationContext.writeGeneratedContent(); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java index fc8ca237b8b..a80db112d35 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java @@ -25,6 +25,7 @@ import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; @@ -81,9 +82,11 @@ class BeanRegistrationsAotContribution MethodReference beanDefinitionMethod = beanDefinitionMethodGenerator .generateBeanDefinitionMethod(generationContext, beanRegistrationsCode); + CodeBlock methodInvocation = beanDefinitionMethod.toInvokeCodeBlock( + ArgumentCodeGenerator.none(), beanRegistrationsCode.getClassName()); code.addStatement("$L.registerBeanDefinition($S, $L)", BEAN_FACTORY_PARAMETER_NAME, beanName, - beanDefinitionMethod.toInvokeCodeBlock()); + methodInvocation); }); method.addCode(code.build()); } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index bf5719eccdb..fa04f862115 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -24,6 +24,7 @@ import java.util.function.Predicate; import org.springframework.aot.generate.AccessVisibility; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionHolder; @@ -156,7 +157,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments MethodReference generatedMethod = methodGenerator .generateBeanDefinitionMethod(generationContext, this.beanRegistrationsCode); - return generatedMethod.toInvokeCodeBlock(); + return generatedMethod.toInvokeCodeBlock(ArgumentCodeGenerator.none()); } return null; } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java index 1b597a36d37..e6cb5df84f1 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java @@ -28,6 +28,7 @@ import org.springframework.aot.generate.AccessVisibility; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.ExecutableMode; import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.RegisteredBean; @@ -297,7 +298,8 @@ class InstanceSupplierCodeGenerator { } private CodeBlock generateReturnStatement(GeneratedMethod generatedMethod) { - return generatedMethod.toMethodReference().toInvokeCodeBlock(); + return generatedMethod.toMethodReference().toInvokeCodeBlock( + ArgumentCodeGenerator.none(), this.className); } private CodeBlock generateWithGeneratorCode(boolean hasArguments, CodeBlock newInstance) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java index 2f4fe187b13..219093424e7 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java @@ -24,6 +24,7 @@ import javax.lang.model.element.Modifier; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess; @@ -161,13 +162,16 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { Class target = registeredBean.getBeanClass(); MethodReference methodReference = this.beanRegistrationCode.getInstancePostProcessors().get(0); this.beanRegistrationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(RegisteredBean.class, "registeredBean").and(target, "instance"), + this.beanRegistrationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(BiFunction.class, RegisteredBean.class, target, target)); type.addMethod(MethodSpec.methodBuilder("apply") .addModifiers(Modifier.PUBLIC) .addParameter(RegisteredBean.class, "registeredBean") .addParameter(target, "instance").returns(target) - .addStatement("return $L", methodReference.toInvokeCodeBlock(CodeBlock.of("registeredBean"), CodeBlock.of("instance"))) + .addStatement("return $L", methodInvocation) .build()); }); diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 3cc278470a1..9020bf8626a 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -30,6 +30,7 @@ import org.junit.jupiter.api.Test; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess; import org.springframework.aot.test.generate.compile.Compiled; @@ -414,12 +415,14 @@ class BeanDefinitionMethodGeneratorTests { private void compile(MethodReference method, BiConsumer result) { this.beanRegistrationsCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = method.toInvokeCodeBlock(ArgumentCodeGenerator.none(), + this.beanRegistrationsCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Supplier.class, BeanDefinition.class)); type.addMethod(MethodSpec.methodBuilder("get") .addModifiers(Modifier.PUBLIC) .returns(BeanDefinition.class) - .addCode("return $L;", method.toInvokeCodeBlock()).build()); + .addCode("return $L;", methodInvocation).build()); }); this.generationContext.writeGeneratedContent(); TestCompiler.forSystem().withFiles(this.generationContext.getGeneratedFiles()).compile(compiled -> diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java index 1eee3a83434..bd2bba145e9 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java @@ -31,6 +31,7 @@ import org.junit.jupiter.api.Test; import org.springframework.aot.generate.ClassNameGenerator; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.TestTarget; import org.springframework.aot.test.generate.compile.Compiled; @@ -155,11 +156,14 @@ class BeanRegistrationsAotContributionTests { MethodReference methodReference = this.beanFactoryInitializationCode .getInitializers().get(0); this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"), + this.beanFactoryInitializationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .addStatement(methodInvocation) .build()); }); this.generationContext.writeGeneratedContent(); diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java index 01a78dda3b4..c6986c7c4b0 100644 --- a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java @@ -25,6 +25,7 @@ import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; +import org.springframework.javapoet.ClassName; /** * Mock {@link BeanFactoryInitializationCode} implementation. @@ -46,6 +47,9 @@ public class MockBeanFactoryInitializationCode implements BeanFactoryInitializat .addForFeature("TestCode", this.typeBuilder); } + public ClassName getClassName() { + return this.generatedClass.getName(); + } public DeferredTypeBuilder getTypeBuilder() { return this.typeBuilder; diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java index 29f502c7353..b2bf870e2f5 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java @@ -25,6 +25,7 @@ import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.context.ApplicationContextInitializer; @@ -88,12 +89,17 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia BEAN_FACTORY_VARIABLE, ContextAnnotationAutowireCandidateResolver.class); code.addStatement("$L.setDependencyComparator($T.INSTANCE)", BEAN_FACTORY_VARIABLE, AnnotationAwareOrderComparator.class); + ArgumentCodeGenerator argCodeGenerator = createInitializerMethodsArgumentCodeGenerator(); for (MethodReference initializer : this.initializers) { - code.addStatement(initializer.toInvokeCodeBlock(CodeBlock.of(BEAN_FACTORY_VARIABLE))); + code.addStatement(initializer.toInvokeCodeBlock(argCodeGenerator, this.generatedClass.getName())); } return code.build(); } + private ArgumentCodeGenerator createInitializerMethodsArgumentCodeGenerator() { + return ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, BEAN_FACTORY_VARIABLE); + } + GeneratedClass getGeneratedClass() { return this.generatedClass; } diff --git a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java index 24bef4dcf15..29961c61070 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java @@ -27,6 +27,7 @@ import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.ResourcePatternHint; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.Compiled; @@ -162,11 +163,14 @@ class ConfigurationClassPostProcessorAotContributionTests { private void compile(BiConsumer, Compiled> result) { MethodReference methodReference = this.beanFactoryInitializationCode.getInitializers().get(0); this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"), + this.beanFactoryInitializationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .addStatement(methodInvocation) .build()); }); this.generationContext.writeGeneratedContent(); diff --git a/spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java b/spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java new file mode 100644 index 00000000000..b3a3ab117d8 --- /dev/null +++ b/spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java @@ -0,0 +1,134 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.generate; + +import java.util.ArrayList; +import java.util.List; + +import javax.lang.model.element.Modifier; + +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.TypeName; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Default {@link MethodReference} implementation based on a {@link MethodSpec}. + * + * @author Stephane Nicoll + * @author Phillip Webb + * @since 6.0 + */ +public class DefaultMethodReference implements MethodReference { + + private final MethodSpec method; + + @Nullable + private final ClassName declaringClass; + + public DefaultMethodReference(MethodSpec method, @Nullable ClassName declaringClass) { + this.method = method; + this.declaringClass = declaringClass; + } + + @Override + public CodeBlock toCodeBlock() { + String methodName = this.method.name; + if (isStatic()) { + Assert.notNull(this.declaringClass, "static method reference must define a declaring class"); + return CodeBlock.of("$T::$L", this.declaringClass, methodName); + } + else { + return CodeBlock.of("this::$L", methodName); + } + } + + public CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator, + @Nullable ClassName targetClassName) { + String methodName = this.method.name; + CodeBlock.Builder code = CodeBlock.builder(); + if (isStatic()) { + Assert.notNull(this.declaringClass, "static method reference must define a declaring class"); + if (isSameDeclaringClass(targetClassName)) { + code.add("$L", methodName); + } + else { + code.add("$T.$L", this.declaringClass, methodName); + } + } + else { + if (!isSameDeclaringClass(targetClassName)) { + code.add(instantiateDeclaringClass(this.declaringClass)); + } + code.add("$L", methodName); + } + code.add("("); + addArguments(code, argumentCodeGenerator); + code.add(")"); + return code.build(); + } + + /** + * Add the code for the method arguments using the specified + * {@link ArgumentCodeGenerator} if necessary. + * @param code the code builder to use to add method arguments + * @param argumentCodeGenerator the code generator to use + */ + protected void addArguments(CodeBlock.Builder code, ArgumentCodeGenerator argumentCodeGenerator) { + List arguments = new ArrayList<>(); + TypeName[] argumentTypes = this.method.parameters.stream() + .map(parameter -> parameter.type).toArray(TypeName[]::new); + for (int i = 0; i < argumentTypes.length; i++) { + TypeName argumentType = argumentTypes[i]; + CodeBlock argumentCode = argumentCodeGenerator.generateCode(argumentType); + if (argumentCode == null) { + throw new IllegalArgumentException("Could not generate code for " + this + + ": parameter " + i + " of type " + argumentType + " is not supported"); + } + arguments.add(argumentCode); + } + code.add(CodeBlock.join(arguments, ", ")); + } + + protected CodeBlock instantiateDeclaringClass(ClassName declaringClass) { + return CodeBlock.of("new $T().", declaringClass); + } + + private boolean isStatic() { + return this.method.modifiers.contains(Modifier.STATIC); + } + + private boolean isSameDeclaringClass(ClassName declaringClass) { + return this.declaringClass == null || this.declaringClass.equals(declaringClass); + } + + @Override + public String toString() { + String methodName = this.method.name; + if (isStatic()) { + return this.declaringClass + "::" + methodName; + } + else { + return ((this.declaringClass != null) + ? "<" + this.declaringClass + ">" : "") + + "::" + methodName; + } + } + +} diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java index 7247c212d0c..b09d36f61f2 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java @@ -18,8 +18,6 @@ package org.springframework.aot.generate; import java.util.function.Consumer; -import javax.lang.model.element.Modifier; - import org.springframework.javapoet.ClassName; import org.springframework.javapoet.MethodSpec; import org.springframework.util.Assert; @@ -73,9 +71,7 @@ public final class GeneratedMethod { * @return a method reference */ public MethodReference toMethodReference() { - return (this.methodSpec.modifiers.contains(Modifier.STATIC) - ? MethodReference.ofStatic(this.className, this.name) - : MethodReference.of(this.className, this.name)); + return new DefaultMethodReference(this.methodSpec, this.className); } /** diff --git a/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java b/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java index 80359dd314b..f6dda971007 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java @@ -16,223 +16,124 @@ package org.springframework.aot.generate; +import java.util.function.Function; + import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.TypeName; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; /** - * A reference to a static or instance method. + * A reference to a method with convenient code generation for + * referencing, or invoking it. * + * @author Stephane Nicoll * @author Phillip Webb * @since 6.0 */ -public final class MethodReference { - - private final Kind kind; - - @Nullable - private final ClassName declaringClass; - - private final String methodName; - - - private MethodReference(Kind kind, @Nullable ClassName declaringClass, - String methodName) { - this.kind = kind; - this.declaringClass = declaringClass; - this.methodName = methodName; - } - - - /** - * Create a new method reference that refers to the given instance method. - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference of(String methodName) { - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.INSTANCE, null, methodName); - } - - /** - * Create a new method reference that refers to the given instance method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference of(Class declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.INSTANCE, ClassName.get(declaringClass), - methodName); - } - - /** - * Create a new method reference that refers to the given instance method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference of(ClassName declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.INSTANCE, declaringClass, methodName); - } - - /** - * Create a new method reference that refers to the given static method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference ofStatic(Class declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.STATIC, ClassName.get(declaringClass), - methodName); - } - - /** - * Create a new method reference that refers to the given static method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference ofStatic(ClassName declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.STATIC, declaringClass, methodName); - } - - - /** - * Return the referenced declaring class. - * @return the declaring class - */ - @Nullable - public ClassName getDeclaringClass() { - return this.declaringClass; - } - - /** - * Return the referenced method name. - * @return the method name - */ - public String getMethodName() { - return this.methodName; - } +public interface MethodReference { /** * Return this method reference as a {@link CodeBlock}. If the reference is * to an instance method then {@code this::} will be returned. * @return a code block for the method reference. - * @see #toCodeBlock(String) */ - public CodeBlock toCodeBlock() { - return toCodeBlock(null); + CodeBlock toCodeBlock(); + + /** + * Return this method reference as a {@link CodeBlock} using the specified + * {@link ArgumentCodeGenerator}. + * @param argumentCodeGenerator the argument code generator to use + * @return a code block to invoke the method + */ + default CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator) { + return toInvokeCodeBlock(argumentCodeGenerator, null); } /** - * Return this method reference as a {@link CodeBlock}. If the reference is - * to an instance method and {@code instanceVariable} is {@code null} then - * {@code this::} will be returned. No {@code instanceVariable} - * can be specified for static method references. - * @param instanceVariable the instance variable or {@code null} - * @return a code block for the method reference. - * @see #toCodeBlock(String) + * Return this method reference as a {@link CodeBlock} using the specified + * {@link ArgumentCodeGenerator}. The {@code targetClassName} defines the + * context in which the method invocation is added. + *

If the caller has an instance of the type in which this method is + * defined, it can hint that by specifying the type as a target class. + * @param argumentCodeGenerator the argument code generator to use + * @param targetClassName the target class name + * @return a code block to invoke the method */ - public CodeBlock toCodeBlock(@Nullable String instanceVariable) { - return switch (this.kind) { - case INSTANCE -> toCodeBlockForInstance(instanceVariable); - case STATIC -> toCodeBlockForStatic(instanceVariable); - }; - } + CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator, @Nullable ClassName targetClassName); - private CodeBlock toCodeBlockForInstance(@Nullable String instanceVariable) { - instanceVariable = (instanceVariable != null) ? instanceVariable : "this"; - return CodeBlock.of("$L::$L", instanceVariable, this.methodName); - } - - private CodeBlock toCodeBlockForStatic(@Nullable String instanceVariable) { - Assert.isTrue(instanceVariable == null, - "'instanceVariable' must be null for static method references"); - return CodeBlock.of("$T::$L", this.declaringClass, this.methodName); - } /** - * Return this method reference as an invocation {@link CodeBlock}. - * @param arguments the method arguments - * @return a code back to invoke the method + * Strategy for generating code for arguments based on their type. */ - public CodeBlock toInvokeCodeBlock(CodeBlock... arguments) { - return toInvokeCodeBlock(null, arguments); - } + interface ArgumentCodeGenerator { - /** - * Return this method reference as an invocation {@link CodeBlock}. - * @param instanceVariable the instance variable or {@code null} - * @param arguments the method arguments - * @return a code back to invoke the method - */ - public CodeBlock toInvokeCodeBlock(@Nullable String instanceVariable, - CodeBlock... arguments) { + /** + * Generate the code for the given argument type. If this type is + * not supported, return {@code null}. + * @param argumentType the argument type + * @return the code for this argument, or {@code null} + */ + @Nullable + CodeBlock generateCode(TypeName argumentType); - return switch (this.kind) { - case INSTANCE -> toInvokeCodeBlockForInstance(instanceVariable, arguments); - case STATIC -> toInvokeCodeBlockForStatic(instanceVariable, arguments); - }; - } - - private CodeBlock toInvokeCodeBlockForInstance(@Nullable String instanceVariable, - CodeBlock[] arguments) { - - CodeBlock.Builder code = CodeBlock.builder(); - if (instanceVariable != null) { - code.add("$L.", instanceVariable); + /** + * Factory method that returns an {@link ArgumentCodeGenerator} that + * always returns {@code null}. + * @return a new {@link ArgumentCodeGenerator} instance + */ + static ArgumentCodeGenerator none() { + return from(type -> null); } - else if (this.declaringClass != null) { - code.add("new $T().", this.declaringClass); + + /** + * Factory method that can be used to create an {@link ArgumentCodeGenerator} + * that support only the given argument type. + * @param argumentType the argument type + * @param argumentCode the code for an argument of that type + * @return a new {@link ArgumentCodeGenerator} instance + */ + static ArgumentCodeGenerator of(Class argumentType, String argumentCode) { + return from(candidateType -> (candidateType.equals(ClassName.get(argumentType)) + ? CodeBlock.of(argumentCode) : null)); } - code.add("$L", this.methodName); - addArguments(code, arguments); - return code.build(); - } - private CodeBlock toInvokeCodeBlockForStatic(@Nullable String instanceVariable, - CodeBlock[] arguments) { - - Assert.isTrue(instanceVariable == null, - "'instanceVariable' must be null for static method references"); - CodeBlock.Builder code = CodeBlock.builder(); - code.add("$T.$L", this.declaringClass, this.methodName); - addArguments(code, arguments); - return code.build(); - } - - private void addArguments(CodeBlock.Builder code, CodeBlock[] arguments) { - code.add("("); - for (int i = 0; i < arguments.length; i++) { - if (i != 0) { - code.add(", "); - } - code.add(arguments[i]); + /** + * Factory method that creates a new {@link ArgumentCodeGenerator} from + * a lambda friendly function. The given function is provided with the + * argument type and must provide the code to use or {@code null} if + * the type is not supported. + * @param function the resolver function + * @return a new {@link ArgumentCodeGenerator} instance backed by the function + */ + static ArgumentCodeGenerator from(Function function) { + return function::apply; } - code.add(")"); - } - @Override - public String toString() { - return switch (this.kind) { - case INSTANCE -> ((this.declaringClass != null) ? "<" + this.declaringClass + ">" - : "") + "::" + this.methodName; - case STATIC -> this.declaringClass + "::" + this.methodName; - }; - } + /** + * Create a new composed {@link ArgumentCodeGenerator} by combining this + * generator with supporting the given argument type. + * @param argumentType the argument type + * @param argumentCode the code for an argument of that type + * @return a new composite {@link ArgumentCodeGenerator} instance + */ + default ArgumentCodeGenerator and(Class argumentType, String argumentCode) { + return and(ArgumentCodeGenerator.of(argumentType, argumentCode)); + } + /** + * Create a new composed {@link ArgumentCodeGenerator} by combining this + * generator with the given generator. + * @param argumentCodeGenerator the argument generator to add + * @return a new composite {@link ArgumentCodeGenerator} instance + */ + default ArgumentCodeGenerator and(ArgumentCodeGenerator argumentCodeGenerator) { + return from(type -> { + CodeBlock code = generateCode(type); + return (code != null ? code : argumentCodeGenerator.generateCode(type)); + }); + } - private enum Kind { - INSTANCE, STATIC } } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java b/spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java new file mode 100644 index 00000000000..b9643151ca9 --- /dev/null +++ b/spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java @@ -0,0 +1,199 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.generate; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.MethodSpec.Builder; +import org.springframework.javapoet.TypeName; +import org.springframework.lang.Nullable; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link DefaultMethodReference}. + * + * @author Phillip Webb + * @author Stephane Nicoll + */ +class DefaultMethodReferenceTests { + + private static final String EXPECTED_STATIC = "org.springframework.aot.generate.DefaultMethodReferenceTests::someMethod"; + + private static final String EXPECTED_ANONYMOUS_INSTANCE = "::someMethod"; + + private static final String EXPECTED_DECLARED_INSTANCE = "::someMethod"; + + private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test"); + + private static final ClassName INITIALIZER_CLASS_NAME = ClassName.get("com.example", "Initializer"); + + @Test + void createWithStringCreatesMethodReference() { + MethodSpec method = createTestMethod("someMethod", new TypeName[0]); + MethodReference reference = new DefaultMethodReference(method, null); + assertThat(reference).hasToString(EXPECTED_ANONYMOUS_INSTANCE); + } + + @Test + void createWithClassNameAndStringCreateMethodReference() { + ClassName declaringClass = ClassName.get(DefaultMethodReferenceTests.class); + MethodReference reference = createMethodReference("someMethod", new TypeName[0], declaringClass); + assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE); + } + + @Test + void createWithStaticAndClassAndStringCreatesMethodReference() { + ClassName declaringClass = ClassName.get(DefaultMethodReferenceTests.class); + MethodReference reference = createStaticMethodReference("someMethod", declaringClass); + assertThat(reference).hasToString(EXPECTED_STATIC); + } + + @Test + void toCodeBlock() { + assertThat(createLocalMethodReference("methodName").toCodeBlock()) + .isEqualTo(CodeBlock.of("this::methodName")); + } + + @Test + void toCodeBlockWithStaticMethod() { + assertThat(createStaticMethodReference("methodName", TEST_CLASS_NAME).toCodeBlock()) + .isEqualTo(CodeBlock.of("com.example.Test::methodName")); + } + + @Test + void toCodeBlockWithStaticMethodRequiresDeclaringClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0], Modifier.STATIC); + MethodReference methodReference = new DefaultMethodReference(method, null); + assertThatIllegalArgumentException().isThrownBy(methodReference::toCodeBlock) + .withMessage("static method reference must define a declaring class"); + } + + @Test + void toInvokeCodeBlockWithNullDeclaringClassAndTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, null); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME)) + .isEqualTo(CodeBlock.of("methodName()")); + } + + @Test + void toInvokeCodeBlockWithNullDeclaringClassAndNullTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, null); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none())) + .isEqualTo(CodeBlock.of("methodName()")); + } + + @Test + void toInvokeCodeBlockWithDeclaringClassAndNullTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none())) + .isEqualTo(CodeBlock.of("new com.example.Test().methodName()")); + } + + @Test + void toInvokeCodeBlockWithMatchingTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME); + CodeBlock invocation = methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME); + // Assume com.example.Test is in a `test` variable. + assertThat(CodeBlock.of("$L.$L", "test", invocation)).isEqualTo(CodeBlock.of("test.methodName()")); + } + + @Test + void toInvokeCodeBlockWithNonMatchingDeclaringClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), INITIALIZER_CLASS_NAME)) + .isEqualTo(CodeBlock.of("new com.example.Test().methodName()")); + } + + @Test + void toInvokeCodeBlockWithMatchingArg() { + MethodReference methodReference = createLocalMethodReference("methodName", ClassName.get(String.class)); + ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(String.class, "stringArg"); + assertThat(methodReference.toInvokeCodeBlock(argCodeGenerator)) + .isEqualTo(CodeBlock.of("methodName(stringArg)")); + } + + @Test + void toInvokeCodeBlockWithMatchingArgs() { + MethodReference methodReference = createLocalMethodReference("methodName", + ClassName.get(Integer.class), ClassName.get(String.class)); + ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(String.class, "stringArg") + .and(Integer.class, "integerArg"); + assertThat(methodReference.toInvokeCodeBlock(argCodeGenerator)) + .isEqualTo(CodeBlock.of("methodName(integerArg, stringArg)")); + } + + @Test + void toInvokeCodeBlockWithNonMatchingArg() { + MethodReference methodReference = createLocalMethodReference("methodName", + ClassName.get(Integer.class), ClassName.get(String.class)); + ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(Integer.class, "integerArg"); + assertThatIllegalArgumentException().isThrownBy(() -> methodReference.toInvokeCodeBlock(argCodeGenerator)) + .withMessageContaining("parameter 1 of type java.lang.String is not supported"); + } + + @Test + void toInvokeCodeBlockWithStaticMethodAndMatchingDeclaringClass() { + MethodReference methodReference = createStaticMethodReference("methodName", TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME)) + .isEqualTo(CodeBlock.of("methodName()")); + } + + @Test + void toInvokeCodeBlockWithStaticMethodAndSeparateDeclaringClass() { + MethodReference methodReference = createStaticMethodReference("methodName", TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), INITIALIZER_CLASS_NAME)) + .isEqualTo(CodeBlock.of("com.example.Test.methodName()")); + } + + + private MethodReference createLocalMethodReference(String name, TypeName... argumentTypes) { + return createMethodReference(name, argumentTypes, null); + } + + private MethodReference createMethodReference(String name, TypeName[] argumentTypes, @Nullable ClassName declaringClass) { + MethodSpec method = createTestMethod(name, argumentTypes); + return new DefaultMethodReference(method, declaringClass); + } + + private MethodReference createStaticMethodReference(String name, ClassName declaringClass, TypeName... argumentTypes) { + MethodSpec method = createTestMethod(name, argumentTypes, Modifier.STATIC); + return new DefaultMethodReference(method, declaringClass); + } + + private MethodSpec createTestMethod(String name, TypeName[] argumentTypes, Modifier... modifiers) { + Builder method = MethodSpec.methodBuilder(name); + for (int i = 0; i < argumentTypes.length; i++) { + method.addParameter(argumentTypes[i], "args" + i); + } + method.addModifiers(modifiers); + return method.build(); + } + +} diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java index 34ac962746a..6e865bd4eb8 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java @@ -22,6 +22,7 @@ import javax.lang.model.element.Modifier; import org.junit.jupiter.api.Test; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; @@ -67,8 +68,8 @@ class GeneratedMethodTests { GeneratedMethod generatedMethod = create(emptyMethod); MethodReference methodReference = generatedMethod.toMethodReference(); assertThat(methodReference).isNotNull(); - assertThat(methodReference.toInvokeCodeBlock("test")) - .isEqualTo(CodeBlock.of("test.spring()")); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME)) + .isEqualTo(CodeBlock.of("spring()")); } @Test @@ -76,7 +77,8 @@ class GeneratedMethodTests { GeneratedMethod generatedMethod = create(method -> method.addModifiers(Modifier.STATIC)); MethodReference methodReference = generatedMethod.toMethodReference(); assertThat(methodReference).isNotNull(); - assertThat(methodReference.toInvokeCodeBlock()) + ClassName anotherDeclaringClass = ClassName.get("com.example", "Another"); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), anotherDeclaringClass)) .isEqualTo(CodeBlock.of("com.example.Test.spring()")); } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java b/spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java deleted file mode 100644 index de5c79667b4..00000000000 --- a/spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -import org.junit.jupiter.api.Test; - -import org.springframework.javapoet.ClassName; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; - -/** - * Tests for {@link MethodReference}. - * - * @author Phillip Webb - */ -class MethodReferenceTests { - - private static final String EXPECTED_STATIC = "org.springframework.aot.generate.MethodReferenceTests::someMethod"; - - private static final String EXPECTED_ANONYMOUS_INSTANCE = "::someMethod"; - - private static final String EXPECTED_DECLARED_INSTANCE = "::someMethod"; - - - @Test - void ofWithStringWhenMethodNameIsNullThrowsException() { - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofWithStringCreatesMethodReference() { - String methodName = "someMethod"; - MethodReference reference = MethodReference.of(methodName); - assertThat(reference).hasToString(EXPECTED_ANONYMOUS_INSTANCE); - } - - @Test - void ofWithClassAndStringWhenDeclaringClassIsNullThrowsException() { - Class declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofWithClassAndStringWhenMethodNameIsNullThrowsException() { - Class declaringClass = MethodReferenceTests.class; - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofWithClassAndStringCreatesMethodReference() { - Class declaringClass = MethodReferenceTests.class; - String methodName = "someMethod"; - MethodReference reference = MethodReference.of(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE); - } - - @Test - void ofWithClassNameAndStringWhenDeclaringClassIsNullThrowsException() { - ClassName declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofWithClassNameAndStringWhenMethodNameIsNullThrowsException() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofWithClassNameAndStringCreateMethodReference() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = "someMethod"; - MethodReference reference = MethodReference.of(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE); - } - - @Test - void ofStaticWithClassAndStringWhenDeclaringClassIsNullThrowsException() { - Class declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofStaticWithClassAndStringWhenMethodNameIsEmptyThrowsException() { - Class declaringClass = MethodReferenceTests.class; - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofStaticWithClassAndStringCreatesMethodReference() { - Class declaringClass = MethodReferenceTests.class; - String methodName = "someMethod"; - MethodReference reference = MethodReference.ofStatic(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_STATIC); - } - - @Test - void ofStaticWithClassNameAndGeneratedMethodNameWhenDeclaringClassIsNullThrowsException() { - ClassName declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofStaticWithClassNameAndGeneratedMethodNameWhenMethodNameIsEmptyThrowsException() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofStaticWithClassNameAndGeneratedMethodNameCreatesMethodReference() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = "someMethod"; - MethodReference reference = MethodReference.ofStatic(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_STATIC); - } - - @Test - void toCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toCodeBlock(null)).hasToString("this::someMethod"); - } - - @Test - void toCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNotNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toCodeBlock("myInstance")) - .hasToString("myInstance::someMethod"); - } - - @Test - void toCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThat(reference.toCodeBlock(null)).hasToString(EXPECTED_STATIC); - } - - @Test - void toCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNotNullThrowsException() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThatIllegalArgumentException() - .isThrownBy(() -> reference.toCodeBlock("myInstance")).withMessage( - "'instanceVariable' must be null for static method references"); - } - - @Test - void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toInvokeCodeBlock()).hasToString("someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNullAndHasDecalredClass() { - MethodReference reference = MethodReference.of(MethodReferenceTests.class, - "someMethod"); - assertThat(reference.toInvokeCodeBlock()).hasToString( - "new org.springframework.aot.generate.MethodReferenceTests().someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNotNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toInvokeCodeBlock("myInstance")) - .hasToString("myInstance.someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThat(reference.toInvokeCodeBlock()).hasToString( - "org.springframework.aot.generate.MethodReferenceTests.someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNotNullThrowsException() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThatIllegalArgumentException() - .isThrownBy(() -> reference.toInvokeCodeBlock("myInstance")).withMessage( - "'instanceVariable' must be null for static method references"); - } - -}