All MethodReference to support a more flexible signature

Closes gh-29005
This commit is contained in:
Stephane Nicoll 2022-09-12 10:09:01 +02:00
commit fcb6baf2e9
27 changed files with 683 additions and 466 deletions

View File

@ -163,8 +163,8 @@ class ScopedProxyBeanRegistrationAotProcessor implements BeanRegistrationAotProc
method.addStatement("return ($T) factory.getObject()",
beanClass);
});
return CodeBlock.of("$T.of($T::$L)", InstanceSupplier.class,
beanRegistrationCode.getClassName(), generatedMethod.getName());
return CodeBlock.of("$T.of($L)", InstanceSupplier.class,
generatedMethod.toMethodReference().toCodeBlock());
}
}

View File

@ -26,6 +26,7 @@ import org.junit.jupiter.api.Test;
import org.springframework.aop.framework.AopInfrastructureBean;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.aot.test.generate.compile.Compiled;
import org.springframework.aot.test.generate.compile.TestCompiler;
@ -139,11 +140,14 @@ class ScopedProxyBeanRegistrationAotProcessorTests {
MethodReference methodReference = this.beanFactoryInitializationCode
.getInitializers().get(0);
this.beanFactoryInitializationCode.getTypeBuilder().set(type -> {
CodeBlock methodInvocation = methodReference.toInvokeCodeBlock(
ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"),
this.beanFactoryInitializationCode.getClassName());
type.addModifiers(Modifier.PUBLIC);
type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class));
type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC)
.addParameter(DefaultListableBeanFactory.class, "beanFactory")
.addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory")))
.addStatement(methodInvocation)
.build());
});
this.generationContext.writeGeneratedContent();

View File

@ -44,7 +44,6 @@ import org.springframework.aot.generate.AccessVisibility;
import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.hint.ExecutableMode;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.BeanUtils;
@ -944,8 +943,7 @@ public class AutowiredAnnotationBeanPostProcessor implements SmartInstantiationA
method.returns(this.target);
method.addCode(generateMethodCode(generationContext.getRuntimeHints()));
});
beanRegistrationCode.addInstancePostProcessor(
MethodReference.ofStatic(generatedClass.getName(), generateMethod.getName()));
beanRegistrationCode.addInstancePostProcessor(generateMethod.toMethodReference());
if (this.candidateResolver != null) {
registerHints(generationContext.getRuntimeHints());

View File

@ -107,16 +107,14 @@ class BeanDefinitionMethodGenerator {
GeneratedMethod generatedMethod = generateBeanDefinitionMethod(
generationContext, generatedClass.getName(), generatedMethods,
codeFragments, Modifier.PUBLIC);
return MethodReference.ofStatic(generatedClass.getName(),
generatedMethod.getName());
return generatedMethod.toMethodReference();
}
GeneratedMethods generatedMethods = beanRegistrationsCode.getMethods()
.withPrefix(getName());
GeneratedMethod generatedMethod = generateBeanDefinitionMethod(generationContext,
beanRegistrationsCode.getClassName(), generatedMethods, codeFragments,
Modifier.PRIVATE);
return MethodReference.ofStatic(beanRegistrationsCode.getClassName(),
generatedMethod.getName());
return generatedMethod.toMethodReference();
}
private BeanRegistrationCodeFragments getCodeFragments(GenerationContext generationContext,

View File

@ -24,6 +24,7 @@ import org.springframework.aot.generate.MethodReference;
* perform bean factory initialization.
*
* @author Phillip Webb
* @author Stephane Nicoll
* @since 6.0
* @see BeanFactoryInitializationAotContribution
*/
@ -41,10 +42,16 @@ public interface BeanFactoryInitializationCode {
GeneratedMethods getMethods();
/**
* Add an initializer method call.
* @param methodReference a reference to the initialize method to call. The
* referenced method must have the same functional signature as
* {@code Consumer<DefaultListableBeanFactory>}.
* Add an initializer method call. An initializer can use a flexible signature,
* using any of the following:
* <ul>
* <li>{@code DefaultListableBeanFactory}, or {@code ConfigurableListableBeanFactory}
* to use the bean factory.</li>
* <li>{@code ConfigurableEnvironment} or {@code Environment} to access the
* environment.</li>
* <li>{@code ResourceLoader} to load resources.</li>
* </ul>
* @param methodReference a reference to the initialize method to call.
*/
void addInitializer(MethodReference methodReference);

View File

@ -25,6 +25,7 @@ import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
@ -65,8 +66,7 @@ class BeanRegistrationsAotContribution
BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator(generatedClass);
GeneratedMethod generatedMethod = codeGenerator.getMethods().add("registerBeanDefinitions", method ->
generateRegisterMethod(method, generationContext, codeGenerator));
beanFactoryInitializationCode.addInitializer(
MethodReference.of(generatedClass.getName(), generatedMethod.getName()));
beanFactoryInitializationCode.addInitializer(generatedMethod.toMethodReference());
}
private void generateRegisterMethod(MethodSpec.Builder method,
@ -82,9 +82,11 @@ class BeanRegistrationsAotContribution
MethodReference beanDefinitionMethod = beanDefinitionMethodGenerator
.generateBeanDefinitionMethod(generationContext,
beanRegistrationsCode);
CodeBlock methodInvocation = beanDefinitionMethod.toInvokeCodeBlock(
ArgumentCodeGenerator.none(), beanRegistrationsCode.getClassName());
code.addStatement("$L.registerBeanDefinition($S, $L)",
BEAN_FACTORY_PARAMETER_NAME, beanName,
beanDefinitionMethod.toInvokeCodeBlock());
methodInvocation);
});
method.addCode(code.build());
}

View File

@ -24,6 +24,7 @@ import java.util.function.Predicate;
import org.springframework.aot.generate.AccessVisibility;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanDefinitionHolder;
@ -156,7 +157,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments
MethodReference generatedMethod = methodGenerator
.generateBeanDefinitionMethod(generationContext,
this.beanRegistrationsCode);
return generatedMethod.toInvokeCodeBlock();
return generatedMethod.toInvokeCodeBlock(ArgumentCodeGenerator.none());
}
return null;
}

View File

