diff --git a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java index d36d59376f..fa84da293d 100644 --- a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java +++ b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java @@ -25,12 +25,13 @@ import kotlin.Unit; import kotlin.coroutines.CoroutineContext; import kotlin.jvm.JvmClassMappingKt; import kotlin.reflect.KClass; -import kotlin.reflect.KClassifier; import kotlin.reflect.KFunction; import kotlin.reflect.KParameter; import kotlin.reflect.KType; import kotlin.reflect.full.KCallables; import kotlin.reflect.full.KClasses; +import kotlin.reflect.full.KClassifiers; +import kotlin.reflect.full.KTypes; import kotlin.reflect.jvm.KCallablesJvm; import kotlin.reflect.jvm.KTypesJvm; import kotlin.reflect.jvm.ReflectJvmMapping; @@ -58,6 +59,12 @@ import org.springframework.util.CollectionUtils; */ public abstract class CoroutinesUtils { + private static final KType flowType = KClassifiers.getStarProjectedType(JvmClassMappingKt.getKotlinClass(Flow.class)); + + private static final KType monoType = KClassifiers.getStarProjectedType(JvmClassMappingKt.getKotlinClass(Mono.class)); + + private static final KType publisherType = KClassifiers.getStarProjectedType(JvmClassMappingKt.getKotlinClass(Publisher.class)); + /** * Convert a {@link Deferred} instance to a {@link Mono}. */ @@ -137,18 +144,15 @@ public abstract class CoroutinesUtils { .filter(result -> result != Unit.INSTANCE) .onErrorMap(InvocationTargetException.class, InvocationTargetException::getTargetException); - KClassifier returnType = function.getReturnType().getClassifier(); - if (returnType != null) { - if (returnType.equals(JvmClassMappingKt.getKotlinClass(Flow.class))) { - return mono.flatMapMany(CoroutinesUtils::asFlux); - } - else if (returnType.equals(JvmClassMappingKt.getKotlinClass(Mono.class))) { - return mono.flatMap(o -> ((Mono)o)); - } - else if (returnType instanceof KClass kClass && - Publisher.class.isAssignableFrom(JvmClassMappingKt.getJavaClass(kClass))) { - return mono.flatMapMany(o -> ((Publisher)o)); - } + KType returnType = function.getReturnType(); + if (KTypes.isSubtypeOf(returnType, flowType)) { + return mono.flatMapMany(CoroutinesUtils::asFlux); + } + else if (KTypes.isSubtypeOf(returnType, monoType)) { + return mono.flatMap(o -> ((Mono)o)); + } + else if (KTypes.isSubtypeOf(returnType, publisherType)) { + return mono.flatMapMany(o -> ((Publisher)o)); } return mono; } diff --git a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt index 2091fbe0dd..ad0dd07f54 100644 --- a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt @@ -97,6 +97,29 @@ class CoroutinesUtilsTests { Assertions.assertThatIllegalArgumentException().isThrownBy { CoroutinesUtils.invokeSuspendingFunction(method, this, "foo") } } + @Test + fun invokeSuspendingFunctionWithMono() { + val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithMono", Continuation::class.java) + val publisher = CoroutinesUtils.invokeSuspendingFunction(method, this) + Assertions.assertThat(publisher).isInstanceOf(Mono::class.java) + StepVerifier.create(publisher) + .expectNext("foo") + .expectComplete() + .verify() + } + + @Test + fun invokeSuspendingFunctionWithFlux() { + val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithFlux", Continuation::class.java) + val publisher = CoroutinesUtils.invokeSuspendingFunction(method, this) + Assertions.assertThat(publisher).isInstanceOf(Flux::class.java) + StepVerifier.create(publisher) + .expectNext("foo") + .expectNext("bar") + .expectComplete() + .verify() + } + @Test fun invokeSuspendingFunctionWithFlow() { val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithFlow", Continuation::class.java) @@ -213,6 +236,16 @@ class CoroutinesUtilsTests { return value } + suspend fun suspendingFunctionWithMono(): Mono { + delay(1) + return Mono.just("foo") + } + + suspend fun suspendingFunctionWithFlux(): Flux { + delay(1) + return Flux.just("foo", "bar") + } + suspend fun suspendingFunctionWithFlow(): Flow { delay(1) return flowOf("foo", "bar")