Move Kotlin value class unboxing to InvocableHandlerMethod

Before this commit, in Spring Framework 6.2, Kotlin value class
unboxing was done at CoroutinesUtils level, which is a good fit
for InvocableHandlerMethod use case, but not for other ones like
AopUtils.

This commit moves such unboxing to InvocableHandlerMethod in
order to keep the HTTP response body support while fixing other
regressions.

Closes gh-33943
This commit is contained in:
Sébastien Deleuze 2024-11-27 16:27:37 +01:00
parent ea3bd7ae0c
commit 1aede291bb
6 changed files with 320 additions and 62 deletions

View File

@ -44,7 +44,6 @@ import kotlinx.coroutines.reactor.ReactorFlowKt;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SynchronousSink;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
@ -109,7 +108,7 @@ public abstract class CoroutinesUtils {
* @throws IllegalArgumentException if {@code method} is not a suspending function
* @since 6.0
*/
@SuppressWarnings({"deprecation", "DataFlowIssue", "NullAway"})
@SuppressWarnings({"DataFlowIssue", "NullAway"})
public static Publisher<?> invokeSuspendingFunction(
CoroutineContext context, Method method, @Nullable Object target, @Nullable Object... args) {
@ -146,7 +145,7 @@ public abstract class CoroutinesUtils {
}
return KCallables.callSuspendBy(function, argMap, continuation);
})
.handle(CoroutinesUtils::handleResult)
.filter(result -> result != Unit.INSTANCE)
.onErrorMap(InvocationTargetException.class, InvocationTargetException::getTargetException);
KType returnType = function.getReturnType();
@ -166,22 +165,4 @@ public abstract class CoroutinesUtils {
return ReactorFlowKt.asFlux(((Flow<?>) flow));
}
private static void handleResult(Object result, SynchronousSink<Object> sink) {
if (result == Unit.INSTANCE) {
sink.complete();
}
else if (KotlinDetector.isInlineClass(result.getClass())) {
try {
sink.next(result.getClass().getDeclaredMethod("unbox-impl").invoke(result));
sink.complete();
}
catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException ex) {
sink.error(ex);
}
}
else {
sink.next(result);
sink.complete();
}
}
}

View File

@ -192,7 +192,7 @@ class CoroutinesUtilsTests {
@Test
fun invokeSuspendingFunctionWithValueClassParameter() {
val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithValueClass") }
val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithValueClassParameter") }
val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, "foo", null) as Mono
runBlocking {
Assertions.assertThat(mono.awaitSingle()).isEqualTo("foo")
@ -204,7 +204,16 @@ class CoroutinesUtilsTests {
val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithValueClassReturnValue") }
val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, null) as Mono
runBlocking {
Assertions.assertThat(mono.awaitSingle()).isEqualTo("foo")
Assertions.assertThat(mono.awaitSingle()).isEqualTo(ValueClass("foo"))
}
}
@Test
fun invokeSuspendingFunctionWithResultOfUnitReturnValue() {
val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithResultOfUnitReturnValue") }
val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, null) as Mono
runBlocking {
Assertions.assertThat(mono.awaitSingle()).isEqualTo(Result.success(Unit))
}
}
@ -314,7 +323,7 @@ class CoroutinesUtilsTests {
return null
}
suspend fun suspendingFunctionWithValueClass(value: ValueClass): String {
suspend fun suspendingFunctionWithValueClassParameter(value: ValueClass): String {
delay(1)
return value.value
}
@ -324,6 +333,11 @@ class CoroutinesUtilsTests {
return ValueClass("foo")
}
suspend fun suspendingFunctionWithResultOfUnitReturnValue(): Result<Unit> {
delay(1)
return Result.success(Unit)
}
suspend fun suspendingFunctionWithValueClassWithInit(value: ValueClassWithInit): String {
delay(1)
return value.value

View File

@ -30,6 +30,8 @@ import kotlin.reflect.KType;
import kotlin.reflect.full.KClasses;
import kotlin.reflect.jvm.KCallablesJvm;
import kotlin.reflect.jvm.ReflectJvmMapping;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SynchronousSink;
import org.springframework.context.MessageSource;
import org.springframework.core.CoroutinesUtils;
@ -288,7 +290,8 @@ public class InvocableHandlerMethod extends HandlerMethod {
* @since 6.0
*/
protected Object invokeSuspendingFunction(Method method, Object target, Object[] args) {
return CoroutinesUtils.invokeSuspendingFunction(method, target, args);
Object result = CoroutinesUtils.invokeSuspendingFunction(method, target, args);
return (result instanceof Mono<?> mono ? mono.handle(KotlinDelegate::handleResult) : result);
}
@ -298,7 +301,7 @@ public class InvocableHandlerMethod extends HandlerMethod {
private static class KotlinDelegate {
@Nullable
@SuppressWarnings({"deprecation", "DataFlowIssue"})
@SuppressWarnings("DataFlowIssue")
public static Object invokeFunction(Method method, Object target, Object[] args) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
// For property accessors
@ -333,10 +336,33 @@ public class InvocableHandlerMethod extends HandlerMethod {
}
Object result = function.callBy(argMap);
if (result != null && KotlinDetector.isInlineClass(result.getClass())) {
return result.getClass().getDeclaredMethod("unbox-impl").invoke(result);
result = unbox(result);
}
return (result == Unit.INSTANCE ? null : result);
}
private static void handleResult(Object result, SynchronousSink<Object> sink) {
if (KotlinDetector.isInlineClass(result.getClass())) {
try {
Object unboxed = unbox(result);
if (unboxed != Unit.INSTANCE) {
sink.next(unboxed);
}
sink.complete();
}
catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException ex) {
sink.error(ex);
}
}
else {
sink.next(result);
sink.complete();
}
}
private static Object unbox(Object result) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
return result.getClass().getDeclaredMethod("unbox-impl").invoke(result);
}
}
}