@ -28,6 +28,7 @@ import org.springframework.aot.generate.AccessVisibility;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.hint.ExecutableMode;
import org.springframework.beans.factory.support.InstanceSupplier;
import org.springframework.beans.factory.support.RegisteredBean;
@ -296,8 +297,9 @@ class InstanceSupplierCodeGenerator {
REGISTERED_BEAN_PARAMETER_NAME, declaringClass, factoryMethodName, args);
}
private CodeBlock generateReturnStatement(GeneratedMethod getInstanceMethod) {
return CodeBlock.of("$T.$L()", this.className, getInstanceMethod.getName());
private CodeBlock generateReturnStatement(GeneratedMethod generatedMethod) {
return generatedMethod.toMethodReference().toInvokeCodeBlock(
ArgumentCodeGenerator.none(), this.className);
}
private CodeBlock generateWithGeneratorCode(boolean hasArguments, CodeBlock newInstance) {

View File

@ -24,6 +24,7 @@ import javax.lang.model.element.Modifier;
import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.hint.predicate.RuntimeHintsPredicates;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess;
@ -161,13 +162,16 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests {
Class<?> target = registeredBean.getBeanClass();
MethodReference methodReference = this.beanRegistrationCode.getInstancePostProcessors().get(0);
this.beanRegistrationCode.getTypeBuilder().set(type -> {
CodeBlock methodInvocation = methodReference.toInvokeCodeBlock(
ArgumentCodeGenerator.of(RegisteredBean.class, "registeredBean").and(target, "instance"),
this.beanRegistrationCode.getClassName());
type.addModifiers(Modifier.PUBLIC);
type.addSuperinterface(ParameterizedTypeName.get(BiFunction.class, RegisteredBean.class, target, target));
type.addMethod(MethodSpec.methodBuilder("apply")
.addModifiers(Modifier.PUBLIC)
.addParameter(RegisteredBean.class, "registeredBean")
.addParameter(target, "instance").returns(target)
.addStatement("return $L", methodReference.toInvokeCodeBlock(CodeBlock.of("registeredBean"), CodeBlock.of("instance")))
.addStatement("return $L", methodInvocation)
.build());
});

View File

@ -30,6 +30,7 @@ import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess;
import org.springframework.aot.test.generate.compile.Compiled;
@ -129,8 +130,7 @@ class BeanDefinitionMethodGeneratorTests {
.addParameter(RegisteredBean.class, "registeredBean")
.addParameter(TestBean.class, "testBean")
.returns(TestBean.class).addCode("return new $T($S);", TestBean.class, "postprocessed"));
beanRegistrationCode.addInstancePostProcessor(MethodReference.ofStatic(
beanRegistrationCode.getClassName(), generatedMethod.getName()));
beanRegistrationCode.addInstancePostProcessor(generatedMethod.toMethodReference());
};
List<BeanRegistrationAotContribution> aotContributions = Collections
.singletonList(aotContribution);
@ -167,8 +167,7 @@ class BeanDefinitionMethodGeneratorTests {
.addParameter(RegisteredBean.class, "registeredBean")
.addParameter(TestBean.class, "testBean")
.returns(TestBean.class).addCode("return new $T($S);", TestBean.class, "postprocessed"));
beanRegistrationCode.addInstancePostProcessor(MethodReference.ofStatic(
beanRegistrationCode.getClassName(), generatedMethod.getName()));
beanRegistrationCode.addInstancePostProcessor(generatedMethod.toMethodReference());
};
List<BeanRegistrationAotContribution> aotContributions = Collections
.singletonList(aotContribution);
@ -416,12 +415,14 @@ class BeanDefinitionMethodGeneratorTests {
private void compile(MethodReference method,
BiConsumer<RootBeanDefinition, Compiled> result) {
this.beanRegistrationsCode.getTypeBuilder().set(type -> {
CodeBlock methodInvocation = method.toInvokeCodeBlock(ArgumentCodeGenerator.none(),
this.beanRegistrationsCode.getClassName());
type.addModifiers(Modifier.PUBLIC);
type.addSuperinterface(ParameterizedTypeName.get(Supplier.class, BeanDefinition.class));
type.addMethod(MethodSpec.methodBuilder("get")
.addModifiers(Modifier.PUBLIC)
.returns(BeanDefinition.class)
.addCode("return $L;", method.toInvokeCodeBlock()).build());
.addCode("return $L;", methodInvocation).build());
});
this.generationContext.writeGeneratedContent();
TestCompiler.forSystem().withFiles(this.generationContext.getGeneratedFiles()).compile(compiled ->

View File

@ -31,6 +31,7 @@ import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.ClassNameGenerator;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.aot.test.generate.TestTarget;
import org.springframework.aot.test.generate.compile.Compiled;
@ -155,11 +156,14 @@ class BeanRegistrationsAotContributionTests {
MethodReference methodReference = this.beanFactoryInitializationCode
.getInitializers().get(0);
this.beanFactoryInitializationCode.getTypeBuilder().set(type -> {
CodeBlock methodInvocation = methodReference.toInvokeCodeBlock(
ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"),
this.beanFactoryInitializationCode.getClassName());
type.addModifiers(Modifier.PUBLIC);
type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class));
type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC)
.addParameter(DefaultListableBeanFactory.class, "beanFactory")
.addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory")))
.addStatement(methodInvocation)
.build());
});
this.generationContext.writeGeneratedContent();

View File

