Refine Coroutines annotated controller support

This commit refines Coroutines annotated controller support
by considering Kotlin Unit as Java void and using the right
ReactiveAdapter to support all use cases, including suspending
functions that return Flow (usual when using APIs like WebClient).

It also fixes RSocket fire and forget handling and adds related tests
for that use case.

Closes gh-24057
Closes gh-23866
This commit is contained in:
Sébastien Deleuze 2019-11-22 12:11:07 +01:00
parent 21b2fc1f01
commit 712eac2915
5 changed files with 93 additions and 21 deletions

View File

@ -26,6 +26,7 @@ import kotlinx.coroutines.reactive.awaitFirstOrNull
import kotlinx.coroutines.reactor.asFlux import kotlinx.coroutines.reactor.asFlux
import kotlinx.coroutines.reactor.mono import kotlinx.coroutines.reactor.mono
import org.reactivestreams.Publisher
import reactor.core.publisher.Mono import reactor.core.publisher.Mono
import java.lang.reflect.InvocationTargetException import java.lang.reflect.InvocationTargetException
import java.lang.reflect.Method import java.lang.reflect.Method
@ -51,28 +52,29 @@ internal fun <T: Any> monoToDeferred(source: Mono<T>) =
GlobalScope.async(Dispatchers.Unconfined) { source.awaitFirstOrNull() } GlobalScope.async(Dispatchers.Unconfined) { source.awaitFirstOrNull() }
/** /**
* Invoke a suspending function converting it to [Mono] or [reactor.core.publisher.Flux] * Return {@code true} if the method is a suspending function.
* if necessary. *
* @author Sebastien Deleuze
* @since 5.2.2
*/
internal fun isSuspendingFunction(method: Method) = method.kotlinFunction!!.isSuspend
/**
* Invoke a suspending function and converts it to [Mono] or [reactor.core.publisher.Flux].
* *
* @author Sebastien Deleuze * @author Sebastien Deleuze
* @since 5.2 * @since 5.2
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
internal fun invokeSuspendingFunction(method: Method, bean: Any, vararg args: Any?): Any? { internal fun invokeSuspendingFunction(method: Method, bean: Any, vararg args: Any?): Publisher<*> {
val function = method.kotlinFunction!! val function = method.kotlinFunction!!
return if (function.isSuspend) { val mono = mono(Dispatchers.Unconfined) {
val mono = mono(Dispatchers.Unconfined) { function.callSuspend(bean, *args.sliceArray(0..(args.size-2))).let { if (it == Unit) null else it }
function.callSuspend(bean, *args.sliceArray(0..(args.size-2))) }.onErrorMap(InvocationTargetException::class.java) { it.targetException }
.let { if (it == Unit) null else it } return if (function.returnType.classifier == Flow::class) {
}.onErrorMap(InvocationTargetException::class.java) { it.targetException } mono.flatMapMany { (it as Flow<Any>).asFlux() }
if (function.returnType.classifier == Flow::class) {
mono.flatMapMany { (it as Flow<Any>).asFlux() }
}
else {
mono
}
} }
else { else {
function.call(bean, *args) mono
} }
} }

View File

@ -30,6 +30,7 @@ import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.function.Predicate; import java.util.function.Predicate;
import kotlin.Unit;
import kotlin.reflect.KFunction; import kotlin.reflect.KFunction;
import kotlin.reflect.KParameter; import kotlin.reflect.KParameter;
import kotlin.reflect.jvm.ReflectJvmMapping; import kotlin.reflect.jvm.ReflectJvmMapping;
@ -929,6 +930,9 @@ public class MethodParameter {
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method); KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
if (function != null && function.isSuspend()) { if (function != null && function.isSuspend()) {
Type paramType = ReflectJvmMapping.getJavaType(function.getReturnType()); Type paramType = ReflectJvmMapping.getJavaType(function.getReturnType());
if (paramType == Unit.class) {
paramType = void.class;
}
return ResolvableType.forType(paramType).resolve(method.getReturnType()); return ResolvableType.forType(paramType).resolve(method.getReturnType());
} }
} }

View File

@ -127,10 +127,13 @@ public class InvocableHandlerMethod extends HandlerMethod {
public Mono<Object> invoke(Message<?> message, Object... providedArgs) { public Mono<Object> invoke(Message<?> message, Object... providedArgs) {
return getMethodArgumentValues(message, providedArgs).flatMap(args -> { return getMethodArgumentValues(message, providedArgs).flatMap(args -> {
Object value; Object value;
boolean isSuspendingFunction = false;
try { try {
Method method = getBridgedMethod(); Method method = getBridgedMethod();
ReflectionUtils.makeAccessible(method); ReflectionUtils.makeAccessible(method);
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) { if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())
&& CoroutinesUtils.isSuspendingFunction(method)) {
isSuspendingFunction = true;
value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args); value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args);
} }
else { else {
@ -151,7 +154,8 @@ public class InvocableHandlerMethod extends HandlerMethod {
} }
MethodParameter returnType = getReturnType(); MethodParameter returnType = getReturnType();
ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(returnType.getParameterType()); Class<?> reactiveType = (isSuspendingFunction ? value.getClass() : returnType.getParameterType());
ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(reactiveType);
return (isAsyncVoidReturnType(returnType, adapter) ? return (isAsyncVoidReturnType(returnType, adapter) ?
Mono.from(adapter.toPublisher(value)) : Mono.justOrEmpty(value)); Mono.from(adapter.toPublisher(value)) : Mono.justOrEmpty(value));
}); });

View File

@ -39,6 +39,7 @@ import org.springframework.messaging.handler.annotation.MessageMapping
import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler
import org.springframework.stereotype.Controller import org.springframework.stereotype.Controller
import reactor.core.publisher.Flux import reactor.core.publisher.Flux
import reactor.core.publisher.ReplayProcessor
import reactor.test.StepVerifier import reactor.test.StepVerifier
import java.time.Duration import java.time.Duration
@ -50,6 +51,34 @@ import java.time.Duration
*/ */
class RSocketClientToServerCoroutinesIntegrationTests { class RSocketClientToServerCoroutinesIntegrationTests {
@Test
fun fireAndForget() {
Flux.range(1, 3)
.concatMap { requester.route("receive").data("Hello $it").send() }
.blockLast()
StepVerifier.create(context.getBean(ServerController::class.java).fireForgetPayloads)
.expectNext("Hello 1")
.expectNext("Hello 2")
.expectNext("Hello 3")
.thenAwait(Duration.ofMillis(50))
.thenCancel()
.verify(Duration.ofSeconds(5))
}
@Test
fun fireAndForgetAsync() {
Flux.range(1, 3)
.concatMap { i: Int -> requester.route("receive-async").data("Hello $i").send() }
.blockLast()
StepVerifier.create(context.getBean(ServerController::class.java).fireForgetPayloads)
.expectNext("Hello 1")
.expectNext("Hello 2")
.expectNext("Hello 3")
.thenAwait(Duration.ofMillis(50))
.thenCancel()
.verify(Duration.ofSeconds(5))
}
@Test @Test
fun echoAsync() { fun echoAsync() {
val result = Flux.range(1, 3).concatMap { i -> requester.route("echo-async").data("Hello " + i!!).retrieveMono(String::class.java) } val result = Flux.range(1, 3).concatMap { i -> requester.route("echo-async").data("Hello " + i!!).retrieveMono(String::class.java) }
@ -70,6 +99,16 @@ class RSocketClientToServerCoroutinesIntegrationTests {
.verify(Duration.ofSeconds(5)) .verify(Duration.ofSeconds(5))
} }
@Test
fun echoStreamAsync() {
val result = requester.route("echo-stream-async").data("Hello").retrieveFlux(String::class.java)
StepVerifier.create(result)
.expectNext("Hello 0").expectNextCount(6).expectNext("Hello 7")
.thenCancel()
.verify(Duration.ofSeconds(5))
}
@Test @Test
fun echoChannel() { fun echoChannel() {
val result = requester.route("echo-channel") val result = requester.route("echo-channel")
@ -106,6 +145,19 @@ class RSocketClientToServerCoroutinesIntegrationTests {
@Controller @Controller
class ServerController { class ServerController {
val fireForgetPayloads = ReplayProcessor.create<String>()
@MessageMapping("receive")
fun receive(payload: String) {
fireForgetPayloads.onNext(payload)
}
@MessageMapping("receive-async")
suspend fun receiveAsync(payload: String) {
delay(10)
fireForgetPayloads.onNext(payload)
}
@MessageMapping("echo-async") @MessageMapping("echo-async")
suspend fun echoAsync(payload: String): String { suspend fun echoAsync(payload: String): String {
delay(10) delay(10)
@ -123,6 +175,18 @@ class RSocketClientToServerCoroutinesIntegrationTests {
} }
} }
@MessageMapping("echo-stream-async")
suspend fun echoStreamAsync(payload: String): Flow<String> {
delay(10)
var i = 0
return flow {
while(true) {
delay(10)
emit("$payload ${i++}")
}
}
}
@MessageMapping("echo-channel") @MessageMapping("echo-channel")
fun echoChannel(payloads: Flow<String>) = payloads.map { fun echoChannel(payloads: Flow<String>) = payloads.map {
delay(10) delay(10)
@ -185,8 +249,6 @@ class RSocketClientToServerCoroutinesIntegrationTests {
private lateinit var server: CloseableChannel private lateinit var server: CloseableChannel
private val interceptor = FireAndForgetCountingInterceptor()
private lateinit var requester: RSocketRequester private lateinit var requester: RSocketRequester
@ -196,7 +258,6 @@ class RSocketClientToServerCoroutinesIntegrationTests {
context = AnnotationConfigApplicationContext(ServerConfig::class.java) context = AnnotationConfigApplicationContext(ServerConfig::class.java)
server = RSocketFactory.receive() server = RSocketFactory.receive()
.addResponderPlugin(interceptor)
.frameDecoder(PayloadDecoder.ZERO_COPY) .frameDecoder(PayloadDecoder.ZERO_COPY)
.acceptor(context.getBean(RSocketMessageHandler::class.java).responder()) .acceptor(context.getBean(RSocketMessageHandler::class.java).responder())
.transport(TcpServerTransport.create("localhost", 7000)) .transport(TcpServerTransport.create("localhost", 7000))

View File

@ -139,7 +139,8 @@ public class InvocableHandlerMethod extends HandlerMethod {
try { try {
ReflectionUtils.makeAccessible(getBridgedMethod()); ReflectionUtils.makeAccessible(getBridgedMethod());
Method method = getBridgedMethod(); Method method = getBridgedMethod();
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) { if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())
&& CoroutinesUtils.isSuspendingFunction(method)) {
value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args); value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args);
} }
else { else {