View File

@ -16,14 +16,18 @@
package org.springframework.web.method.support
import kotlinx.coroutines.delay
import org.assertj.core.api.Assertions
import org.junit.jupiter.api.Test
import org.springframework.core.MethodParameter
import org.springframework.util.ReflectionUtils
import org.springframework.web.bind.support.WebDataBinderFactory
import org.springframework.web.context.request.NativeWebRequest
import org.springframework.web.context.request.ServletWebRequest
import org.springframework.web.testfixture.method.ResolvableMethod
import org.springframework.web.testfixture.servlet.MockHttpServletRequest
import org.springframework.web.testfixture.servlet.MockHttpServletResponse
import reactor.core.publisher.Mono
import reactor.test.StepVerifier
import java.lang.reflect.Method
import kotlin.reflect.jvm.javaGetter
import kotlin.reflect.jvm.javaMethod
@ -33,6 +37,7 @@ import kotlin.reflect.jvm.javaMethod
*
* @author Sebastien Deleuze
*/
@Suppress("UNCHECKED_CAST")
class InvocableHandlerMethodKotlinTests {
private val request: NativeWebRequest = ServletWebRequest(MockHttpServletRequest(), MockHttpServletResponse())
@ -110,6 +115,12 @@ class InvocableHandlerMethodKotlinTests {
Assertions.assertThat(value).isEqualTo("foo")
}
@Test
fun resultOfUnitReturnValue() {
val value = getInvocable(ValueClassHandler::resultOfUnitReturnValue.javaMethod!!).invokeForRequest(request, null)
Assertions.assertThat(value).isNull()
}
@Test
fun valueClassDefaultValue() {
composite.addResolver(StubArgumentResolver(Double::class.java))
@ -138,6 +149,60 @@ class InvocableHandlerMethodKotlinTests {
Assertions.assertThat(value).isEqualTo('a')
}
@Test
fun suspendingValueClass() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
composite.addResolver(StubArgumentResolver(Long::class.java, 1L))
val value = getInvocable(SuspendingValueClassHandler::longValueClass.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<Long>).expectNext(1L).verifyComplete()
}
@Test
fun suspendingValueClassReturnValue() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
val value = getInvocable(SuspendingValueClassHandler::valueClassReturnValue.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<String>).expectNext("foo").verifyComplete()
}
@Test
fun suspendingResultOfUnitReturnValue() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
val value = getInvocable(SuspendingValueClassHandler::resultOfUnitReturnValue.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<Unit>).verifyComplete()
}
@Test
fun suspendingValueClassDefaultValue() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
composite.addResolver(StubArgumentResolver(Double::class.java))
val value = getInvocable(SuspendingValueClassHandler::doubleValueClass.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<Double>).expectNext(3.1).verifyComplete()
}
@Test
fun suspendingValueClassWithInit() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
composite.addResolver(StubArgumentResolver(String::class.java, ""))
val value = getInvocable(SuspendingValueClassHandler::valueClassWithInit.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<String>).verifyError(IllegalArgumentException::class.java)
}
@Test
fun suspendingValueClassWithNullable() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
composite.addResolver(StubArgumentResolver(LongValueClass::class.java, null))
val value = getInvocable(SuspendingValueClassHandler::valueClassWithNullable.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<Long>).verifyComplete()
}
@Test
fun suspendingValueClassWithPrivateConstructor() {
composite.addResolver(ContinuationHandlerMethodArgumentResolver())
composite.addResolver(StubArgumentResolver(Char::class.java, 'a'))
val value = getInvocable(SuspendingValueClassHandler::valueClassWithPrivateConstructor.javaMethod!!).invokeForRequest(request, null)
StepVerifier.create(value as Mono<Char>).expectNext('a').verifyComplete()
}
@Test
fun propertyAccessor() {
val value = getInvocable(PropertyAccessorHandler::prop.javaGetter!!).invokeForRequest(request, null)
@ -206,23 +271,58 @@ class InvocableHandlerMethodKotlinTests {
private class ValueClassHandler {
fun valueClassReturnValue() =
StringValueClass("foo")
fun valueClassReturnValue() = StringValueClass("foo")
fun longValueClass(limit: LongValueClass) =
limit.value
fun resultOfUnitReturnValue() = Result.success(Unit)
fun doubleValueClass(limit: DoubleValueClass = DoubleValueClass(3.1)) =
limit.value
fun longValueClass(limit: LongValueClass) = limit.value
fun valueClassWithInit(valueClass: ValueClassWithInit) =
valueClass
fun doubleValueClass(limit: DoubleValueClass = DoubleValueClass(3.1)) = limit.value
fun valueClassWithNullable(limit: LongValueClass?) =
limit?.value
fun valueClassWithInit(valueClass: ValueClassWithInit) = valueClass
fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor) =
limit.value
fun valueClassWithNullable(limit: LongValueClass?) = limit?.value
fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor) = limit.value
}
private class SuspendingValueClassHandler {
suspend fun valueClassReturnValue(): StringValueClass {
delay(1)
return StringValueClass("foo")
}
suspend fun resultOfUnitReturnValue(): Result<Unit> {
delay(1)
return Result.success(Unit)
}
suspend fun longValueClass(limit: LongValueClass): Long {
delay(1)
return limit.value
}
suspend fun doubleValueClass(limit: DoubleValueClass = DoubleValueClass(3.1)): Double {
delay(1)
return limit.value
}
suspend fun valueClassWithInit(valueClass: ValueClassWithInit): ValueClassWithInit {
delay(1)
return valueClass
}
suspend fun valueClassWithNullable(limit: LongValueClass?): Long? {
delay(1)
return limit?.value
}
suspend fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor): Char {
delay(1)
return limit.value
}
}
private class PropertyAccessorHandler {
@ -282,4 +382,19 @@ class InvocableHandlerMethodKotlinTests {
class CustomException(message: String) : Throwable(message)
// Avoid adding a spring-webmvc dependency
class ContinuationHandlerMethodArgumentResolver : HandlerMethodArgumentResolver {
override fun supportsParameter(parameter: MethodParameter) =
"kotlin.coroutines.Continuation" == parameter.getParameterType().getName()
override fun resolveArgument(
parameter: MethodParameter,
mavContainer: ModelAndViewContainer?,
webRequest: NativeWebRequest,
binderFactory: WebDataBinderFactory?
) = null
}
}

View File

@ -36,6 +36,7 @@ import kotlin.reflect.full.KClasses;
import kotlin.reflect.jvm.KCallablesJvm;
import kotlin.reflect.jvm.ReflectJvmMapping;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SynchronousSink;
import reactor.core.scheduler.Scheduler;
import org.springframework.core.CoroutinesUtils;
@ -323,18 +324,15 @@ public class InvocableHandlerMethod extends HandlerMethod {
private static final String COROUTINE_CONTEXT_ATTRIBUTE = "org.springframework.web.server.CoWebFilter.context";
@Nullable
@SuppressWarnings({"deprecation", "DataFlowIssue"})
@SuppressWarnings("DataFlowIssue")
public static Object invokeFunction(Method method, Object target, Object[] args, boolean isSuspendingFunction,
ServerWebExchange exchange) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
if (isSuspendingFunction) {
Object coroutineContext = exchange.getAttribute(COROUTINE_CONTEXT_ATTRIBUTE);
if (coroutineContext == null) {
return CoroutinesUtils.invokeSuspendingFunction(method, target, args);
}
else {
return CoroutinesUtils.invokeSuspendingFunction((CoroutineContext) coroutineContext, method, target, args);
}
Object result = (coroutineContext == null ? CoroutinesUtils.invokeSuspendingFunction(method, target, args) :
CoroutinesUtils.invokeSuspendingFunction((CoroutineContext) coroutineContext, method, target, args));
return (result instanceof Mono<?> mono ? mono.handle(KotlinDelegate::handleResult) : result);
}
else {
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
@ -370,11 +368,35 @@ public class InvocableHandlerMethod extends HandlerMethod {
}
Object result = function.callBy(argMap);
if (result != null && KotlinDetector.isInlineClass(result.getClass())) {
return result.getClass().getDeclaredMethod("unbox-impl").invoke(result);
result = unbox(result);
}
return (result == Unit.INSTANCE ? null : result);
}
}
private static void handleResult(Object result, SynchronousSink<Object> sink) {
if (KotlinDetector.isInlineClass(result.getClass())) {
try {
Object unboxed = unbox(result);
if (unboxed != Unit.INSTANCE) {
sink.next(unboxed);
}
sink.complete();
}
catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException ex) {
sink.error(ex);
}
}
else {
sink.next(result);
sink.complete();
}
}
private static Object unbox(Object result) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
return result.getClass().getDeclaredMethod("unbox-impl").invoke(result);
}
}
}