@ -25,6 +25,7 @@ import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.javapoet.ClassName;
/**
* Mock {@link BeanFactoryInitializationCode} implementation.
@ -46,6 +47,9 @@ public class MockBeanFactoryInitializationCode implements BeanFactoryInitializat
.addForFeature("TestCode", this.typeBuilder);
}
public ClassName getClassName() {
return this.generatedClass.getName();
}
public DeferredTypeBuilder getTypeBuilder() {
return this.typeBuilder;

View File

@ -33,7 +33,6 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.aop.framework.autoproxy.AutoProxyUtils;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.hint.ResourceHints;
import org.springframework.aot.hint.TypeReference;
import org.springframework.beans.PropertyValues;
@ -536,7 +535,7 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo
.add("addImportAwareBeanPostProcessors", method ->
generateAddPostProcessorMethod(method, mappings));
beanFactoryInitializationCode
.addInitializer(MethodReference.of(generatedMethod.getName()));
.addInitializer(generatedMethod.toMethodReference());
ResourceHints hints = generationContext.getRuntimeHints().resources();
mappings.forEach(
(target, from) -> hints.registerType(TypeReference.of(from)));

View File

@ -18,6 +18,7 @@ package org.springframework.context.aot;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import javax.lang.model.element.Modifier;
@ -25,16 +26,24 @@ import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.Environment;
import org.springframework.core.io.ResourceLoader;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.ParameterizedTypeName;
import org.springframework.javapoet.TypeName;
import org.springframework.javapoet.TypeSpec;
import org.springframework.lang.Nullable;
/**
* Internal code generator to create the {@link ApplicationContextInitializer}.
@ -88,12 +97,17 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia
BEAN_FACTORY_VARIABLE, ContextAnnotationAutowireCandidateResolver.class);
code.addStatement("$L.setDependencyComparator($T.INSTANCE)",
BEAN_FACTORY_VARIABLE, AnnotationAwareOrderComparator.class);
ArgumentCodeGenerator argCodeGenerator = createInitializerMethodArgumentCodeGenerator();
for (MethodReference initializer : this.initializers) {
code.addStatement(initializer.toInvokeCodeBlock(CodeBlock.of(BEAN_FACTORY_VARIABLE)));
code.addStatement(initializer.toInvokeCodeBlock(argCodeGenerator, this.generatedClass.getName()));
}
return code.build();
}
static ArgumentCodeGenerator createInitializerMethodArgumentCodeGenerator() {
return ArgumentCodeGenerator.from(new InitializerMethodArgumentCodeGenerator());
}
GeneratedClass getGeneratedClass() {
return this.generatedClass;
}
@ -108,4 +122,30 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia
this.initializers.add(methodReference);
}
private static class InitializerMethodArgumentCodeGenerator implements Function<TypeName, CodeBlock> {
@Override
@Nullable
public CodeBlock apply(TypeName typeName) {
return (typeName instanceof ClassName className ? apply(className) : null);
}
@Nullable
private CodeBlock apply(ClassName className) {
String name = className.canonicalName();
if (name.equals(DefaultListableBeanFactory.class.getName())
|| name.equals(ConfigurableListableBeanFactory.class.getName())) {
return CodeBlock.of(BEAN_FACTORY_VARIABLE);
}
else if (name.equals(ConfigurableEnvironment.class.getName())
|| name.equals(Environment.class.getName())) {
return CodeBlock.of("$L.getConfigurableEnvironment()", APPLICATION_CONTEXT_VARIABLE);
}
else if (name.equals(ResourceLoader.class.getName())) {
return CodeBlock.of(APPLICATION_CONTEXT_VARIABLE);
}
return null;
}
}
}

View File

@ -27,6 +27,7 @@ import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.hint.ResourcePatternHint;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.aot.test.generate.compile.Compiled;
@ -162,11 +163,14 @@ class ConfigurationClassPostProcessorAotContributionTests {
private void compile(BiConsumer<Consumer<DefaultListableBeanFactory>, Compiled> result) {
MethodReference methodReference = this.beanFactoryInitializationCode.getInitializers().get(0);
this.beanFactoryInitializationCode.getTypeBuilder().set(type -> {
CodeBlock methodInvocation = methodReference.toInvokeCodeBlock(
ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"),
this.beanFactoryInitializationCode.getClassName());
type.addModifiers(Modifier.PUBLIC);
type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class));
type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC)
.addParameter(DefaultListableBeanFactory.class, "beanFactory")
.addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory")))
.addStatement(methodInvocation)
.build());
});
this.generationContext.writeGeneratedContent();

View File

@ -0,0 +1,77 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.context.aot;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.AbstractBeanFactory;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.Environment;
import org.springframework.core.env.StandardEnvironment;
import org.springframework.core.io.ResourceLoader;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link ApplicationContextInitializationCodeGenerator}.
*
* @author Stephane Nicoll
*/
class ApplicationContextInitializationCodeGeneratorTests {
private static final ArgumentCodeGenerator argCodeGenerator = ApplicationContextInitializationCodeGenerator.
createInitializerMethodArgumentCodeGenerator();
@ParameterizedTest
@MethodSource("methodArguments")
void argumentsForSupportedTypesAreResolved(Class<?> target, String expectedArgument) {
CodeBlock code = CodeBlock.of(expectedArgument);
assertThat(argCodeGenerator.generateCode(ClassName.get(target))).isEqualTo(code);
}
@Test
void argumentForUnsupportedBeanFactoryIsNotResolved() {
assertThat(argCodeGenerator.generateCode(ClassName.get(AbstractBeanFactory.class))).isNull();
}
@Test
void argumentForUnsupportedEnvironmentIsNotResolved() {
assertThat(argCodeGenerator.generateCode(ClassName.get(StandardEnvironment.class))).isNull();
}
static Stream<Arguments> methodArguments() {
String applicationContext = "applicationContext";
String environment = applicationContext + ".getConfigurableEnvironment()";
return Stream.of(
Arguments.of(DefaultListableBeanFactory.class, "beanFactory"),
Arguments.of(ConfigurableListableBeanFactory.class, "beanFactory"),
Arguments.of(ConfigurableEnvironment.class, environment),
Arguments.of(Environment.class, environment),
Arguments.of(ResourceLoader.class, applicationContext));
}
}

View File

@ -0,0 +1,134 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.aot.generate;
import java.util.ArrayList;
import java.util.List;
import javax.lang.model.element.Modifier;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.TypeName;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* Default {@link MethodReference} implementation based on a {@link MethodSpec}.
*
* @author Stephane Nicoll
* @author Phillip Webb
* @since 6.0
*/
public class DefaultMethodReference implements MethodReference {
private final MethodSpec method;
@Nullable
private final ClassName declaringClass;
public DefaultMethodReference(MethodSpec method, @Nullable ClassName declaringClass) {
this.method = method;
this.declaringClass = declaringClass;
}
@Override
public CodeBlock toCodeBlock() {
String methodName = this.method.name;
if (isStatic()) {
Assert.notNull(this.declaringClass, "static method reference must define a declaring class");
return CodeBlock.of("$T::$L", this.declaringClass, methodName);
}
else {
return CodeBlock.of("this::$L", methodName);
}
}
public CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator,
@Nullable ClassName targetClassName) {
String methodName = this.method.name;
CodeBlock.Builder code = CodeBlock.builder();
if (isStatic()) {
Assert.notNull(this.declaringClass, "static method reference must define a declaring class");
if (isSameDeclaringClass(targetClassName)) {
code.add("$L", methodName);
}
else {
code.add("$T.$L", this.declaringClass, methodName);
}
}
else {
if (!isSameDeclaringClass(targetClassName)) {
code.add(instantiateDeclaringClass(this.declaringClass));
}
code.add("$L", methodName);
}
code.add("(");
addArguments(code, argumentCodeGenerator);
code.add(")");
return code.build();
}
/**
* Add the code for the method arguments using the specified
* {@link ArgumentCodeGenerator} if necessary.
* @param code the code builder to use to add method arguments
* @param argumentCodeGenerator the code generator to use
*/
protected void addArguments(CodeBlock.Builder code, ArgumentCodeGenerator argumentCodeGenerator) {
List<CodeBlock> arguments = new ArrayList<>();
TypeName[] argumentTypes = this.method.parameters.stream()
.map(parameter -> parameter.type).toArray(TypeName[]::new);
for (int i = 0; i < argumentTypes.length; i++) {
TypeName argumentType = argumentTypes[i];
CodeBlock argumentCode = argumentCodeGenerator.generateCode(argumentType);
if (argumentCode == null) {
throw new IllegalArgumentException("Could not generate code for " + this
+ ": parameter " + i + " of type " + argumentType + " is not supported");
}
arguments.add(argumentCode);
}
code.add(CodeBlock.join(arguments, ", "));
}
protected CodeBlock instantiateDeclaringClass(ClassName declaringClass) {
return CodeBlock.of("new $T().", declaringClass);
}
private boolean isStatic() {
return this.method.modifiers.contains(Modifier.STATIC);
}
private boolean isSameDeclaringClass(ClassName declaringClass) {
return this.declaringClass == null || this.declaringClass.equals(declaringClass);
}
@Override
public String toString() {
String methodName = this.method.name;
if (isStatic()) {
return this.declaringClass + "::" + methodName;
}
else {
return ((this.declaringClass != null)
? "<" + this.declaringClass + ">" : "<instance>")
+ "::" + methodName;
}
}
}

