Add AOT support for Kotlin constructors with optional parameters

This commit leverages Kotlin reflection to instantiate classes
with constructors using optional parameters in the code
generated AOT.

Closes gh-29820
This commit is contained in:
Sébastien Deleuze 2023-07-07 17:20:58 +02:00
parent 20dd66cd5a
commit a03a14797f
5 changed files with 223 additions and 8 deletions

View File

@ -27,6 +27,7 @@ import java.util.stream.Collectors;
import org.springframework.aot.hint.ExecutableMode; import org.springframework.aot.hint.ExecutableMode;
import org.springframework.beans.BeanInstantiationException; import org.springframework.beans.BeanInstantiationException;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.BeansException; import org.springframework.beans.BeansException;
import org.springframework.beans.TypeConverter; import org.springframework.beans.TypeConverter;
import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactory;
@ -343,8 +344,7 @@ public final class BeanInstanceSupplier<T> extends AutowiredElementResolver impl
Object enclosingInstance = createInstance(declaringClass.getEnclosingClass()); Object enclosingInstance = createInstance(declaringClass.getEnclosingClass());
args = ObjectUtils.addObjectToArray(args, enclosingInstance, 0); args = ObjectUtils.addObjectToArray(args, enclosingInstance, 0);
} }
ReflectionUtils.makeAccessible(constructor); return BeanUtils.instantiateClass(constructor, args);
return constructor.newInstance(args);
} }
private Object instantiate(ConfigurableBeanFactory beanFactory, Method method, Object[] args) throws Exception { private Object instantiate(ConfigurableBeanFactory beanFactory, Method method, Object[] args) throws Exception {

View File

@ -24,6 +24,11 @@ import java.lang.reflect.Modifier;
import java.util.Arrays; import java.util.Arrays;
import java.util.function.Consumer; import java.util.function.Consumer;
import kotlin.jvm.JvmClassMappingKt;
import kotlin.reflect.KClass;
import kotlin.reflect.KFunction;
import kotlin.reflect.KParameter;
import org.springframework.aot.generate.AccessControl; import org.springframework.aot.generate.AccessControl;
import org.springframework.aot.generate.AccessControl.Visibility; import org.springframework.aot.generate.AccessControl.Visibility;
import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethod;
@ -31,8 +36,11 @@ import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.hint.ExecutableMode; import org.springframework.aot.hint.ExecutableMode;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.ReflectionHints;
import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.InstanceSupplier;
import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.core.KotlinDetector;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
@ -56,6 +64,7 @@ import org.springframework.util.function.ThrowingSupplier;
* @author Phillip Webb * @author Phillip Webb
* @author Stephane Nicoll * @author Stephane Nicoll
* @author Juergen Hoeller * @author Juergen Hoeller
* @author Sebastien Deleuze
* @since 6.0 * @since 6.0
*/ */
class InstanceSupplierCodeGenerator { class InstanceSupplierCodeGenerator {
@ -108,11 +117,16 @@ class InstanceSupplierCodeGenerator {
boolean dependsOnBean = ClassUtils.isInnerClass(declaringClass); boolean dependsOnBean = ClassUtils.isInnerClass(declaringClass);
Visibility accessVisibility = getAccessVisibility(registeredBean, constructor); Visibility accessVisibility = getAccessVisibility(registeredBean, constructor);
if (accessVisibility != Visibility.PRIVATE) { if (KotlinDetector.isKotlinReflectPresent() && KotlinDelegate.hasConstructorWithOptionalParameter(beanClass)) {
return generateCodeForInaccessibleConstructor(beanName, beanClass, constructor,
dependsOnBean, hints -> hints.registerType(beanClass, MemberCategory.INVOKE_DECLARED_CONSTRUCTORS));
}
else if (accessVisibility != Visibility.PRIVATE) {
return generateCodeForAccessibleConstructor(beanName, beanClass, constructor, return generateCodeForAccessibleConstructor(beanName, beanClass, constructor,
dependsOnBean, declaringClass); dependsOnBean, declaringClass);
} }
return generateCodeForInaccessibleConstructor(beanName, beanClass, constructor, dependsOnBean); return generateCodeForInaccessibleConstructor(beanName, beanClass, constructor, dependsOnBean,
hints -> hints.registerConstructor(constructor, ExecutableMode.INVOKE));
} }
private CodeBlock generateCodeForAccessibleConstructor(String beanName, Class<?> beanClass, private CodeBlock generateCodeForAccessibleConstructor(String beanName, Class<?> beanClass,
@ -137,11 +151,10 @@ class InstanceSupplierCodeGenerator {
return generateReturnStatement(generatedMethod); return generateReturnStatement(generatedMethod);
} }
private CodeBlock generateCodeForInaccessibleConstructor(String beanName, private CodeBlock generateCodeForInaccessibleConstructor(String beanName, Class<?> beanClass,
Class<?> beanClass, Constructor<?> constructor, boolean dependsOnBean) { Constructor<?> constructor, boolean dependsOnBean, Consumer<ReflectionHints> hints) {
this.generationContext.getRuntimeHints().reflection() hints.accept(this.generationContext.getRuntimeHints().reflection());
.registerConstructor(constructor, ExecutableMode.INVOKE);
GeneratedMethod generatedMethod = generateGetInstanceSupplierMethod(method -> { GeneratedMethod generatedMethod = generateGetInstanceSupplierMethod(method -> {
method.addJavadoc("Get the bean instance supplier for '$L'.", beanName); method.addJavadoc("Get the bean instance supplier for '$L'.", beanName);
@ -337,4 +350,25 @@ class InstanceSupplierCodeGenerator {
.anyMatch(Exception.class::isAssignableFrom); .anyMatch(Exception.class::isAssignableFrom);
} }
/**
* Inner class to avoid a hard dependency on Kotlin at runtime.
*/
private static class KotlinDelegate {
public static boolean hasConstructorWithOptionalParameter(Class<?> beanClass) {
if (KotlinDetector.isKotlinType(beanClass)) {
KClass<?> kClass = JvmClassMappingKt.getKotlinClass(beanClass);
for (KFunction<?> constructor : kClass.getConstructors()) {
for (KParameter parameter : constructor.getParameters()) {
if (parameter.isOptional()) {
return true;
}
}
}
}
return false;
}
}
} }

View File

@ -0,0 +1,143 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.beans.factory.aot
import org.assertj.core.api.Assertions
import org.assertj.core.api.ThrowingConsumer
import org.junit.jupiter.api.Test
import org.springframework.aot.hint.*
import org.springframework.aot.test.generate.TestGenerationContext
import org.springframework.beans.factory.config.BeanDefinition
import org.springframework.beans.factory.support.DefaultListableBeanFactory
import org.springframework.beans.factory.support.InstanceSupplier
import org.springframework.beans.factory.support.RegisteredBean
import org.springframework.beans.factory.support.RootBeanDefinition
import org.springframework.beans.testfixture.beans.KotlinTestBean
import org.springframework.beans.testfixture.beans.KotlinTestBeanWithOptionalParameter
import org.springframework.beans.testfixture.beans.factory.aot.DeferredTypeBuilder
import org.springframework.core.test.tools.Compiled
import org.springframework.core.test.tools.TestCompiler
import org.springframework.javapoet.MethodSpec
import org.springframework.javapoet.ParameterizedTypeName
import org.springframework.javapoet.TypeSpec
import java.util.function.BiConsumer
import java.util.function.Supplier
import javax.lang.model.element.Modifier
/**
* Kotlin tests for [InstanceSupplierCodeGenerator].
*
* @author Sebastien Deleuze
*/
class InstanceSupplierCodeGeneratorKotlinTests {
private val generationContext = TestGenerationContext()
@Test
fun generateWhenHasDefaultConstructor() {
val beanDefinition: BeanDefinition = RootBeanDefinition(KotlinTestBean::class.java)
val beanFactory = DefaultListableBeanFactory()
compile(beanFactory, beanDefinition) { instanceSupplier, compiled ->
val bean = getBean<KotlinTestBean>(beanFactory, beanDefinition, instanceSupplier)
Assertions.assertThat(bean).isInstanceOf(KotlinTestBean::class.java)
Assertions.assertThat(compiled.sourceFile).contains("InstanceSupplier.using(KotlinTestBean::new)")
}
Assertions.assertThat(getReflectionHints().getTypeHint(KotlinTestBean::class.java))
.satisfies(hasConstructorWithMode(ExecutableMode.INTROSPECT))
}
@Test
fun generateWhenConstructorHasOptionalParameter() {
val beanDefinition: BeanDefinition = RootBeanDefinition(KotlinTestBeanWithOptionalParameter::class.java)
val beanFactory = DefaultListableBeanFactory()
compile(beanFactory, beanDefinition) { instanceSupplier, compiled ->
val bean: KotlinTestBeanWithOptionalParameter = getBean(beanFactory, beanDefinition, instanceSupplier)
Assertions.assertThat(bean).isInstanceOf(KotlinTestBeanWithOptionalParameter::class.java)
Assertions.assertThat(compiled.sourceFile)
.contains("return BeanInstanceSupplier.<KotlinTestBeanWithOptionalParameter>forConstructor();")
}
Assertions.assertThat<TypeHint>(getReflectionHints().getTypeHint(KotlinTestBeanWithOptionalParameter::class.java))
.satisfies(hasMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS))
}
private fun getReflectionHints(): ReflectionHints {
return generationContext.runtimeHints.reflection()
}
private fun hasConstructorWithMode(mode: ExecutableMode): ThrowingConsumer<TypeHint> {
return ThrowingConsumer {
Assertions.assertThat(it.constructors()).anySatisfy(hasMode(mode))
}
}
private fun hasMemberCategory(category: MemberCategory): ThrowingConsumer<TypeHint> {
return ThrowingConsumer {
Assertions.assertThat(it.memberCategories).contains(category)
}
}
private fun hasMode(mode: ExecutableMode): ThrowingConsumer<ExecutableHint> {
return ThrowingConsumer {
Assertions.assertThat(it.mode).isEqualTo(mode)
}
}
@Suppress("UNCHECKED_CAST")
private fun <T> getBean(beanFactory: DefaultListableBeanFactory, beanDefinition: BeanDefinition,
instanceSupplier: InstanceSupplier<*>): T {
(beanDefinition as RootBeanDefinition).instanceSupplier = instanceSupplier
beanFactory.registerBeanDefinition("testBean", beanDefinition)
return beanFactory.getBean("testBean") as T
}
private fun compile(beanFactory: DefaultListableBeanFactory, beanDefinition: BeanDefinition,
result: BiConsumer<InstanceSupplier<*>, Compiled>) {
val freshBeanFactory = DefaultListableBeanFactory(beanFactory)
freshBeanFactory.registerBeanDefinition("testBean", beanDefinition)
val registeredBean = RegisteredBean.of(freshBeanFactory, "testBean")
val typeBuilder = DeferredTypeBuilder()
val generateClass = generationContext.generatedClasses.addForFeature("TestCode", typeBuilder)
val generator = InstanceSupplierCodeGenerator(
generationContext, generateClass.name,
generateClass.methods, false
)
val constructorOrFactoryMethod = registeredBean.resolveConstructorOrFactoryMethod()
Assertions.assertThat(constructorOrFactoryMethod).isNotNull()
val generatedCode = generator.generateCode(registeredBean, constructorOrFactoryMethod)
typeBuilder.set { type: TypeSpec.Builder ->
type.addModifiers(Modifier.PUBLIC)
type.addSuperinterface(
ParameterizedTypeName.get(
Supplier::class.java,
InstanceSupplier::class.java
)
)
type.addMethod(
MethodSpec.methodBuilder("get")
.addModifiers(Modifier.PUBLIC)
.returns(InstanceSupplier::class.java)
.addStatement("return \$L", generatedCode).build()
)
}
generationContext.writeGeneratedContent()
TestCompiler.forSystem().with(generationContext).compile {
result.accept(it.getInstance(Supplier::class.java).get() as InstanceSupplier<*>, it)
}
}
}

View File

@ -0,0 +1,19 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.beans.testfixture.beans
class KotlinTestBean

View File

@ -0,0 +1,19 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.beans.testfixture.beans
class KotlinTestBeanWithOptionalParameter(private val other: KotlinTestBean = KotlinTestBean())