View File

@ -208,10 +208,17 @@ class InvocableHandlerMethodKotlinTests {
@Test
fun valueClassReturnValue() {
val method = ValueClassController::valueClassReturnValue.javaMethod!!
val result = invoke(ValueClassController(), method,)
val result = invoke(ValueClassController(), method)
assertHandlerResultValue(result, "foo")
}
@Test
fun resultOfUnitReturnValue() {
val method = ValueClassController::resultOfUnitReturnValue.javaMethod!!
val result = invoke(ValueClassController(), method)
assertHandlerResultValue(result, null)
}
@Test
fun valueClassWithDefaultValue() {
this.resolvers.add(stubResolver(null, Double::class.java))
@ -244,6 +251,60 @@ class InvocableHandlerMethodKotlinTests {
assertHandlerResultValue(result, "1")
}
@Test
fun suspendingValueClass() {
this.resolvers.add(stubResolver(1L, Long::class.java))
val method = SuspendingValueClassController::valueClass.javaMethod!!
val result = invoke(SuspendingValueClassController(), method,1L)
assertHandlerResultValue(result, "1")
}
@Test
fun suspendingValueClassReturnValue() {
val method = SuspendingValueClassController::valueClassReturnValue.javaMethod!!
val result = invoke(SuspendingValueClassController(), method)
assertHandlerResultValue(result, "foo")
}
@Test
fun suspendingResultOfUnitReturnValue() {
val method = SuspendingValueClassController::resultOfUnitReturnValue.javaMethod!!
val result = invoke(SuspendingValueClassController(), method)
assertComplete(result)
}
@Test
fun suspendingValueClassWithDefaultValue() {
this.resolvers.add(stubResolver(null, Double::class.java))
val method = SuspendingValueClassController::valueClassWithDefault.javaMethod!!
val result = invoke(SuspendingValueClassController(), method)
assertHandlerResultValue(result, "3.1")
}
@Test
fun suspendingValueClassWithInit() {
this.resolvers.add(stubResolver("", String::class.java))
val method = SuspendingValueClassController::valueClassWithInit.javaMethod!!
val result = invoke(SuspendingValueClassController(), method)
assertExceptionThrown(result, IllegalArgumentException::class)
}
@Test
fun suspendingValueClassWithNullable() {
this.resolvers.add(stubResolver(null, LongValueClass::class.java))
val method = SuspendingValueClassController::valueClassWithNullable.javaMethod!!
val result = invoke(SuspendingValueClassController(), method, null)
assertHandlerResultValue(result, "null")
}
@Test
fun suspendingValueClassWithPrivateConstructor() {
this.resolvers.add(stubResolver(1L, Long::class.java))
val method = SuspendingValueClassController::valueClassWithPrivateConstructor.javaMethod!!
val result = invoke(SuspendingValueClassController(), method, 1L)
assertHandlerResultValue(result, "1")
}
@Test
fun propertyAccessor() {
this.resolvers.add(stubResolver(null, String::class.java))
@ -313,9 +374,14 @@ class InvocableHandlerMethodKotlinTests {
}
private fun assertExceptionThrown(mono: Mono<HandlerResult>, exceptionClass: KClass<out Throwable>) {
StepVerifier.create(mono).verifyError(exceptionClass.java)
StepVerifier.create(mono.flatMap { t -> t.returnValue as Mono<*> }).verifyError(exceptionClass.java)
}
private fun assertComplete(mono: Mono<HandlerResult>) {
StepVerifier.create(mono.flatMap { t -> t.returnValue as Mono<*> }).verifyComplete()
}
class CoroutinesController {
suspend fun singleArg(q: String?): String {
@ -380,23 +446,57 @@ class InvocableHandlerMethodKotlinTests {
class ValueClassController {
fun valueClass(limit: LongValueClass) =
"${limit.value}"
fun valueClass(limit: LongValueClass) = "${limit.value}"
fun valueClassReturnValue() =
StringValueClass("foo")
fun valueClassReturnValue() = StringValueClass("foo")
fun valueClassWithDefault(limit: DoubleValueClass = DoubleValueClass(3.1)) =
"${limit.value}"
fun resultOfUnitReturnValue() = Result.success(Unit)
fun valueClassWithInit(valueClass: ValueClassWithInit) =
valueClass
fun valueClassWithDefault(limit: DoubleValueClass = DoubleValueClass(3.1)) = "${limit.value}"
fun valueClassWithNullable(limit: LongValueClass?) =
"${limit?.value}"
fun valueClassWithInit(valueClass: ValueClassWithInit) = valueClass
fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor) =
"${limit.value}"
fun valueClassWithNullable(limit: LongValueClass?) = "${limit?.value}"
fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor) = "${limit.value}"
}
class SuspendingValueClassController {
suspend fun valueClass(limit: LongValueClass): String {
delay(1)
return "${limit.value}"
}
suspend fun valueClassReturnValue(): StringValueClass {
delay(1)
return StringValueClass("foo")
}
suspend fun resultOfUnitReturnValue(): Result<Unit> {
delay(1)
return Result.success(Unit)
}
suspend fun valueClassWithDefault(limit: DoubleValueClass = DoubleValueClass(3.1)): String {
delay(1)
return "${limit.value}"
}
suspend fun valueClassWithInit(valueClass: ValueClassWithInit): ValueClassWithInit {
delay(1)
return valueClass
}
suspend fun valueClassWithNullable(limit: LongValueClass?): String {
delay(1)
return "${limit?.value}"
}
suspend fun valueClassWithPrivateConstructor(limit: ValueClassWithPrivateConstructor): String {
delay(1)
return "${limit.value}"
}
}
class PropertyAccessorController {