From 85cb6cc5fb337ec14505b0903834e73da6b45579 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Thu, 21 Dec 2023 12:15:26 +0100 Subject: [PATCH] Support Kotlin extensions in web handlers This commit restores support for Kotlin extensions in web handlers, and adds support for invoking reflectively suspending extension functions, as well as the other features supported as of Spring Framework 6.1 like value classes and default value for parameters. Closes gh-31876 --- .../springframework/core/CoroutinesUtils.java | 10 ++- .../core/CoroutinesUtilsTests.kt | 32 +++++++++ .../support/InvocableHandlerMethod.java | 10 ++- .../InvocableHandlerMethodKotlinTests.kt | 29 +++++++++ .../result/method/InvocableHandlerMethod.java | 9 ++- .../InvocableHandlerMethodKotlinTests.kt | 65 ++++++++++++++----- 6 files changed, 130 insertions(+), 25 deletions(-) diff --git a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java index e6de675a306..8c78ceac9be 100644 --- a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java +++ b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java @@ -117,8 +117,11 @@ public abstract class CoroutinesUtils { int index = 0; for (KParameter parameter : function.getParameters()) { switch (parameter.getKind()) { - case INSTANCE -> argMap.put(parameter, target); - case VALUE -> { + case INSTANCE: + argMap.put(parameter, target); + break; + case VALUE: + case EXTENSION_RECEIVER: if (!parameter.isOptional() || args[index] != null) { if (parameter.getType().getClassifier() instanceof KClass kClass && kClass.isValue()) { Class javaClass = JvmClassMappingKt.getJavaClass(kClass); @@ -131,7 +134,8 @@ public abstract class CoroutinesUtils { } } index++; - } + break; + } } return KCallables.callSuspendBy(function, argMap, continuation); diff --git a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt index b5fad73d9cc..fdce5caf8d8 100644 --- a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt @@ -154,6 +154,26 @@ class CoroutinesUtilsTests { } } + @Test + fun invokeSuspendingFunctionWithExtension() { + val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithExtension", + CustomException::class.java, Continuation::class.java) + val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, CustomException("foo")) as Mono + runBlocking { + Assertions.assertThat(mono.awaitSingleOrNull()).isEqualTo("foo") + } + } + + @Test + fun invokeSuspendingFunctionWithExtensionAndParameter() { + val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithExtensionAndParameter", + CustomException::class.java, Int::class.java, Continuation::class.java) + val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, CustomException("foo"), 20) as Mono + runBlocking { + Assertions.assertThat(mono.awaitSingleOrNull()).isEqualTo("foo-20") + } + } + suspend fun suspendingFunction(value: String): String { delay(1) return value @@ -186,7 +206,19 @@ class CoroutinesUtilsTests { return value.value } + suspend fun CustomException.suspendingFunctionWithExtension(): String { + delay(1) + return "${this.message}" + } + + suspend fun CustomException.suspendingFunctionWithExtensionAndParameter(limit: Int): String { + delay(1) + return "${this.message}-$limit" + } + @JvmInline value class ValueClass(val value: String) + class CustomException(message: String) : Throwable(message) + } diff --git a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java index ed160be6f55..303373de389 100644 --- a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java +++ b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java @@ -318,8 +318,11 @@ public class InvocableHandlerMethod extends HandlerMethod { int index = 0; for (KParameter parameter : function.getParameters()) { switch (parameter.getKind()) { - case INSTANCE -> argMap.put(parameter, target); - case VALUE -> { + case INSTANCE: + argMap.put(parameter, target); + break; + case VALUE: + case EXTENSION_RECEIVER: if (!parameter.isOptional() || args[index] != null) { if (parameter.getType().getClassifier() instanceof KClass kClass && kClass.isValue()) { Class javaClass = JvmClassMappingKt.getJavaClass(kClass); @@ -332,7 +335,8 @@ public class InvocableHandlerMethod extends HandlerMethod { } } index++; - } + break; + } } Object result = function.callBy(argMap); diff --git a/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt b/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt index 7ee5d15c88d..2ba7f249d42 100644 --- a/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt @@ -104,6 +104,22 @@ class InvocableHandlerMethodKotlinTests { Assertions.assertThat(value).isEqualTo("foo") } + @Test + fun extension() { + composite.addResolver(StubArgumentResolver(CustomException::class.java, CustomException("foo"))) + val value = getInvocable(ExtensionHandler::class.java, CustomException::class.java).invokeForRequest(request, null) + Assertions.assertThat(value).isEqualTo("foo") + } + + @Test + fun extensionWithParameter() { + composite.addResolver(StubArgumentResolver(CustomException::class.java, CustomException("foo"))) + composite.addResolver(StubArgumentResolver(Int::class.java, 20)) + val value = getInvocable(ExtensionHandler::class.java, CustomException::class.java, Int::class.java) + .invokeForRequest(request, null) + Assertions.assertThat(value).isEqualTo("foo-20") + } + private fun getInvocable(clazz: Class<*>, vararg argTypes: Class<*>): InvocableHandlerMethod { val method = ResolvableMethod.on(clazz).argTypes(*argTypes).resolveMethod() val handlerMethod = InvocableHandlerMethod(clazz.constructors.first().newInstance(), method) @@ -150,10 +166,23 @@ class InvocableHandlerMethodKotlinTests { get() = "foo" } + private class ExtensionHandler { + + fun CustomException.handle(): String { + return "${this.message}" + } + + fun CustomException.handleWithParameter(limit: Int): String { + return "${this.message}-$limit" + } + } + @JvmInline value class LongValueClass(val value: Long) @JvmInline value class DoubleValueClass(val value: Double) + class CustomException(message: String) : Throwable(message) + } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java index 490e897091d..0d4c2f47dd7 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java @@ -329,8 +329,11 @@ public class InvocableHandlerMethod extends HandlerMethod { int index = 0; for (KParameter parameter : function.getParameters()) { switch (parameter.getKind()) { - case INSTANCE -> argMap.put(parameter, target); - case VALUE -> { + case INSTANCE: + argMap.put(parameter, target); + break; + case VALUE: + case EXTENSION_RECEIVER: if (!parameter.isOptional() || args[index] != null) { if (parameter.getType().getClassifier() instanceof KClass kClass && kClass.isValue()) { Class javaClass = JvmClassMappingKt.getJavaClass(kClass); @@ -343,7 +346,7 @@ public class InvocableHandlerMethod extends HandlerMethod { } } index++; - } + break; } } Object result = function.callBy(argMap); diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt index 57ef009b935..61daeb51400 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt @@ -21,6 +21,7 @@ import io.mockk.mockk import kotlinx.coroutines.delay import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test +import org.springframework.core.MethodParameter import org.springframework.core.ReactiveAdapterRegistry import org.springframework.http.HttpStatus import org.springframework.http.server.reactive.ServerHttpResponse @@ -34,6 +35,7 @@ import org.springframework.web.reactive.result.method.InvocableHandlerMethod import org.springframework.web.reactive.result.method.annotation.ContinuationHandlerMethodArgumentResolver import org.springframework.web.reactive.result.method.annotation.RequestParamMethodArgumentResolver import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.get +import org.springframework.web.testfixture.method.ResolvableMethod import org.springframework.web.testfixture.server.MockServerWebExchange import reactor.core.publisher.Mono import reactor.test.StepVerifier @@ -55,7 +57,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun resolveNoArg() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, String::class.java)) val method = CoroutinesController::singleArg.javaMethod!! val result = invoke(CoroutinesController(), method, null) assertHandlerResultValue(result, "success:null") @@ -116,7 +118,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun defaultValue() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, String::class.java)) val method = DefaultValueController::handle.javaMethod!! val result = invoke(DefaultValueController(), method) assertHandlerResultValue(result, "default") @@ -124,7 +126,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun defaultValueOverridden() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, String::class.java)) val method = DefaultValueController::handle.javaMethod!! exchange = MockServerWebExchange.from(get("http://localhost:8080/path").queryParam("value", "override")) val result = invoke(DefaultValueController(), method) @@ -133,7 +135,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun defaultValues() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, Int::class.java)) val method = DefaultValueController::handleMultiple.javaMethod!! val result = invoke(DefaultValueController(), method) assertHandlerResultValue(result, "10-20") @@ -141,7 +143,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun defaultValuesOverridden() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, Int::class.java)) val method = DefaultValueController::handleMultiple.javaMethod!! exchange = MockServerWebExchange.from(get("http://localhost:8080/path").queryParam("limit2", "40")) val result = invoke(DefaultValueController(), method) @@ -150,7 +152,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun suspendingDefaultValue() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, String::class.java)) val method = DefaultValueController::handleSuspending.javaMethod!! val result = invoke(DefaultValueController(), method) assertHandlerResultValue(result, "default") @@ -158,7 +160,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun suspendingDefaultValueOverridden() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, String::class.java)) val method = DefaultValueController::handleSuspending.javaMethod!! exchange = MockServerWebExchange.from(get("http://localhost:8080/path").queryParam("value", "override")) val result = invoke(DefaultValueController(), method) @@ -181,7 +183,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun valueClass() { - this.resolvers.add(stubResolver(1L)) + this.resolvers.add(stubResolver(1L, Long::class.java)) val method = ValueClassController::valueClass.javaMethod!! val result = invoke(ValueClassController(), method,1L) assertHandlerResultValue(result, "1") @@ -189,7 +191,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun valueClassDefaultValue() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, Double::class.java)) val method = ValueClassController::valueClassWithDefault.javaMethod!! val result = invoke(ValueClassController(), method) assertHandlerResultValue(result, "3.1") @@ -197,12 +199,31 @@ class InvocableHandlerMethodKotlinTests { @Test fun propertyAccessor() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, String::class.java)) val method = PropertyAccessorController::prop.getter.javaMethod!! val result = invoke(PropertyAccessorController(), method) assertHandlerResultValue(result, "foo") } + @Test + fun extension() { + this.resolvers.add(stubResolver(CustomException("foo"))) + val method = ResolvableMethod.on(ExtensionHandler::class.java).argTypes(CustomException::class.java).resolveMethod() + val result = invoke(ExtensionHandler(), method) + assertHandlerResultValue(result, "foo") + } + + @Test + fun extensionWithParameter() { + this.resolvers.add(stubResolver(CustomException("foo"))) + this.resolvers.add(stubResolver(20, Int::class.java)) + val method = ResolvableMethod.on(ExtensionHandler::class.java) + .argTypes(CustomException::class.java, Int::class.javaPrimitiveType) + .resolveMethod() + val result = invoke(ExtensionHandler(), method) + assertHandlerResultValue(result, "foo-20") + } + private fun invokeForResult(handler: Any, method: Method, vararg providedArgs: Any): HandlerResult? { return invoke(handler, method, *providedArgs).block(Duration.ofSeconds(5)) @@ -214,14 +235,13 @@ class InvocableHandlerMethodKotlinTests { return invocable.invoke(this.exchange, BindingContext(), *providedArgs) } - private fun stubResolver(stubValue: Any?): HandlerMethodArgumentResolver { - return stubResolver(Mono.justOrEmpty(stubValue)) - } + private fun stubResolver(stubValue: Any): HandlerMethodArgumentResolver = + stubResolver(stubValue, stubValue::class.java) - private fun stubResolver(stubValue: Mono): HandlerMethodArgumentResolver { + private fun stubResolver(stubValue: Any?, stubClass: Class<*>): HandlerMethodArgumentResolver { val resolver = mockk() - every { resolver.supportsParameter(any()) } returns true - every { resolver.resolveArgument(any(), any(), any()) } returns stubValue + every { resolver.supportsParameter(any()) } answers { (it.invocation.args[0] as MethodParameter).getParameterType() == stubClass } + every { resolver.resolveArgument(any(), any(), any()) } returns Mono.justOrEmpty(stubValue) return resolver } @@ -309,9 +329,22 @@ class InvocableHandlerMethodKotlinTests { get() = "foo" } + class ExtensionHandler { + + fun CustomException.handle(): String { + return "${this.message}" + } + + fun CustomException.handleWithParameter(limit: Int): String { + return "${this.message}-$limit" + } + } + @JvmInline value class LongValueClass(val value: Long) @JvmInline value class DoubleValueClass(val value: Double) + + class CustomException(message: String) : Throwable(message) } \ No newline at end of file