View File

@ -55,7 +55,7 @@ public final class GeneratedClass {
GeneratedClass(ClassName name, Consumer<TypeSpec.Builder> type) {
this.name = name;
this.type = type;
this.methods = new GeneratedMethods(this::generateSequencedMethodName);
this.methods = new GeneratedMethods(name, this::generateSequencedMethodName);
}

View File

@ -18,6 +18,7 @@ package org.springframework.aot.generate;
import java.util.function.Consumer;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.MethodSpec;
import org.springframework.util.Assert;
@ -25,11 +26,14 @@ import org.springframework.util.Assert;
* A generated method.
*
* @author Phillip Webb
* @author Stephane Nicoll
* @since 6.0
* @see GeneratedMethods
*/
public final class GeneratedMethod {
private final ClassName className;
private final String name;
private final MethodSpec methodSpec;
@ -39,12 +43,14 @@ public final class GeneratedMethod {
* Create a new {@link GeneratedMethod} instance with the given name. This
* constructor is package-private since names should only be generated via
* {@link GeneratedMethods}.
* @param className the declaring class of the method
* @param name the generated method name
* @param method consumer to generate the method
*/
GeneratedMethod(String name, Consumer<MethodSpec.Builder> method) {
GeneratedMethod(ClassName className, String name, Consumer<MethodSpec.Builder> method) {
this.className = className;
this.name = name;
MethodSpec.Builder builder = MethodSpec.methodBuilder(getName());
MethodSpec.Builder builder = MethodSpec.methodBuilder(this.name);
method.accept(builder);
this.methodSpec = builder.build();
Assert.state(this.name.equals(this.methodSpec.name),
@ -60,6 +66,14 @@ public final class GeneratedMethod {
return this.name;
}
/**
* Return a {@link MethodReference} to this generated method.
* @return a method reference
*/
public MethodReference toMethodReference() {
return new DefaultMethodReference(this.methodSpec, this.className);
}
/**
* Return the {@link MethodSpec} for this generated method.
* @return the method spec

View File

@ -22,6 +22,7 @@ import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.MethodSpec.Builder;
import org.springframework.util.Assert;
@ -30,11 +31,14 @@ import org.springframework.util.Assert;
* A managed collection of generated methods.
*
* @author Phillip Webb
* @author Stephane Nicoll
* @since 6.0
* @see GeneratedMethod
*/
public class GeneratedMethods {
private final ClassName className;
private final Function<MethodName, String> methodNameGenerator;
private final MethodName prefix;
@ -44,18 +48,22 @@ public class GeneratedMethods {
/**
* Create a new {@link GeneratedMethods} using the specified method name
* generator.
* @param className the declaring class name
* @param methodNameGenerator the method name generator
*/
GeneratedMethods(Function<MethodName, String> methodNameGenerator) {
GeneratedMethods(ClassName className, Function<MethodName, String> methodNameGenerator) {
Assert.notNull(className, "'className' must not be null");
Assert.notNull(methodNameGenerator, "'methodNameGenerator' must not be null");
this.className = className;
this.methodNameGenerator = methodNameGenerator;
this.prefix = MethodName.NONE;
this.generatedMethods = new ArrayList<>();
}
private GeneratedMethods(Function<MethodName, String> methodNameGenerator,
private GeneratedMethods(ClassName className, Function<MethodName, String> methodNameGenerator,
MethodName prefix, List<GeneratedMethod> generatedMethods) {
this.className = className;
this.methodNameGenerator = methodNameGenerator;
this.prefix = prefix;
this.generatedMethods = generatedMethods;
@ -82,7 +90,7 @@ public class GeneratedMethods {
Assert.notNull(suggestedNameParts, "'suggestedNameParts' must not be null");
Assert.notNull(method, "'method' must not be null");
String generatedName = this.methodNameGenerator.apply(this.prefix.and(suggestedNameParts));
GeneratedMethod generatedMethod = new GeneratedMethod(generatedName, method);
GeneratedMethod generatedMethod = new GeneratedMethod(this.className, generatedName, method);
this.generatedMethods.add(generatedMethod);
return generatedMethod;
}
@ -90,7 +98,8 @@ public class GeneratedMethods {
public GeneratedMethods withPrefix(String prefix) {
Assert.notNull(prefix, "'prefix' must not be null");
return new GeneratedMethods(this.methodNameGenerator, this.prefix.and(prefix), this.generatedMethods);
return new GeneratedMethods(this.className, this.methodNameGenerator,
this.prefix.and(prefix), this.generatedMethods);
}
/**

View File

@ -16,223 +16,124 @@
package org.springframework.aot.generate;
import java.util.function.Function;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.TypeName;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* A reference to a static or instance method.
* A reference to a method with convenient code generation for
* referencing, or invoking it.
*
* @author Stephane Nicoll
* @author Phillip Webb
* @since 6.0
*/
public final class MethodReference {
private final Kind kind;
@Nullable
private final ClassName declaringClass;
private final String methodName;
private MethodReference(Kind kind, @Nullable ClassName declaringClass,
String methodName) {
this.kind = kind;
this.declaringClass = declaringClass;
this.methodName = methodName;
}
/**
* Create a new method reference that refers to the given instance method.
* @param methodName the method name
* @return a new {@link MethodReference} instance
*/
public static MethodReference of(String methodName) {
Assert.hasLength(methodName, "'methodName' must not be empty");
return new MethodReference(Kind.INSTANCE, null, methodName);
}
/**
* Create a new method reference that refers to the given instance method.
* @param declaringClass the declaring class
* @param methodName the method name
* @return a new {@link MethodReference} instance
*/
public static MethodReference of(Class<?> declaringClass, String methodName) {
Assert.notNull(declaringClass, "'declaringClass' must not be null");
Assert.hasLength(methodName, "'methodName' must not be empty");
return new MethodReference(Kind.INSTANCE, ClassName.get(declaringClass),
methodName);
}
/**
* Create a new method reference that refers to the given instance method.
* @param declaringClass the declaring class
* @param methodName the method name
* @return a new {@link MethodReference} instance
*/
public static MethodReference of(ClassName declaringClass, String methodName) {
Assert.notNull(declaringClass, "'declaringClass' must not be null");
Assert.hasLength(methodName, "'methodName' must not be empty");
return new MethodReference(Kind.INSTANCE, declaringClass, methodName);
}
/**
* Create a new method reference that refers to the given static method.
* @param declaringClass the declaring class
* @param methodName the method name
* @return a new {@link MethodReference} instance
*/
public static MethodReference ofStatic(Class<?> declaringClass, String methodName) {
Assert.notNull(declaringClass, "'declaringClass' must not be null");
Assert.hasLength(methodName, "'methodName' must not be empty");
return new MethodReference(Kind.STATIC, ClassName.get(declaringClass),
methodName);
}
/**
* Create a new method reference that refers to the given static method.
* @param declaringClass the declaring class
* @param methodName the method name
* @return a new {@link MethodReference} instance
*/
public static MethodReference ofStatic(ClassName declaringClass, String methodName) {
Assert.notNull(declaringClass, "'declaringClass' must not be null");
Assert.hasLength(methodName, "'methodName' must not be empty");
return new MethodReference(Kind.STATIC, declaringClass, methodName);
}
/**
* Return the referenced declaring class.
* @return the declaring class
*/
@Nullable
public ClassName getDeclaringClass() {
return this.declaringClass;
}
/**
* Return the referenced method name.
* @return the method name
*/
public String getMethodName() {
return this.methodName;
}
public interface MethodReference {
/**
* Return this method reference as a {@link CodeBlock}. If the reference is
* to an instance method then {@code this::<method name>} will be returned.
* @return a code block for the method reference.
* @see #toCodeBlock(String)
*/
public CodeBlock toCodeBlock() {
return toCodeBlock(null);
CodeBlock toCodeBlock();
/**
* Return this method reference as a {@link CodeBlock} using the specified
* {@link ArgumentCodeGenerator}.
* @param argumentCodeGenerator the argument code generator to use
* @return a code block to invoke the method
*/
default CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator) {
return toInvokeCodeBlock(argumentCodeGenerator, null);
}
/**
* Return this method reference as a {@link CodeBlock}. If the reference is
* to an instance method and {@code instanceVariable} is {@code null} then
* {@code this::<method name>} will be returned. No {@code instanceVariable}
* can be specified for static method references.
* @param instanceVariable the instance variable or {@code null}
* @return a code block for the method reference.
* @see #toCodeBlock(String)
* Return this method reference as a {@link CodeBlock} using the specified
* {@link ArgumentCodeGenerator}. The {@code targetClassName} defines the
* context in which the method invocation is added.
* <p>If the caller has an instance of the type in which this method is
* defined, it can hint that by specifying the type as a target class.
* @param argumentCodeGenerator the argument code generator to use
* @param targetClassName the target class name
* @return a code block to invoke the method
*/
public CodeBlock toCodeBlock(@Nullable String instanceVariable) {
return switch (this.kind) {
case INSTANCE -> toCodeBlockForInstance(instanceVariable);
case STATIC -> toCodeBlockForStatic(instanceVariable);
};
}
CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator, @Nullable ClassName targetClassName);
private CodeBlock toCodeBlockForInstance(@Nullable String instanceVariable) {
instanceVariable = (instanceVariable != null) ? instanceVariable : "this";
return CodeBlock.of("$L::$L", instanceVariable, this.methodName);
}
private CodeBlock toCodeBlockForStatic(@Nullable String instanceVariable) {
Assert.isTrue(instanceVariable == null,
"'instanceVariable' must be null for static method references");
return CodeBlock.of("$T::$L", this.declaringClass, this.methodName);
}
/**
* Return this method reference as an invocation {@link CodeBlock}.
* @param arguments the method arguments
* @return a code back to invoke the method
* Strategy for generating code for arguments based on their type.
*/
public CodeBlock toInvokeCodeBlock(CodeBlock... arguments) {
return toInvokeCodeBlock(null, arguments);
}
interface ArgumentCodeGenerator {
/**
* Return this method reference as an invocation {@link CodeBlock}.
* @param instanceVariable the instance variable or {@code null}
* @param arguments the method arguments
* @return a code back to invoke the method
*/
public CodeBlock toInvokeCodeBlock(@Nullable String instanceVariable,
CodeBlock... arguments) {
/**
* Generate the code for the given argument type. If this type is
* not supported, return {@code null}.
* @param argumentType the argument type
* @return the code for this argument, or {@code null}
*/
@Nullable
CodeBlock generateCode(TypeName argumentType);
return switch (this.kind) {
case INSTANCE -> toInvokeCodeBlockForInstance(instanceVariable, arguments);
case STATIC -> toInvokeCodeBlockForStatic(instanceVariable, arguments);
};
}
private CodeBlock toInvokeCodeBlockForInstance(@Nullable String instanceVariable,
CodeBlock[] arguments) {
CodeBlock.Builder code = CodeBlock.builder();
if (instanceVariable != null) {
code.add("$L.", instanceVariable);
/**
* Factory method that returns an {@link ArgumentCodeGenerator} that
* always returns {@code null}.
* @return a new {@link ArgumentCodeGenerator} instance
*/
static ArgumentCodeGenerator none() {
return from(type -> null);
}
else if (this.declaringClass != null) {
code.add("new $T().", this.declaringClass);
/**
* Factory method that can be used to create an {@link ArgumentCodeGenerator}
* that support only the given argument type.
* @param argumentType the argument type
* @param argumentCode the code for an argument of that type
* @return a new {@link ArgumentCodeGenerator} instance
*/
static ArgumentCodeGenerator of(Class<?> argumentType, String argumentCode) {
return from(candidateType -> (candidateType.equals(ClassName.get(argumentType))
? CodeBlock.of(argumentCode) : null));
}
code.add("$L", this.methodName);
addArguments(code, arguments);
return code.build();
}
private CodeBlock toInvokeCodeBlockForStatic(@Nullable String instanceVariable,
CodeBlock[] arguments) {
Assert.isTrue(instanceVariable == null,
"'instanceVariable' must be null for static method references");
CodeBlock.Builder code = CodeBlock.builder();
code.add("$T.$L", this.declaringClass, this.methodName);
addArguments(code, arguments);
return code.build();
}
private void addArguments(CodeBlock.Builder code, CodeBlock[] arguments) {
code.add("(");
for (int i = 0; i < arguments.length; i++) {
if (i != 0) {
code.add(", ");
}
code.add(arguments[i]);
/**
* Factory method that creates a new {@link ArgumentCodeGenerator} from
* a lambda friendly function. The given function is provided with the
* argument type and must provide the code to use or {@code null} if
* the type is not supported.
* @param function the resolver function
* @return a new {@link ArgumentCodeGenerator} instance backed by the function
*/
static ArgumentCodeGenerator from(Function<TypeName, CodeBlock> function) {
return function::apply;
}
code.add(")");
}
@Override
public String toString() {
return switch (this.kind) {
case INSTANCE -> ((this.declaringClass != null) ? "<" + this.declaringClass + ">"
: "<instance>") + "::" + this.methodName;
case STATIC -> this.declaringClass + "::" + this.methodName;
};
}
/**
* Create a new composed {@link ArgumentCodeGenerator} by combining this
* generator with supporting the given argument type.
* @param argumentType the argument type
* @param argumentCode the code for an argument of that type
* @return a new composite {@link ArgumentCodeGenerator} instance
*/
default ArgumentCodeGenerator and(Class<?> argumentType, String argumentCode) {
return and(ArgumentCodeGenerator.of(argumentType, argumentCode));
}
/**
* Create a new composed {@link ArgumentCodeGenerator} by combining this
* generator with the given generator.
* @param argumentCodeGenerator the argument generator to add
* @return a new composite {@link ArgumentCodeGenerator} instance
*/
default ArgumentCodeGenerator and(ArgumentCodeGenerator argumentCodeGenerator) {
return from(type -> {
CodeBlock code = generateCode(type);
return (code != null ? code : argumentCodeGenerator.generateCode(type));
});
}
private enum Kind {
INSTANCE, STATIC
}
}

View File

@ -0,0 +1,199 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.aot.generate;
import javax.lang.model.element.Modifier;
import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.MethodSpec.Builder;
import org.springframework.javapoet.TypeName;
import org.springframework.lang.Nullable;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
/**
* Tests for {@link DefaultMethodReference}.
*
* @author Phillip Webb
* @author Stephane Nicoll
*/
class DefaultMethodReferenceTests {
private static final String EXPECTED_STATIC = "org.springframework.aot.generate.DefaultMethodReferenceTests::someMethod";
private static final String EXPECTED_ANONYMOUS_INSTANCE = "<instance>::someMethod";
private static final String EXPECTED_DECLARED_INSTANCE = "<org.springframework.aot.generate.DefaultMethodReferenceTests>::someMethod";
private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test");
private static final ClassName INITIALIZER_CLASS_NAME = ClassName.get("com.example", "Initializer");
@Test
void createWithStringCreatesMethodReference() {
MethodSpec method = createTestMethod("someMethod", new TypeName[0]);
MethodReference reference = new DefaultMethodReference(method, null);
assertThat(reference).hasToString(EXPECTED_ANONYMOUS_INSTANCE);
}
@Test
void createWithClassNameAndStringCreateMethodReference() {
ClassName declaringClass = ClassName.get(DefaultMethodReferenceTests.class);
MethodReference reference = createMethodReference("someMethod", new TypeName[0], declaringClass);
assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE);
}
@Test
void createWithStaticAndClassAndStringCreatesMethodReference() {
ClassName declaringClass = ClassName.get(DefaultMethodReferenceTests.class);
MethodReference reference = createStaticMethodReference("someMethod", declaringClass);
assertThat(reference).hasToString(EXPECTED_STATIC);
}
@Test
void toCodeBlock() {
assertThat(createLocalMethodReference("methodName").toCodeBlock())
.isEqualTo(CodeBlock.of("this::methodName"));
}
@Test
void toCodeBlockWithStaticMethod() {
assertThat(createStaticMethodReference("methodName", TEST_CLASS_NAME).toCodeBlock())
.isEqualTo(CodeBlock.of("com.example.Test::methodName"));
}
@Test
void toCodeBlockWithStaticMethodRequiresDeclaringClass() {
MethodSpec method = createTestMethod("methodName", new TypeName[0], Modifier.STATIC);
MethodReference methodReference = new DefaultMethodReference(method, null);
assertThatIllegalArgumentException().isThrownBy(methodReference::toCodeBlock)
.withMessage("static method reference must define a declaring class");
}
@Test
void toInvokeCodeBlockWithNullDeclaringClassAndTargetClass() {
MethodSpec method = createTestMethod("methodName", new TypeName[0]);
MethodReference methodReference = new DefaultMethodReference(method, null);
assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME))
.isEqualTo(CodeBlock.of("methodName()"));
}
@Test
void toInvokeCodeBlockWithNullDeclaringClassAndNullTargetClass() {
MethodSpec method = createTestMethod("methodName", new TypeName[0]);
MethodReference methodReference = new DefaultMethodReference(method, null);
assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none()))
.isEqualTo(CodeBlock.of("methodName()"));
}
@Test
void toInvokeCodeBlockWithDeclaringClassAndNullTargetClass() {
MethodSpec method = createTestMethod("methodName", new TypeName[0]);
MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME);
assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none()))
.isEqualTo(CodeBlock.of("new com.example.Test().methodName()"));
}
@Test
void toInvokeCodeBlockWithMatchingTargetClass() {
MethodSpec method = createTestMethod("methodName", new TypeName[0]);
MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME);
CodeBlock invocation = methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME);
// Assume com.example.Test is in a `test` variable.
assertThat(CodeBlock.of("$L.$L", "test", invocation)).isEqualTo(CodeBlock.of("test.methodName()"));
}
@Test
void toInvokeCodeBlockWithNonMatchingDeclaringClass() {
MethodSpec method = createTestMethod("methodName", new TypeName[0]);
MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME);
assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), INITIALIZER_CLASS_NAME))
.isEqualTo(CodeBlock.of("new com.example.Test().methodName()"));
}
@Test
void toInvokeCodeBlockWithMatchingArg() {
MethodReference methodReference = createLocalMethodReference("methodName", ClassName.get(String.class));
ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(String.class, "stringArg");
assertThat(methodReference.toInvokeCodeBlock(argCodeGenerator))
.isEqualTo(CodeBlock.of("methodName(stringArg)"));
}
@Test
void toInvokeCodeBlockWithMatchingArgs() {
MethodReference methodReference = createLocalMethodReference("methodName",
ClassName.get(Integer.class), ClassName.get(String.class));
ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(String.class, "stringArg")
.and(Integer.class, "integerArg");
assertThat(methodReference.toInvokeCodeBlock(argCodeGenerator))
.isEqualTo(CodeBlock.of("methodName(integerArg, stringArg)"));
}
@Test
void toInvokeCodeBlockWithNonMatchingArg() {
MethodReference methodReference = createLocalMethodReference("methodName",
ClassName.get(Integer.class), ClassName.get(String.class));
ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(Integer.class, "integerArg");
assertThatIllegalArgumentException().isThrownBy(() -> methodReference.toInvokeCodeBlock(argCodeGenerator))
.withMessageContaining("parameter 1 of type java.lang.String is not supported");
}
@Test
void toInvokeCodeBlockWithStaticMethodAndMatchingDeclaringClass() {
MethodReference methodReference = createStaticMethodReference("methodName", TEST_CLASS_NAME);
assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME))
.isEqualTo(CodeBlock.of("methodName()"));
}
@Test
void toInvokeCodeBlockWithStaticMethodAndSeparateDeclaringClass() {
MethodReference methodReference = createStaticMethodReference("methodName", TEST_CLASS_NAME);
assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), INITIALIZER_CLASS_NAME))
.isEqualTo(CodeBlock.of("com.example.Test.methodName()"));
}
private MethodReference createLocalMethodReference(String name, TypeName... argumentTypes) {
return createMethodReference(name, argumentTypes, null);
}
private MethodReference createMethodReference(String name, TypeName[] argumentTypes, @Nullable ClassName declaringClass) {
MethodSpec method = createTestMethod(name, argumentTypes);
return new DefaultMethodReference(method, declaringClass);
}
private MethodReference createStaticMethodReference(String name, ClassName declaringClass, TypeName... argumentTypes) {
MethodSpec method = createTestMethod(name, argumentTypes, Modifier.STATIC);
return new DefaultMethodReference(method, declaringClass);
}
private MethodSpec createTestMethod(String name, TypeName[] argumentTypes, Modifier... modifiers) {
Builder method = MethodSpec.methodBuilder(name);
for (int i = 0; i < argumentTypes.length; i++) {
method.addParameter(argumentTypes[i], "args" + i);
}
method.addModifiers(modifiers);
return method.build();
}
}

