Add RSocketServiceMethod support for suspending functions
See #34868 Signed-off-by: Dmitry Sulman <dmitry.sulman@gmail.com>
This commit is contained in:
parent
2faed3cdbb
commit
255ef569d7
|
@ -17,6 +17,7 @@
|
|||
package org.springframework.aop.framework;
|
||||
|
||||
import kotlin.coroutines.Continuation;
|
||||
import kotlinx.coroutines.flow.Flow;
|
||||
import kotlinx.coroutines.reactive.ReactiveFlowKt;
|
||||
import kotlinx.coroutines.reactor.MonoKt;
|
||||
import org.jspecify.annotations.Nullable;
|
||||
|
@ -35,6 +36,9 @@ abstract class CoroutinesUtils {
|
|||
if (publisher instanceof Publisher<?> rsPublisher) {
|
||||
return ReactiveFlowKt.asFlow(rsPublisher);
|
||||
}
|
||||
else if (publisher instanceof Flow<?>) {
|
||||
return publisher;
|
||||
}
|
||||
else {
|
||||
throw new IllegalArgumentException("Not a Reactive Streams Publisher: " + publisher);
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.springframework.aop.framework
|
|||
|
||||
import kotlinx.coroutines.CoroutineName
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.flowOf
|
||||
import kotlinx.coroutines.flow.toList
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
|
@ -72,4 +73,16 @@ class CoroutinesUtilsTests {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun flowAsFlow() {
|
||||
val value1 = "foo"
|
||||
val value2 = "bar"
|
||||
val values = flowOf(value1, value2)
|
||||
val flow = CoroutinesUtils.asFlow(values) as Flow<String>
|
||||
runBlocking {
|
||||
assertThat(flow.toList()).containsExactly(value1, value2)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.reactivestreams.Publisher;
|
|||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.core.DefaultParameterNameDiscoverer;
|
||||
import org.springframework.core.KotlinDetector;
|
||||
import org.springframework.core.MethodParameter;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.ReactiveAdapter;
|
||||
|
@ -54,6 +55,8 @@ import org.springframework.util.StringValueResolver;
|
|||
*/
|
||||
final class RSocketServiceMethod {
|
||||
|
||||
private static final String COROUTINES_FLOW_CLASS_NAME = "kotlinx.coroutines.flow.Flow";
|
||||
|
||||
private final Method method;
|
||||
|
||||
private final MethodParameter[] parameters;
|
||||
|
@ -82,6 +85,10 @@ final class RSocketServiceMethod {
|
|||
if (count == 0) {
|
||||
return new MethodParameter[0];
|
||||
}
|
||||
if (KotlinDetector.isSuspendingFunction(method)) {
|
||||
count -= 1;
|
||||
}
|
||||
|
||||
DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
|
||||
MethodParameter[] parameters = new MethodParameter[count];
|
||||
for (int i = 0; i < count; i++) {
|
||||
|
@ -129,10 +136,16 @@ final class RSocketServiceMethod {
|
|||
|
||||
MethodParameter returnParam = new MethodParameter(method, -1);
|
||||
Class<?> returnType = returnParam.getParameterType();
|
||||
boolean isUnwrapped = KotlinDetector.isSuspendingFunction(method) &&
|
||||
!COROUTINES_FLOW_CLASS_NAME.equals(returnParam.getParameterType().getName());
|
||||
if (isUnwrapped) {
|
||||
returnType = Mono.class;
|
||||
}
|
||||
|
||||
ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType);
|
||||
|
||||
MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional());
|
||||
Class<?> actualType = actualParam.getNestedParameterType();
|
||||
Class<?> actualType = isUnwrapped ? actualParam.getParameterType() : actualParam.getNestedParameterType();
|
||||
|
||||
Function<RSocketRequestValues, Publisher<?>> responseFunction;
|
||||
if (ClassUtils.isVoidType(actualType) || (reactiveAdapter != null && reactiveAdapter.isNoValue())) {
|
||||
|
@ -147,7 +160,8 @@ final class RSocketServiceMethod {
|
|||
}
|
||||
else {
|
||||
ParameterizedTypeReference<?> payloadType =
|
||||
ParameterizedTypeReference.forType(actualParam.getNestedGenericParameterType());
|
||||
ParameterizedTypeReference.forType(isUnwrapped ? actualParam.getGenericParameterType() :
|
||||
actualParam.getNestedGenericParameterType());
|
||||
|
||||
responseFunction = values -> (
|
||||
reactiveAdapter.isMultiValue() ?
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.jspecify.annotations.Nullable;
|
|||
|
||||
import org.springframework.aop.framework.ProxyFactory;
|
||||
import org.springframework.aop.framework.ReflectiveMethodInvocation;
|
||||
import org.springframework.core.KotlinDetector;
|
||||
import org.springframework.core.MethodIntrospector;
|
||||
import org.springframework.core.ReactiveAdapterRegistry;
|
||||
import org.springframework.core.annotation.AnnotatedElementUtils;
|
||||
|
@ -246,7 +247,9 @@ public final class RSocketServiceProxyFactory {
|
|||
Method method = invocation.getMethod();
|
||||
RSocketServiceMethod serviceMethod = this.serviceMethods.get(method);
|
||||
if (serviceMethod != null) {
|
||||
return serviceMethod.invoke(invocation.getArguments());
|
||||
@Nullable Object[] arguments = KotlinDetector.isSuspendingFunction(method) ?
|
||||
resolveCoroutinesArguments(invocation.getArguments()) : invocation.getArguments();
|
||||
return serviceMethod.invoke(arguments);
|
||||
}
|
||||
if (method.isDefault()) {
|
||||
if (invocation instanceof ReflectiveMethodInvocation reflectiveMethodInvocation) {
|
||||
|
@ -256,6 +259,12 @@ public final class RSocketServiceProxyFactory {
|
|||
}
|
||||
throw new IllegalStateException("Unexpected method invocation: " + method);
|
||||
}
|
||||
|
||||
private static Object[] resolveCoroutinesArguments(@Nullable Object[] args) {
|
||||
Object[] functionArgs = new Object[args.length - 1];
|
||||
System.arraycopy(args, 0, functionArgs, 0, args.length - 1);
|
||||
return functionArgs;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
/*
|
||||
* Copyright 2002-present the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.messaging.rsocket.service
|
||||
|
||||
import io.rsocket.util.DefaultPayload
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.flowOf
|
||||
import kotlinx.coroutines.flow.map
|
||||
import kotlinx.coroutines.flow.toList
|
||||
import kotlinx.coroutines.reactive.asFlow
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.springframework.messaging.rsocket.RSocketRequester
|
||||
import org.springframework.messaging.rsocket.RSocketStrategies
|
||||
import org.springframework.messaging.rsocket.TestRSocket
|
||||
import org.springframework.util.MimeTypeUtils.TEXT_PLAIN
|
||||
import reactor.core.publisher.Flux
|
||||
import reactor.core.publisher.Mono
|
||||
|
||||
/**
|
||||
* Kotlin tests for [RSocketServiceMethod].
|
||||
*
|
||||
* @author Dmitry Sulman
|
||||
*/
|
||||
class RSocketServiceMethodKotlinTests {
|
||||
|
||||
private lateinit var rsocket: TestRSocket
|
||||
|
||||
private lateinit var proxyFactory: RSocketServiceProxyFactory
|
||||
|
||||
@BeforeEach
|
||||
fun setUp() {
|
||||
rsocket = TestRSocket()
|
||||
val requester = RSocketRequester.wrap(rsocket, TEXT_PLAIN, TEXT_PLAIN, RSocketStrategies.create())
|
||||
proxyFactory = RSocketServiceProxyFactory.builder(requester).build()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun fireAndForget(): Unit = runBlocking {
|
||||
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
|
||||
|
||||
val requestPayload = "request"
|
||||
service.fireAndForget(requestPayload)
|
||||
|
||||
assertThat(rsocket.savedMethodName).isEqualTo("fireAndForget")
|
||||
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("ff")
|
||||
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun requestResponse(): Unit = runBlocking {
|
||||
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
|
||||
|
||||
val requestPayload = "request"
|
||||
val responsePayload = "response"
|
||||
rsocket.setPayloadMonoToReturn(Mono.just(DefaultPayload.create(responsePayload)))
|
||||
val response = service.requestResponse(requestPayload)
|
||||
|
||||
assertThat(response).isEqualTo(responsePayload)
|
||||
assertThat(rsocket.savedMethodName).isEqualTo("requestResponse")
|
||||
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rr")
|
||||
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun requestStream(): Unit = runBlocking {
|
||||
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
|
||||
|
||||
val requestPayload = "request"
|
||||
val responsePayload1 = "response1"
|
||||
val responsePayload2 = "response2"
|
||||
rsocket.setPayloadFluxToReturn(
|
||||
Flux.just(DefaultPayload.create(responsePayload1), DefaultPayload.create(responsePayload2)))
|
||||
val response = service.requestStream(requestPayload).toList()
|
||||
|
||||
assertThat(response).containsExactly(responsePayload1, responsePayload2)
|
||||
assertThat(rsocket.savedMethodName).isEqualTo("requestStream")
|
||||
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rs")
|
||||
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun requestChannel(): Unit = runBlocking {
|
||||
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
|
||||
|
||||
val requestPayload1 = "request1"
|
||||
val requestPayload2 = "request2"
|
||||
val responsePayload1 = "response1"
|
||||
val responsePayload2 = "response2"
|
||||
rsocket.setPayloadFluxToReturn(
|
||||
Flux.just(DefaultPayload.create(responsePayload1), DefaultPayload.create(responsePayload2)))
|
||||
val response = service.requestChannel(flowOf(requestPayload1, requestPayload2)).toList()
|
||||
|
||||
assertThat(response).containsExactly(responsePayload1, responsePayload2)
|
||||
assertThat(rsocket.savedMethodName).isEqualTo("requestChannel")
|
||||
|
||||
val savedPayloads = rsocket.savedPayloadFlux
|
||||
?.asFlow()
|
||||
?.map { it.dataUtf8 }
|
||||
?.toList()
|
||||
assertThat(savedPayloads).containsExactly(requestPayload1, requestPayload2)
|
||||
}
|
||||
|
||||
private interface SuspendingFunctionsService {
|
||||
|
||||
@RSocketExchange("ff")
|
||||
suspend fun fireAndForget(input: String)
|
||||
|
||||
@RSocketExchange("rr")
|
||||
suspend fun requestResponse(input: String): String
|
||||
|
||||
@RSocketExchange("rs")
|
||||
suspend fun requestStream(input: String): Flow<String>
|
||||
|
||||
@RSocketExchange("rc")
|
||||
suspend fun requestChannel(input: Flow<String>): Flow<String>
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue