Add RSocketServiceMethod support for suspending functions

See #34868

Signed-off-by: Dmitry Sulman <dmitry.sulman@gmail.com>
This commit is contained in:
Dmitry Sulman 2025-09-12 19:22:46 +03:00
parent 2faed3cdbb
commit 255ef569d7
5 changed files with 177 additions and 3 deletions

View File

@ -17,6 +17,7 @@
package org.springframework.aop.framework;
import kotlin.coroutines.Continuation;
import kotlinx.coroutines.flow.Flow;
import kotlinx.coroutines.reactive.ReactiveFlowKt;
import kotlinx.coroutines.reactor.MonoKt;
import org.jspecify.annotations.Nullable;
@ -35,6 +36,9 @@ abstract class CoroutinesUtils {
if (publisher instanceof Publisher<?> rsPublisher) {
return ReactiveFlowKt.asFlow(rsPublisher);
}
else if (publisher instanceof Flow<?>) {
return publisher;
}
else {
throw new IllegalArgumentException("Not a Reactive Streams Publisher: " + publisher);
}

View File

@ -18,6 +18,7 @@ package org.springframework.aop.framework
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
@ -72,4 +73,16 @@ class CoroutinesUtilsTests {
}
}
@Test
@Suppress("UNCHECKED_CAST")
fun flowAsFlow() {
val value1 = "foo"
val value2 = "bar"
val values = flowOf(value1, value2)
val flow = CoroutinesUtils.asFlow(values) as Flow<String>
runBlocking {
assertThat(flow.toList()).containsExactly(value1, value2)
}
}
}

View File