View File

@ -18,8 +18,13 @@ package org.springframework.aot.generate;
import java.util.function.Consumer;
import javax.lang.model.element.Modifier;
import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec;
import static org.assertj.core.api.Assertions.assertThat;
@ -29,30 +34,56 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
* Tests for {@link GeneratedMethod}.
*
* @author Phillip Webb
* @author Stephane Nicoll
*/
class GeneratedMethodTests {
private static final Consumer<MethodSpec.Builder> methodSpecCustomizer = method -> {};
private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test");
private static final Consumer<MethodSpec.Builder> emptyMethod = method -> {};
private static final String NAME = "spring";
@Test
void getNameReturnsName() {
GeneratedMethod generatedMethod = new GeneratedMethod(NAME, methodSpecCustomizer);
GeneratedMethod generatedMethod = new GeneratedMethod(TEST_CLASS_NAME, NAME, emptyMethod);
assertThat(generatedMethod.getName()).isSameAs(NAME);
}
@Test
void generateMethodSpecReturnsMethodSpec() {
GeneratedMethod generatedMethod = new GeneratedMethod(NAME, method -> method.addJavadoc("Test"));
GeneratedMethod generatedMethod = create(method -> method.addJavadoc("Test"));
assertThat(generatedMethod.getMethodSpec().javadoc).asString().contains("Test");
}
@Test
void generateMethodSpecWhenMethodNameIsChangedThrowsException() {
assertThatIllegalStateException().isThrownBy(() ->
new GeneratedMethod(NAME, method -> method.setName("badname")).getMethodSpec())
.withMessage("'method' consumer must not change the generated method name");
create(method -> method.setName("badname")).getMethodSpec())
.withMessage("'method' consumer must not change the generated method name");
}
@Test
void toMethodReferenceWithInstanceMethod() {
GeneratedMethod generatedMethod = create(emptyMethod);
MethodReference methodReference = generatedMethod.toMethodReference();
assertThat(methodReference).isNotNull();
assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME))
.isEqualTo(CodeBlock.of("spring()"));
}
@Test
void toMethodReferenceWithStaticMethod() {
GeneratedMethod generatedMethod = create(method -> method.addModifiers(Modifier.STATIC));
MethodReference methodReference = generatedMethod.toMethodReference();
assertThat(methodReference).isNotNull();
ClassName anotherDeclaringClass = ClassName.get("com.example", "Another");
assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), anotherDeclaringClass))
.isEqualTo(CodeBlock.of("com.example.Test.spring()"));
}
private GeneratedMethod create(Consumer<MethodSpec.Builder> method) {
return new GeneratedMethod(TEST_CLASS_NAME, NAME, method);
}
}

