Support Kotlin extensions in web handlers

This commit restores support for Kotlin extensions in
web handlers, and adds support for invoking reflectively
suspending extension functions, as well as the other
features supported as of Spring Framework 6.1 like
value classes and default value for parameters.

Closes gh-31876
This commit is contained in:
Sébastien Deleuze 2023-12-21 12:15:26 +01:00
parent 5f8a031c22
commit 85cb6cc5fb
6 changed files with 130 additions and 25 deletions

View File

@ -117,8 +117,11 @@ public abstract class CoroutinesUtils {
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE -> {
case INSTANCE:
argMap.put(parameter, target);
break;
case VALUE:
case EXTENSION_RECEIVER:
if (!parameter.isOptional() || args[index] != null) {
if (parameter.getType().getClassifier() instanceof KClass<?> kClass && kClass.isValue()) {
Class<?> javaClass = JvmClassMappingKt.getJavaClass(kClass);
@ -131,7 +134,8 @@ public abstract class CoroutinesUtils {
}
}
index++;
}
break;
}
}
return KCallables.callSuspendBy(function, argMap, continuation);

View File

@ -154,6 +154,26 @@ class CoroutinesUtilsTests {
}
}
@Test
fun invokeSuspendingFunctionWithExtension() {
val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithExtension",
CustomException::class.java, Continuation::class.java)
val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, CustomException("foo")) as Mono
runBlocking {
Assertions.assertThat(mono.awaitSingleOrNull()).isEqualTo("foo")
}
}
@Test
fun invokeSuspendingFunctionWithExtensionAndParameter() {
val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithExtensionAndParameter",
CustomException::class.java, Int::class.java, Continuation::class.java)
val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, CustomException("foo"), 20) as Mono
runBlocking {
Assertions.assertThat(mono.awaitSingleOrNull()).isEqualTo("foo-20")
}
}
suspend fun suspendingFunction(value: String): String {
delay(1)
return value
@ -186,7 +206,19 @@ class CoroutinesUtilsTests {
return value.value
}
suspend fun CustomException.suspendingFunctionWithExtension(): String {
delay(1)
return "${this.message}"
}
suspend fun CustomException.suspendingFunctionWithExtensionAndParameter(limit: Int): String {
delay(1)
return "${this.message}-$limit"
}
@JvmInline
value class ValueClass(val value: String)
class CustomException(message: String) : Throwable(message)
}

View File

@ -318,8 +318,11 @@ public class InvocableHandlerMethod extends HandlerMethod {
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE -> {
case INSTANCE:
argMap.put(parameter, target);
break;
case VALUE:
case EXTENSION_RECEIVER:
if (!parameter.isOptional() || args[index] != null) {
if (parameter.getType().getClassifier() instanceof KClass<?> kClass && kClass.isValue()) {
Class<?> javaClass = JvmClassMappingKt.getJavaClass(kClass);
@ -332,7 +335,8 @@ public class InvocableHandlerMethod extends HandlerMethod {
}
}
index++;
}
break;
}
}
Object result = function.callBy(argMap);

View File

@ -104,6 +104,22 @@ class InvocableHandlerMethodKotlinTests {
Assertions.assertThat(value).isEqualTo("foo")
}
@Test
fun extension() {
composite.addResolver(StubArgumentResolver(CustomException::class.java, CustomException("foo")))
val value = getInvocable(ExtensionHandler::class.java, CustomException::class.java).invokeForRequest(request, null)
Assertions.assertThat(value).isEqualTo("foo")
}
@Test
fun extensionWithParameter() {
composite.addResolver(StubArgumentResolver(CustomException::class.java, CustomException("foo")))
composite.addResolver(StubArgumentResolver(Int::class.java, 20))
val value = getInvocable(ExtensionHandler::class.java, CustomException::class.java, Int::class.java)
.invokeForRequest(request, null)
Assertions.assertThat(value).isEqualTo("foo-20")
}
private fun getInvocable(clazz: Class<*>, vararg argTypes: Class<*>): InvocableHandlerMethod {
val method = ResolvableMethod.on(clazz).argTypes(*argTypes).resolveMethod()
val handlerMethod = InvocableHandlerMethod(clazz.constructors.first().newInstance(), method)
@ -150,10 +166,23 @@ class InvocableHandlerMethodKotlinTests {
get() = "foo"
}
private class ExtensionHandler {
fun CustomException.handle(): String {
return "${this.message}"
}
fun CustomException.handleWithParameter(limit: Int): String {
return "${this.message}-$limit"
}
}
@JvmInline
value class LongValueClass(val value: Long)
@JvmInline
value class DoubleValueClass(val value: Double)
class CustomException(message: String) : Throwable(message)
}

