Refine Coroutines annotated controller support

This commit refines Coroutines annotated controller support
by considering Kotlin Unit as Java void and using the right
ReactiveAdapter to support all use cases, including suspending
functions that return Flow (usual when using APIs like WebClient).

It also fixes RSocket fire and forget handling and adds related tests
for that use case.

Closes gh-24057
Closes gh-23866
This commit is contained in:
Sébastien Deleuze 2019-11-22 12:11:07 +01:00
parent 21b2fc1f01
commit 712eac2915
5 changed files with 93 additions and 21 deletions

View File

@ -26,6 +26,7 @@ import kotlinx.coroutines.reactive.awaitFirstOrNull
import kotlinx.coroutines.reactor.asFlux
import kotlinx.coroutines.reactor.mono
import org.reactivestreams.Publisher
import reactor.core.publisher.Mono
import java.lang.reflect.InvocationTargetException
import java.lang.reflect.Method
@ -51,28 +52,29 @@ internal fun <T: Any> monoToDeferred(source: Mono<T>) =
GlobalScope.async(Dispatchers.Unconfined) { source.awaitFirstOrNull() }
/**
* Invoke a suspending function converting it to [Mono] or [reactor.core.publisher.Flux]
* if necessary.
* Return {@code true} if the method is a suspending function.
*
* @author Sebastien Deleuze
* @since 5.2.2
*/
internal fun isSuspendingFunction(method: Method) = method.kotlinFunction!!.isSuspend
/**
* Invoke a suspending function and converts it to [Mono] or [reactor.core.publisher.Flux].
*
* @author Sebastien Deleuze
* @since 5.2
*/
@Suppress("UNCHECKED_CAST")
internal fun invokeSuspendingFunction(method: Method, bean: Any, vararg args: Any?): Any? {
internal fun invokeSuspendingFunction(method: Method, bean: Any, vararg args: Any?): Publisher<*> {
val function = method.kotlinFunction!!
return if (function.isSuspend) {
val mono = mono(Dispatchers.Unconfined) {
function.callSuspend(bean, *args.sliceArray(0..(args.size-2)))
.let { if (it == Unit) null else it }
}.onErrorMap(InvocationTargetException::class.java) { it.targetException }
if (function.returnType.classifier == Flow::class) {
mono.flatMapMany { (it as Flow<Any>).asFlux() }
}
else {
mono
}
val mono = mono(Dispatchers.Unconfined) {
function.callSuspend(bean, *args.sliceArray(0..(args.size-2))).let { if (it == Unit) null else it }
}.onErrorMap(InvocationTargetException::class.java) { it.targetException }
return if (function.returnType.classifier == Flow::class) {
mono.flatMapMany { (it as Flow<Any>).asFlux() }
}
else {
function.call(bean, *args)
mono
}
}

View File

@ -30,6 +30,7 @@ import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import kotlin.Unit;
import kotlin.reflect.KFunction;
import kotlin.reflect.KParameter;
import kotlin.reflect.jvm.ReflectJvmMapping;
@ -929,6 +930,9 @@ public class MethodParameter {
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
if (function != null && function.isSuspend()) {
Type paramType = ReflectJvmMapping.getJavaType(function.getReturnType());
if (paramType == Unit.class) {
paramType = void.class;
}
return ResolvableType.forType(paramType).resolve(method.getReturnType());
}
}

View File

@ -127,10 +127,13 @@ public class InvocableHandlerMethod extends HandlerMethod {
public Mono<Object> invoke(Message<?> message, Object... providedArgs) {
return getMethodArgumentValues(message, providedArgs).flatMap(args -> {
Object value;
boolean isSuspendingFunction = false;
try {
Method method = getBridgedMethod();
ReflectionUtils.makeAccessible(method);
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) {
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())
&& CoroutinesUtils.isSuspendingFunction(method)) {
isSuspendingFunction = true;
value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args);
}
else {
@ -151,7 +154,8 @@ public class InvocableHandlerMethod extends HandlerMethod {
}
MethodParameter returnType = getReturnType();
ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(returnType.getParameterType());
Class<?> reactiveType = (isSuspendingFunction ? value.getClass() : returnType.getParameterType());
ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(reactiveType);
return (isAsyncVoidReturnType(returnType, adapter) ?
Mono.from(adapter.toPublisher(value)) : Mono.justOrEmpty(value));
});

View File

@ -39,6 +39,7 @@ import org.springframework.messaging.handler.annotation.MessageMapping
import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler
import org.springframework.stereotype.Controller
import reactor.core.publisher.Flux
import reactor.core.publisher.ReplayProcessor
import reactor.test.StepVerifier
import java.time.Duration
@ -50,6 +51,34 @@ import java.time.Duration
*/
class RSocketClientToServerCoroutinesIntegrationTests {
@Test
fun fireAndForget() {
Flux.range(1, 3)
.concatMap { requester.route("receive").data("Hello $it").send() }
.blockLast()
StepVerifier.create(context.getBean(ServerController::class.java).fireForgetPayloads)
.expectNext("Hello 1")
.expectNext("Hello 2")
.expectNext("Hello 3")
.thenAwait(Duration.ofMillis(50))
.thenCancel()
.verify(Duration.ofSeconds(5))
}
@Test
fun fireAndForgetAsync() {
Flux.range(1, 3)
.concatMap { i: Int -> requester.route("receive-async").data("Hello $i").send() }
.blockLast()
StepVerifier.create(context.getBean(ServerController::class.java).fireForgetPayloads)
.expectNext("Hello 1")
.expectNext("Hello 2")
.expectNext("Hello 3")
.thenAwait(Duration.ofMillis(50))
.thenCancel()
.verify(Duration.ofSeconds(5))
}
@Test
fun echoAsync() {
val result = Flux.range(1, 3).concatMap { i -> requester.route("echo-async").data("Hello " + i!!).retrieveMono(String::class.java) }
@ -70,6 +99,16 @@ class RSocketClientToServerCoroutinesIntegrationTests {
.verify(Duration.ofSeconds(5))
}
@Test
fun echoStreamAsync() {
val result = requester.route("echo-stream-async").data("Hello").retrieveFlux(String::class.java)
StepVerifier.create(result)
.expectNext("Hello 0").expectNextCount(6).expectNext("Hello 7")
.thenCancel()
.verify(Duration.ofSeconds(5))
}
@Test
fun echoChannel() {
val result = requester.route("echo-channel")
@ -106,6 +145,19 @@ class RSocketClientToServerCoroutinesIntegrationTests {
@Controller
class ServerController {
val fireForgetPayloads = ReplayProcessor.create<String>()
@MessageMapping("receive")
fun receive(payload: String) {
fireForgetPayloads.onNext(payload)
}
@MessageMapping("receive-async")
suspend fun receiveAsync(payload: String) {
delay(10)
fireForgetPayloads.onNext(payload)
}
@MessageMapping("echo-async")
suspend fun echoAsync(payload: String): String {
delay(10)
@ -123,6 +175,18 @@ class RSocketClientToServerCoroutinesIntegrationTests {
}
}
@MessageMapping("echo-stream-async")
suspend fun echoStreamAsync(payload: String): Flow<String> {
delay(10)
var i = 0
return flow {
while(true) {
delay(10)
emit("$payload ${i++}")
}
}
}
@MessageMapping("echo-channel")
fun echoChannel(payloads: Flow<String>) = payloads.map {
delay(10)
@ -185,8 +249,6 @@ class RSocketClientToServerCoroutinesIntegrationTests {
private lateinit var server: CloseableChannel
private val interceptor = FireAndForgetCountingInterceptor()
private lateinit var requester: RSocketRequester
@ -196,7 +258,6 @@ class RSocketClientToServerCoroutinesIntegrationTests {
context = AnnotationConfigApplicationContext(ServerConfig::class.java)
server = RSocketFactory.receive()
.addResponderPlugin(interceptor)
.frameDecoder(PayloadDecoder.ZERO_COPY)
.acceptor(context.getBean(RSocketMessageHandler::class.java).responder())
.transport(TcpServerTransport.create("localhost", 7000))

View File

@ -139,7 +139,8 @@ public class InvocableHandlerMethod extends HandlerMethod {
try {
ReflectionUtils.makeAccessible(getBridgedMethod());
Method method = getBridgedMethod();
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) {
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())
&& CoroutinesUtils.isSuspendingFunction(method)) {
value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args);
}
else {