View File

@ -23,6 +23,7 @@ import java.util.function.Function;
import org.junit.jupiter.api.Test;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.MethodSpec;
import static org.assertj.core.api.Assertions.assertThat;
@ -32,38 +33,49 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
* Tests for {@link GeneratedMethods}.
*
* @author Phillip Webb
* @author Stephane Nicoll
*/
class GeneratedMethodsTests {
private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test");
private static final Consumer<MethodSpec.Builder> methodSpecCustomizer = method -> {};
private final GeneratedMethods methods = new GeneratedMethods(MethodName::toString);
private final GeneratedMethods methods = new GeneratedMethods(TEST_CLASS_NAME, MethodName::toString);
@Test
void createWhenClassNameIsNullThrowsException() {
assertThatIllegalArgumentException().isThrownBy(() ->
new GeneratedMethods(null, MethodName::toString))
.withMessage("'className' must not be null");
}
@Test
void createWhenMethodNameGeneratorIsNullThrowsException() {
assertThatIllegalArgumentException().isThrownBy(() -> new GeneratedMethods(null))
assertThatIllegalArgumentException().isThrownBy(() ->
new GeneratedMethods(TEST_CLASS_NAME, null))
.withMessage("'methodNameGenerator' must not be null");
}
@Test
void createWithExistingGeneratorUsesGenerator() {
Function<MethodName, String> generator = name -> "__" + name.toString();
GeneratedMethods methods = new GeneratedMethods(generator);
GeneratedMethods methods = new GeneratedMethods(TEST_CLASS_NAME, generator);
assertThat(methods.add("test", methodSpecCustomizer).getName()).hasToString("__test");
}
@Test
void addWithStringNameWhenSuggestedMethodIsNullThrowsException() {
assertThatIllegalArgumentException().isThrownBy(() ->
this.methods.add((String) null, methodSpecCustomizer))
.withMessage("'suggestedName' must not be null");
this.methods.add((String) null, methodSpecCustomizer))
.withMessage("'suggestedName' must not be null");
}
@Test
void addWithStringNameWhenMethodIsNullThrowsException() {
assertThatIllegalArgumentException().isThrownBy(() ->
this.methods.add("test", null))
.withMessage("'method' must not be null");
this.methods.add("test", null))
.withMessage("'method' must not be null");
}
@Test
@ -71,7 +83,7 @@ class GeneratedMethodsTests {
this.methods.add("springBeans", methodSpecCustomizer);
this.methods.add("springContext", methodSpecCustomizer);
assertThat(this.methods.stream().map(GeneratedMethod::getName).map(Object::toString))
.containsExactly("springBeans", "springContext");
.containsExactly("springBeans", "springContext");
}
@Test