@ -30,6 +30,7 @@ import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ReactiveAdapter;
@ -54,6 +55,8 @@ import org.springframework.util.StringValueResolver;
*/
final class RSocketServiceMethod {
private static final String COROUTINES_FLOW_CLASS_NAME = "kotlinx.coroutines.flow.Flow";
private final Method method;
private final MethodParameter[] parameters;
@ -82,6 +85,10 @@ final class RSocketServiceMethod {
if (count == 0) {
return new MethodParameter[0];
}
if (KotlinDetector.isSuspendingFunction(method)) {
count -= 1;
}
DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
MethodParameter[] parameters = new MethodParameter[count];
for (int i = 0; i < count; i++) {
@ -129,10 +136,16 @@ final class RSocketServiceMethod {
MethodParameter returnParam = new MethodParameter(method, -1);
Class<?> returnType = returnParam.getParameterType();
boolean isUnwrapped = KotlinDetector.isSuspendingFunction(method) &&
!COROUTINES_FLOW_CLASS_NAME.equals(returnParam.getParameterType().getName());
if (isUnwrapped) {
returnType = Mono.class;
}
ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType);
MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional());
Class<?> actualType = actualParam.getNestedParameterType();
Class<?> actualType = isUnwrapped ? actualParam.getParameterType() : actualParam.getNestedParameterType();
Function<RSocketRequestValues, Publisher<?>> responseFunction;
if (ClassUtils.isVoidType(actualType) || (reactiveAdapter != null && reactiveAdapter.isNoValue())) {
@ -147,7 +160,8 @@ final class RSocketServiceMethod {
}
else {
ParameterizedTypeReference<?> payloadType =
ParameterizedTypeReference.forType(actualParam.getNestedGenericParameterType());
ParameterizedTypeReference.forType(isUnwrapped ? actualParam.getGenericParameterType() :
actualParam.getNestedGenericParameterType());
responseFunction = values -> (
reactiveAdapter.isMultiValue() ?

View File

@ -31,6 +31,7 @@ import org.jspecify.annotations.Nullable;
import org.springframework.aop.framework.ProxyFactory;
import org.springframework.aop.framework.ReflectiveMethodInvocation;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodIntrospector;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.annotation.AnnotatedElementUtils;
@ -246,7 +247,9 @@ public final class RSocketServiceProxyFactory {
Method method = invocation.getMethod();
RSocketServiceMethod serviceMethod = this.serviceMethods.get(method);
if (serviceMethod != null) {
return serviceMethod.invoke(invocation.getArguments());
@Nullable Object[] arguments = KotlinDetector.isSuspendingFunction(method) ?
resolveCoroutinesArguments(invocation.getArguments()) : invocation.getArguments();
return serviceMethod.invoke(arguments);
}
if (method.isDefault()) {
if (invocation instanceof ReflectiveMethodInvocation reflectiveMethodInvocation) {
@ -256,6 +259,12 @@ public final class RSocketServiceProxyFactory {
}
throw new IllegalStateException("Unexpected method invocation: " + method);
}
private static Object[] resolveCoroutinesArguments(@Nullable Object[] args) {
Object[] functionArgs = new Object[args.length - 1];
System.arraycopy(args, 0, functionArgs, 0, args.length - 1);
return functionArgs;
}
}
}

View File

@ -0,0 +1,134 @@
/*
* Copyright 2002-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service
import io.rsocket.util.DefaultPayload
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.reactive.asFlow
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.springframework.messaging.rsocket.RSocketRequester
import org.springframework.messaging.rsocket.RSocketStrategies
import org.springframework.messaging.rsocket.TestRSocket
import org.springframework.util.MimeTypeUtils.TEXT_PLAIN
import reactor.core.publisher.Flux
import reactor.core.publisher.Mono
/**
* Kotlin tests for [RSocketServiceMethod].
*
* @author Dmitry Sulman
*/
class RSocketServiceMethodKotlinTests {
private lateinit var rsocket: TestRSocket
private lateinit var proxyFactory: RSocketServiceProxyFactory
@BeforeEach
fun setUp() {
rsocket = TestRSocket()
val requester = RSocketRequester.wrap(rsocket, TEXT_PLAIN, TEXT_PLAIN, RSocketStrategies.create())
proxyFactory = RSocketServiceProxyFactory.builder(requester).build()
}
@Test
fun fireAndForget(): Unit = runBlocking {
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
val requestPayload = "request"
service.fireAndForget(requestPayload)
assertThat(rsocket.savedMethodName).isEqualTo("fireAndForget")
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("ff")
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
}
@Test
fun requestResponse(): Unit = runBlocking {
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
val requestPayload = "request"
val responsePayload = "response"
rsocket.setPayloadMonoToReturn(Mono.just(DefaultPayload.create(responsePayload)))
val response = service.requestResponse(requestPayload)
assertThat(response).isEqualTo(responsePayload)
assertThat(rsocket.savedMethodName).isEqualTo("requestResponse")
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rr")
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
}
@Test
fun requestStream(): Unit = runBlocking {
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
val requestPayload = "request"
val responsePayload1 = "response1"
val responsePayload2 = "response2"
rsocket.setPayloadFluxToReturn(
Flux.just(DefaultPayload.create(responsePayload1), DefaultPayload.create(responsePayload2)))
val response = service.requestStream(requestPayload).toList()
assertThat(response).containsExactly(responsePayload1, responsePayload2)
assertThat(rsocket.savedMethodName).isEqualTo("requestStream")
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rs")
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
}
@Test
fun requestChannel(): Unit = runBlocking {
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
val requestPayload1 = "request1"
val requestPayload2 = "request2"
val responsePayload1 = "response1"
val responsePayload2 = "response2"
rsocket.setPayloadFluxToReturn(
Flux.just(DefaultPayload.create(responsePayload1), DefaultPayload.create(responsePayload2)))
val response = service.requestChannel(flowOf(requestPayload1, requestPayload2)).toList()
assertThat(response).containsExactly(responsePayload1, responsePayload2)
assertThat(rsocket.savedMethodName).isEqualTo("requestChannel")
val savedPayloads = rsocket.savedPayloadFlux
?.asFlow()
?.map { it.dataUtf8 }
?.toList()
assertThat(savedPayloads).containsExactly(requestPayload1, requestPayload2)
}
private interface SuspendingFunctionsService {
@RSocketExchange("ff")
suspend fun fireAndForget(input: String)
@RSocketExchange("rr")
suspend fun requestResponse(input: String): String
@RSocketExchange("rs")
suspend fun requestStream(input: String): Flow<String>
@RSocketExchange("rc")
suspend fun requestChannel(input: Flow<String>): Flow<String>
}
}