View File

@ -329,8 +329,11 @@ public class InvocableHandlerMethod extends HandlerMethod {
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE -> {
case INSTANCE:
argMap.put(parameter, target);
break;
case VALUE:
case EXTENSION_RECEIVER:
if (!parameter.isOptional() || args[index] != null) {
if (parameter.getType().getClassifier() instanceof KClass<?> kClass && kClass.isValue()) {
Class<?> javaClass = JvmClassMappingKt.getJavaClass(kClass);
@ -343,7 +346,7 @@ public class InvocableHandlerMethod extends HandlerMethod {
}
}
index++;
}
break;
}
}
Object result = function.callBy(argMap);

View File

@ -21,6 +21,7 @@ import io.mockk.mockk
import kotlinx.coroutines.delay
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.springframework.core.MethodParameter
import org.springframework.core.ReactiveAdapterRegistry
import org.springframework.http.HttpStatus
import org.springframework.http.server.reactive.ServerHttpResponse
@ -34,6 +35,7 @@ import org.springframework.web.reactive.result.method.InvocableHandlerMethod
import org.springframework.web.reactive.result.method.annotation.ContinuationHandlerMethodArgumentResolver
import org.springframework.web.reactive.result.method.annotation.RequestParamMethodArgumentResolver
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.get
import org.springframework.web.testfixture.method.ResolvableMethod
import org.springframework.web.testfixture.server.MockServerWebExchange
import reactor.core.publisher.Mono
import reactor.test.StepVerifier
@ -55,7 +57,7 @@ class InvocableHandlerMethodKotlinTests {
@Test
fun resolveNoArg() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, String::class.java))
val method = CoroutinesController::singleArg.javaMethod!!
val result = invoke(CoroutinesController(), method, null)
assertHandlerResultValue(result, "success:null")
@ -116,7 +118,7 @@ class InvocableHandlerMethodKotlinTests {
@Test
fun defaultValue() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, String::class.java))
val method = DefaultValueController::handle.javaMethod!!
val result = invoke(DefaultValueController(), method)
assertHandlerResultValue(result, "default")
@ -124,7 +126,7 @@ class InvocableHandlerMethodKotlinTests {
@Test
fun defaultValueOverridden() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, String::class.java))
val method = DefaultValueController::handle.javaMethod!!
exchange = MockServerWebExchange.from(get("http://localhost:8080/path").queryParam("value", "override"))
val result = invoke(DefaultValueController(), method)
@ -133,7 +135,7 @@ class InvocableHandlerMethodKotlinTests {
@Test
fun defaultValues() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, Int::class.java))
val method = DefaultValueController::handleMultiple.javaMethod!!
val result = invoke(DefaultValueController(), method)
assertHandlerResultValue(result, "10-20")
@ -141,7 +143,7 @@ class InvocableHandlerMethodKotlinTests {
@Test
fun defaultValuesOverridden() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, Int::class.java))
val method = DefaultValueController::handleMultiple.javaMethod!!
exchange = MockServerWebExchange.from(get("http://localhost:8080/path").queryParam("limit2", "40"))
val result = invoke(DefaultValueController(), method)
@ -150,7 +152,7 @@ class InvocableHandlerMethodKotlinTests {
@Test
fun suspendingDefaultValue() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, String::class.java))
val method = DefaultValueController::handleSuspending.javaMethod!!
val result = invoke(DefaultValueController(), method)
assertHandlerResultValue(result, "default")
@ -158,7 +160,7 @@ class InvocableHandlerMethodKotlinTests {
@Test
fun suspendingDefaultValueOverridden() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, String::class.java))
val method = DefaultValueController::handleSuspending.javaMethod!!
exchange = MockServerWebExchange.from(get("http://localhost:8080/path").queryParam("value", "override"))
val result = invoke(DefaultValueController(), method)
@ -181,7 +183,7 @@ class InvocableHandlerMethodKotlinTests {
@Test
fun valueClass() {
this.resolvers.add(stubResolver(1L))
this.resolvers.add(stubResolver(1L, Long::class.java))
val method = ValueClassController::valueClass.javaMethod!!
val result = invoke(ValueClassController(), method,1L)
assertHandlerResultValue(result, "1")
@ -189,7 +191,7 @@ class InvocableHandlerMethodKotlinTests {
@Test
fun valueClassDefaultValue() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, Double::class.java))
val method = ValueClassController::valueClassWithDefault.javaMethod!!
val result = invoke(ValueClassController(), method)
assertHandlerResultValue(result, "3.1")
@ -197,12 +199,31 @@ class InvocableHandlerMethodKotlinTests {
@Test
fun propertyAccessor() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, String::class.java))
val method = PropertyAccessorController::prop.getter.javaMethod!!
val result = invoke(PropertyAccessorController(), method)
assertHandlerResultValue(result, "foo")
}
@Test
fun extension() {
this.resolvers.add(stubResolver(CustomException("foo")))
val method = ResolvableMethod.on(ExtensionHandler::class.java).argTypes(CustomException::class.java).resolveMethod()
val result = invoke(ExtensionHandler(), method)
assertHandlerResultValue(result, "foo")
}
@Test
fun extensionWithParameter() {
this.resolvers.add(stubResolver(CustomException("foo")))
this.resolvers.add(stubResolver(20, Int::class.java))
val method = ResolvableMethod.on(ExtensionHandler::class.java)
.argTypes(CustomException::class.java, Int::class.javaPrimitiveType)
.resolveMethod()
val result = invoke(ExtensionHandler(), method)
assertHandlerResultValue(result, "foo-20")
}
private fun invokeForResult(handler: Any, method: Method, vararg providedArgs: Any): HandlerResult? {
return invoke(handler, method, *providedArgs).block(Duration.ofSeconds(5))
@ -214,14 +235,13 @@ class InvocableHandlerMethodKotlinTests {
return invocable.invoke(this.exchange, BindingContext(), *providedArgs)
}
private fun stubResolver(stubValue: Any?): HandlerMethodArgumentResolver {
return stubResolver(Mono.justOrEmpty(stubValue))
}
private fun stubResolver(stubValue: Any): HandlerMethodArgumentResolver =
stubResolver(stubValue, stubValue::class.java)
private fun stubResolver(stubValue: Mono<Any>): HandlerMethodArgumentResolver {
private fun stubResolver(stubValue: Any?, stubClass: Class<*>): HandlerMethodArgumentResolver {
val resolver = mockk<HandlerMethodArgumentResolver>()
every { resolver.supportsParameter(any()) } returns true
every { resolver.resolveArgument(any(), any(), any()) } returns stubValue
every { resolver.supportsParameter(any()) } answers { (it.invocation.args[0] as MethodParameter).getParameterType() == stubClass }
every { resolver.resolveArgument(any(), any(), any()) } returns Mono.justOrEmpty(stubValue)
return resolver
}
@ -309,9 +329,22 @@ class InvocableHandlerMethodKotlinTests {
get() = "foo"
}
class ExtensionHandler {
fun CustomException.handle(): String {
return "${this.message}"
}
fun CustomException.handleWithParameter(limit: Int): String {
return "${this.message}-$limit"
}
}
@JvmInline
value class LongValueClass(val value: Long)
@JvmInline
value class DoubleValueClass(val value: Double)
class CustomException(message: String) : Throwable(message)
}