View File

@ -1,226 +0,0 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.aot.generate;
import org.junit.jupiter.api.Test;
import org.springframework.javapoet.ClassName;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
/**
* Tests for {@link MethodReference}.
*
* @author Phillip Webb
*/
class MethodReferenceTests {
private static final String EXPECTED_STATIC = "org.springframework.aot.generate.MethodReferenceTests::someMethod";
private static final String EXPECTED_ANONYMOUS_INSTANCE = "<instance>::someMethod";
private static final String EXPECTED_DECLARED_INSTANCE = "<org.springframework.aot.generate.MethodReferenceTests>::someMethod";
@Test
void ofWithStringWhenMethodNameIsNullThrowsException() {
String methodName = null;
assertThatIllegalArgumentException()
.isThrownBy(() -> MethodReference.of(methodName))
.withMessage("'methodName' must not be empty");
}
@Test
void ofWithStringCreatesMethodReference() {
String methodName = "someMethod";
MethodReference reference = MethodReference.of(methodName);
assertThat(reference).hasToString(EXPECTED_ANONYMOUS_INSTANCE);
}
@Test
void ofWithClassAndStringWhenDeclaringClassIsNullThrowsException() {
Class<?> declaringClass = null;
String methodName = "someMethod";
assertThatIllegalArgumentException()
.isThrownBy(() -> MethodReference.of(declaringClass, methodName))
.withMessage("'declaringClass' must not be null");
}
@Test
void ofWithClassAndStringWhenMethodNameIsNullThrowsException() {
Class<?> declaringClass = MethodReferenceTests.class;
String methodName = null;
assertThatIllegalArgumentException()
.isThrownBy(() -> MethodReference.of(declaringClass, methodName))
.withMessage("'methodName' must not be empty");
}
@Test
void ofWithClassAndStringCreatesMethodReference() {
Class<?> declaringClass = MethodReferenceTests.class;
String methodName = "someMethod";
MethodReference reference = MethodReference.of(declaringClass, methodName);
assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE);
}
@Test
void ofWithClassNameAndStringWhenDeclaringClassIsNullThrowsException() {
ClassName declaringClass = null;
String methodName = "someMethod";
assertThatIllegalArgumentException()
.isThrownBy(() -> MethodReference.of(declaringClass, methodName))
.withMessage("'declaringClass' must not be null");
}
@Test
void ofWithClassNameAndStringWhenMethodNameIsNullThrowsException() {
ClassName declaringClass = ClassName.get(MethodReferenceTests.class);
String methodName = null;
assertThatIllegalArgumentException()
.isThrownBy(() -> MethodReference.of(declaringClass, methodName))
.withMessage("'methodName' must not be empty");
}
@Test
void ofWithClassNameAndStringCreateMethodReference() {
ClassName declaringClass = ClassName.get(MethodReferenceTests.class);
String methodName = "someMethod";
MethodReference reference = MethodReference.of(declaringClass, methodName);
assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE);
}
@Test
void ofStaticWithClassAndStringWhenDeclaringClassIsNullThrowsException() {
Class<?> declaringClass = null;
String methodName = "someMethod";
assertThatIllegalArgumentException()
.isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName))
.withMessage("'declaringClass' must not be null");
}
@Test
void ofStaticWithClassAndStringWhenMethodNameIsEmptyThrowsException() {
Class<?> declaringClass = MethodReferenceTests.class;
String methodName = null;
assertThatIllegalArgumentException()
.isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName))
.withMessage("'methodName' must not be empty");
}
@Test
void ofStaticWithClassAndStringCreatesMethodReference() {
Class<?> declaringClass = MethodReferenceTests.class;
String methodName = "someMethod";
MethodReference reference = MethodReference.ofStatic(declaringClass, methodName);
assertThat(reference).hasToString(EXPECTED_STATIC);
}
@Test
void ofStaticWithClassNameAndGeneratedMethodNameWhenDeclaringClassIsNullThrowsException() {
ClassName declaringClass = null;
String methodName = "someMethod";
assertThatIllegalArgumentException()
.isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName))
.withMessage("'declaringClass' must not be null");
}
@Test
void ofStaticWithClassNameAndGeneratedMethodNameWhenMethodNameIsEmptyThrowsException() {
ClassName declaringClass = ClassName.get(MethodReferenceTests.class);
String methodName = null;
assertThatIllegalArgumentException()
.isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName))
.withMessage("'methodName' must not be empty");
}
@Test
void ofStaticWithClassNameAndGeneratedMethodNameCreatesMethodReference() {
ClassName declaringClass = ClassName.get(MethodReferenceTests.class);
String methodName = "someMethod";
MethodReference reference = MethodReference.ofStatic(declaringClass, methodName);
assertThat(reference).hasToString(EXPECTED_STATIC);
}
@Test
void toCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNull() {
MethodReference reference = MethodReference.of("someMethod");
assertThat(reference.toCodeBlock(null)).hasToString("this::someMethod");
}
@Test
void toCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNotNull() {
MethodReference reference = MethodReference.of("someMethod");
assertThat(reference.toCodeBlock("myInstance"))
.hasToString("myInstance::someMethod");
}
@Test
void toCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNull() {
MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class,
"someMethod");
assertThat(reference.toCodeBlock(null)).hasToString(EXPECTED_STATIC);
}
@Test
void toCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNotNullThrowsException() {
MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class,
"someMethod");
assertThatIllegalArgumentException()
.isThrownBy(() -> reference.toCodeBlock("myInstance")).withMessage(
"'instanceVariable' must be null for static method references");
}
@Test
void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNull() {
MethodReference reference = MethodReference.of("someMethod");
assertThat(reference.toInvokeCodeBlock()).hasToString("someMethod()");
}
@Test
void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNullAndHasDecalredClass() {
MethodReference reference = MethodReference.of(MethodReferenceTests.class,
"someMethod");
assertThat(reference.toInvokeCodeBlock()).hasToString(
"new org.springframework.aot.generate.MethodReferenceTests().someMethod()");
}
@Test
void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNotNull() {
MethodReference reference = MethodReference.of("someMethod");
assertThat(reference.toInvokeCodeBlock("myInstance"))
.hasToString("myInstance.someMethod()");
}
@Test
void toInvokeCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNull() {
MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class,
"someMethod");
assertThat(reference.toInvokeCodeBlock()).hasToString(
"org.springframework.aot.generate.MethodReferenceTests.someMethod()");
}
@Test
void toInvokeCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNotNullThrowsException() {
MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class,
"someMethod");
assertThatIllegalArgumentException()
.isThrownBy(() -> reference.toInvokeCodeBlock("myInstance")).withMessage(
"'instanceVariable' must be null for static method references");
}
}

View File

@ -99,7 +99,7 @@ class PersistenceManagedTypesBeanRegistrationAotProcessor implements BeanRegistr
List.class, toCodeBlock(persistenceManagedTypes.getManagedPackages()));
method.addStatement("return $T.of($L, $L)", beanType, "managedClassNames", "managedPackages");
});
return CodeBlock.of("() -> $T.$L()", beanRegistrationCode.getClassName(), generatedMethod.getName());
return generatedMethod.toMethodReference().toCodeBlock();
}
private CodeBlock toCodeBlock(List<String> values) {

View File

@ -43,7 +43,6 @@ import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.PropertyValues;
@ -797,8 +796,7 @@ public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwar
method.returns(this.target);
method.addCode(generateMethodCode(generationContext.getRuntimeHints(), generatedClass.getMethods()));
});
beanRegistrationCode.addInstancePostProcessor(MethodReference
.ofStatic(generatedClass.getName(), generatedMethod.getName()));
beanRegistrationCode.addInstancePostProcessor(generatedMethod.toMethodReference());
}
private CodeBlock generateMethodCode(RuntimeHints hints, GeneratedMethods generatedMethods) {