Refine Coroutines annotated controller support
This commit refines Coroutines annotated controller support by considering Kotlin Unit as Java void and using the right ReactiveAdapter to support all use cases, including suspending functions that return Flow (usual when using APIs like WebClient). It also fixes RSocket fire and forget handling and adds related tests for that use case. Closes gh-24057 Closes gh-23866
This commit is contained in:
parent
21b2fc1f01
commit
712eac2915
|
@ -26,6 +26,7 @@ import kotlinx.coroutines.reactive.awaitFirstOrNull
|
|||
import kotlinx.coroutines.reactor.asFlux
|
||||
|
||||
import kotlinx.coroutines.reactor.mono
|
||||
import org.reactivestreams.Publisher
|
||||
import reactor.core.publisher.Mono
|
||||
import java.lang.reflect.InvocationTargetException
|
||||
import java.lang.reflect.Method
|
||||
|
@ -51,28 +52,29 @@ internal fun <T: Any> monoToDeferred(source: Mono<T>) =
|
|||
GlobalScope.async(Dispatchers.Unconfined) { source.awaitFirstOrNull() }
|
||||
|
||||
/**
|
||||
* Invoke a suspending function converting it to [Mono] or [reactor.core.publisher.Flux]
|
||||
* if necessary.
|
||||
* Return {@code true} if the method is a suspending function.
|
||||
*
|
||||
* @author Sebastien Deleuze
|
||||
* @since 5.2.2
|
||||
*/
|
||||
internal fun isSuspendingFunction(method: Method) = method.kotlinFunction!!.isSuspend
|
||||
|
||||
/**
|
||||
* Invoke a suspending function and converts it to [Mono] or [reactor.core.publisher.Flux].
|
||||
*
|
||||
* @author Sebastien Deleuze
|
||||
* @since 5.2
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
internal fun invokeSuspendingFunction(method: Method, bean: Any, vararg args: Any?): Any? {
|
||||
internal fun invokeSuspendingFunction(method: Method, bean: Any, vararg args: Any?): Publisher<*> {
|
||||
val function = method.kotlinFunction!!
|
||||
return if (function.isSuspend) {
|
||||
val mono = mono(Dispatchers.Unconfined) {
|
||||
function.callSuspend(bean, *args.sliceArray(0..(args.size-2)))
|
||||
.let { if (it == Unit) null else it }
|
||||
}.onErrorMap(InvocationTargetException::class.java) { it.targetException }
|
||||
if (function.returnType.classifier == Flow::class) {
|
||||
mono.flatMapMany { (it as Flow<Any>).asFlux() }
|
||||
}
|
||||
else {
|
||||
mono
|
||||
}
|
||||
val mono = mono(Dispatchers.Unconfined) {
|
||||
function.callSuspend(bean, *args.sliceArray(0..(args.size-2))).let { if (it == Unit) null else it }
|
||||
}.onErrorMap(InvocationTargetException::class.java) { it.targetException }
|
||||
return if (function.returnType.classifier == Flow::class) {
|
||||
mono.flatMapMany { (it as Flow<Any>).asFlux() }
|
||||
}
|
||||
else {
|
||||
function.call(bean, *args)
|
||||
mono
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ import java.util.Map;
|
|||
import java.util.Optional;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
import kotlin.Unit;
|
||||
import kotlin.reflect.KFunction;
|
||||
import kotlin.reflect.KParameter;
|
||||
import kotlin.reflect.jvm.ReflectJvmMapping;
|
||||
|
@ -929,6 +930,9 @@ public class MethodParameter {
|
|||
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
|
||||
if (function != null && function.isSuspend()) {
|
||||
Type paramType = ReflectJvmMapping.getJavaType(function.getReturnType());
|
||||
if (paramType == Unit.class) {
|
||||
paramType = void.class;
|
||||
}
|
||||
return ResolvableType.forType(paramType).resolve(method.getReturnType());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -127,10 +127,13 @@ public class InvocableHandlerMethod extends HandlerMethod {
|
|||
public Mono<Object> invoke(Message<?> message, Object... providedArgs) {
|
||||
return getMethodArgumentValues(message, providedArgs).flatMap(args -> {
|
||||
Object value;
|
||||
boolean isSuspendingFunction = false;
|
||||
try {
|
||||
Method method = getBridgedMethod();
|
||||
ReflectionUtils.makeAccessible(method);
|
||||
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) {
|
||||
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())
|
||||
&& CoroutinesUtils.isSuspendingFunction(method)) {
|
||||
isSuspendingFunction = true;
|
||||
value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args);
|
||||
}
|
||||
else {
|
||||
|
@ -151,7 +154,8 @@ public class InvocableHandlerMethod extends HandlerMethod {
|
|||
}
|
||||
|
||||
MethodParameter returnType = getReturnType();
|
||||
ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(returnType.getParameterType());
|
||||
Class<?> reactiveType = (isSuspendingFunction ? value.getClass() : returnType.getParameterType());
|
||||
ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(reactiveType);
|
||||
return (isAsyncVoidReturnType(returnType, adapter) ?
|
||||
Mono.from(adapter.toPublisher(value)) : Mono.justOrEmpty(value));
|
||||
});
|
||||
|
|
|
@ -39,6 +39,7 @@ import org.springframework.messaging.handler.annotation.MessageMapping
|
|||
import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler
|
||||
import org.springframework.stereotype.Controller
|
||||
import reactor.core.publisher.Flux
|
||||
import reactor.core.publisher.ReplayProcessor
|
||||
import reactor.test.StepVerifier
|
||||
import java.time.Duration
|
||||
|
||||
|
@ -50,6 +51,34 @@ import java.time.Duration
|
|||
*/
|
||||
class RSocketClientToServerCoroutinesIntegrationTests {
|
||||
|
||||
@Test
|
||||
fun fireAndForget() {
|
||||
Flux.range(1, 3)
|
||||
.concatMap { requester.route("receive").data("Hello $it").send() }
|
||||
.blockLast()
|
||||
StepVerifier.create(context.getBean(ServerController::class.java).fireForgetPayloads)
|
||||
.expectNext("Hello 1")
|
||||
.expectNext("Hello 2")
|
||||
.expectNext("Hello 3")
|
||||
.thenAwait(Duration.ofMillis(50))
|
||||
.thenCancel()
|
||||
.verify(Duration.ofSeconds(5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun fireAndForgetAsync() {
|
||||
Flux.range(1, 3)
|
||||
.concatMap { i: Int -> requester.route("receive-async").data("Hello $i").send() }
|
||||
.blockLast()
|
||||
StepVerifier.create(context.getBean(ServerController::class.java).fireForgetPayloads)
|
||||
.expectNext("Hello 1")
|
||||
.expectNext("Hello 2")
|
||||
.expectNext("Hello 3")
|
||||
.thenAwait(Duration.ofMillis(50))
|
||||
.thenCancel()
|
||||
.verify(Duration.ofSeconds(5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun echoAsync() {
|
||||
val result = Flux.range(1, 3).concatMap { i -> requester.route("echo-async").data("Hello " + i!!).retrieveMono(String::class.java) }
|
||||
|
@ -70,6 +99,16 @@ class RSocketClientToServerCoroutinesIntegrationTests {
|
|||
.verify(Duration.ofSeconds(5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun echoStreamAsync() {
|
||||
val result = requester.route("echo-stream-async").data("Hello").retrieveFlux(String::class.java)
|
||||
|
||||
StepVerifier.create(result)
|
||||
.expectNext("Hello 0").expectNextCount(6).expectNext("Hello 7")
|
||||
.thenCancel()
|
||||
.verify(Duration.ofSeconds(5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun echoChannel() {
|
||||
val result = requester.route("echo-channel")
|
||||
|
@ -106,6 +145,19 @@ class RSocketClientToServerCoroutinesIntegrationTests {
|
|||
@Controller
|
||||
class ServerController {
|
||||
|
||||
val fireForgetPayloads = ReplayProcessor.create<String>()
|
||||
|
||||
@MessageMapping("receive")
|
||||
fun receive(payload: String) {
|
||||
fireForgetPayloads.onNext(payload)
|
||||
}
|
||||
|
||||
@MessageMapping("receive-async")
|
||||
suspend fun receiveAsync(payload: String) {
|
||||
delay(10)
|
||||
fireForgetPayloads.onNext(payload)
|
||||
}
|
||||
|
||||
@MessageMapping("echo-async")
|
||||
suspend fun echoAsync(payload: String): String {
|
||||
delay(10)
|
||||
|
@ -123,6 +175,18 @@ class RSocketClientToServerCoroutinesIntegrationTests {
|
|||
}
|
||||
}
|
||||
|
||||
@MessageMapping("echo-stream-async")
|
||||
suspend fun echoStreamAsync(payload: String): Flow<String> {
|
||||
delay(10)
|
||||
var i = 0
|
||||
return flow {
|
||||
while(true) {
|
||||
delay(10)
|
||||
emit("$payload ${i++}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@MessageMapping("echo-channel")
|
||||
fun echoChannel(payloads: Flow<String>) = payloads.map {
|
||||
delay(10)
|
||||
|
@ -185,8 +249,6 @@ class RSocketClientToServerCoroutinesIntegrationTests {
|
|||
|
||||
private lateinit var server: CloseableChannel
|
||||
|
||||
private val interceptor = FireAndForgetCountingInterceptor()
|
||||
|
||||
private lateinit var requester: RSocketRequester
|
||||
|
||||
|
||||
|
@ -196,7 +258,6 @@ class RSocketClientToServerCoroutinesIntegrationTests {
|
|||
context = AnnotationConfigApplicationContext(ServerConfig::class.java)
|
||||
|
||||
server = RSocketFactory.receive()
|
||||
.addResponderPlugin(interceptor)
|
||||
.frameDecoder(PayloadDecoder.ZERO_COPY)
|
||||
.acceptor(context.getBean(RSocketMessageHandler::class.java).responder())
|
||||
.transport(TcpServerTransport.create("localhost", 7000))
|
||||
|
|
|
@ -139,7 +139,8 @@ public class InvocableHandlerMethod extends HandlerMethod {
|
|||
try {
|
||||
ReflectionUtils.makeAccessible(getBridgedMethod());
|
||||
Method method = getBridgedMethod();
|
||||
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) {
|
||||
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())
|
||||
&& CoroutinesUtils.isSuspendingFunction(method)) {
|
||||
value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args);
|
||||
}
|
||||
else {
|
||||
|
|
Loading…
Reference in New Issue