Add coroutines support to RSocket @MessageMapping

Closes gh-22780
This commit is contained in:
Sebastien Deleuze 2019-04-30 15:42:55 +02:00
parent 842e7e5ef7
commit 2e6059f6b0
7 changed files with 311 additions and 3 deletions

View File

@ -12,6 +12,7 @@ def rsocketVersion = "0.12.2-RC3-SNAPSHOT"
dependencies {
compile(project(":spring-beans"))
compile(project(":spring-core"))
compileOnly(project(":spring-core-coroutines"))
optional(project(":spring-context"))
optional(project(":spring-oxm"))
optional("io.projectreactor.netty:reactor-netty")
@ -35,6 +36,7 @@ dependencies {
testCompile("org.jetbrains.kotlin:kotlin-reflect:${kotlinVersion}")
testCompile("org.jetbrains.kotlin:kotlin-stdlib:${kotlinVersion}")
testCompile("org.xmlunit:xmlunit-matchers:2.6.2")
testCompile(project(":spring-core-coroutines"))
testRuntime("com.sun.xml.bind:jaxb-core:2.3.0.1")
testRuntime("com.sun.xml.bind:jaxb-impl:2.3.0.1")
testRuntime("com.sun.activation:javax.activation:1.2.0")

View File

@ -0,0 +1,42 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.handler.annotation.support.reactive;
import reactor.core.publisher.Mono;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver;
/**
* No-op resolver for method arguments of type {@link kotlin.coroutines.Continuation}.
*
* @author Sebastien Deleuze
* @since 5.2
*/
public class ContinuationHandlerMethodArgumentResolver implements HandlerMethodArgumentResolver {
@Override
public boolean supportsParameter(MethodParameter parameter) {
return "kotlin.coroutines.Continuation".equals(parameter.getParameterType().getName());
}
@Override
public Mono<Object> resolveArgument(MethodParameter parameter, Message<?> message) {
return Mono.empty();
}
}

View File

@ -34,6 +34,7 @@ import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.EmbeddedValueResolverAware;
import org.springframework.core.KotlinDetector;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.codec.Decoder;
import org.springframework.core.convert.ConversionService;
@ -238,6 +239,11 @@ public class MessageMappingMessageHandler extends AbstractMethodMessageHandler<C
resolvers.add(new HeadersMethodArgumentResolver());
resolvers.add(new DestinationVariableMethodArgumentResolver(this.conversionService));
// Type-based...
if (KotlinDetector.isKotlinPresent()) {
resolvers.add(new ContinuationHandlerMethodArgumentResolver());
}
// Custom resolvers
resolvers.addAll(getArgumentResolverConfigurer().getCustomResolvers());

View File

@ -16,16 +16,20 @@
package org.springframework.messaging.handler.invocation.reactive;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import kotlin.reflect.KFunction;
import kotlin.reflect.jvm.ReflectJvmMapping;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ReactiveAdapterRegistry;
@ -60,6 +64,8 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler
private static final ResolvableType OBJECT_RESOLVABLE_TYPE = ResolvableType.forClass(Object.class);
private static final String COROUTINES_FLOW_CLASS_NAME = "kotlinx.coroutines.flow.Flow";
protected final Log logger = LogFactory.getLog(getClass());
@ -132,7 +138,11 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler
ResolvableType elementType;
if (adapter != null) {
publisher = adapter.toPublisher(content);
ResolvableType genericType = returnValueType.getGeneric();
boolean isUnwrapped = KotlinDetector.isKotlinReflectPresent() &&
KotlinDetector.isKotlinType(returnType.getContainingClass()) &&
KotlinDelegate.isSuspend(returnType.getMethod()) &&
!COROUTINES_FLOW_CLASS_NAME.equals(returnValueType.toClass().getName());
ResolvableType genericType = isUnwrapped ? returnValueType : returnValueType.getGeneric();
elementType = getElementType(adapter, genericType);
}
else {
@ -213,4 +223,16 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler
*/
protected abstract Mono<Void> handleNoContent(MethodParameter returnType, Message<?> message);
/**
* Inner class to avoid a hard dependency on Kotlin at runtime.
*/
private static class KotlinDelegate {
static private boolean isSuspend(Method method) {
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
return function != null && function.isSuspend();
}
}
}

View File

@ -26,7 +26,9 @@ import java.util.stream.Stream;
import reactor.core.publisher.Mono;
import org.springframework.core.CoroutinesUtils;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.core.ReactiveAdapter;
@ -125,13 +127,20 @@ public class InvocableHandlerMethod extends HandlerMethod {
* @param providedArgs optional list of argument values to match by type
* @return a Mono with the result from the invocation.
*/
@SuppressWarnings("KotlinInternalInJava")
public Mono<Object> invoke(Message<?> message, Object... providedArgs) {
return getMethodArgumentValues(message, providedArgs).flatMap(args -> {
Object value;
try {
ReflectionUtils.makeAccessible(getBridgedMethod());
value = getBridgedMethod().invoke(getBean(), args);
Method method = getBridgedMethod();
ReflectionUtils.makeAccessible(method);
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) {
value = CoroutinesUtils.invokeHandlerMethod(method, getBean(), args);
}
else {
value = method.invoke(getBean(), args);
}
}
catch (IllegalArgumentException ex) {
assertTargetBean(getBridgedMethod(), getBean(), args);

View File

@ -0,0 +1,225 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket
import java.time.Duration
import io.netty.buffer.PooledByteBufAllocator
import io.rsocket.RSocketFactory
import io.rsocket.frame.decoder.PayloadDecoder
import io.rsocket.transport.netty.server.CloseableChannel
import io.rsocket.transport.netty.server.TcpServerTransport
import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import org.junit.AfterClass
import org.junit.BeforeClass
import org.junit.Test
import reactor.core.publisher.Flux
import reactor.core.publisher.ReplayProcessor
import reactor.test.StepVerifier
import org.springframework.context.annotation.AnnotationConfigApplicationContext
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.core.codec.CharSequenceEncoder
import org.springframework.core.codec.StringDecoder
import org.springframework.core.io.buffer.NettyDataBufferFactory
import org.springframework.messaging.handler.annotation.MessageExceptionHandler
import org.springframework.messaging.handler.annotation.MessageMapping
import org.springframework.stereotype.Controller
/**
* Coroutines server-side handling of RSocket requests.
*
* @author Sebastien Deleuze
* @author Rossen Stoyanchev
*/
class RSocketClientToServerCoroutinesIntegrationTests {
@Test
fun echoAsync() {
val result = Flux.range(1, 3).concatMap { i -> requester.route("echo-async").data("Hello " + i!!).retrieveMono(String::class.java) }
StepVerifier.create(result)
.expectNext("Hello 1 async").expectNext("Hello 2 async").expectNext("Hello 3 async")
.expectComplete()
.verify(Duration.ofSeconds(5))
}
@Test
fun echoStream() {
val result = requester.route("echo-stream").data("Hello").retrieveFlux(String::class.java)
StepVerifier.create(result)
.expectNext("Hello 0").expectNextCount(6).expectNext("Hello 7")
.thenCancel()
.verify(Duration.ofSeconds(5))
}
@Test
fun echoChannel() {
val result = requester.route("echo-channel")
.data(Flux.range(1, 10).map { i -> "Hello " + i!! }, String::class.java)
.retrieveFlux(String::class.java)
StepVerifier.create(result)
.expectNext("Hello 1 async").expectNextCount(8).expectNext("Hello 10 async")
.thenCancel() // https://github.com/rsocket/rsocket-java/issues/613
.verify(Duration.ofSeconds(5))
}
@Test
fun unitReturnValue() {
val result = requester.route("unit-return-value").data("Hello").retrieveFlux(String::class.java)
StepVerifier.create(result).expectComplete().verify(Duration.ofSeconds(5))
}
@Test
fun unitReturnValueFromExceptionHandler() {
val result = requester.route("unit-return-value").data("bad").retrieveFlux(String::class.java)
StepVerifier.create(result).expectComplete().verify(Duration.ofSeconds(5))
}
@Test
fun handleWithThrownException() {
val result = requester.route("thrown-exception").data("a").retrieveMono(String::class.java)
StepVerifier.create(result)
.expectNext("Invalid input error handled")
.expectComplete()
.verify(Duration.ofSeconds(5))
}
@FlowPreview
@Controller
class ServerController {
val fireForgetPayloads = ReplayProcessor.create<String>()
@MessageMapping("echo-async")
suspend fun echoAsync(payload: String): String {
delay(10)
return "$payload async"
}
@MessageMapping("echo-stream")
fun echoStream(payload: String): Flow<String> {
var i = 0
return flow {
while(true) {
delay(10)
emit("$payload ${i++}")
}
}
}
@MessageMapping("echo-channel")
fun echoChannel(payloads: Flow<String>) = payloads.map {
delay(10)
"$it async"
}
@MessageMapping("thrown-exception")
suspend fun handleAndThrow(payload: String): String {
delay(10)
throw IllegalArgumentException("Invalid input error")
}
@MessageMapping("unit-return-value")
suspend fun unitReturnValue(payload: String) =
if (payload != "bad") delay(10) else throw IllegalStateException("bad")
@MessageExceptionHandler
suspend fun handleException(ex: IllegalArgumentException): String {
delay(10)
return "${ex.message} handled"
}
@MessageExceptionHandler
suspend fun handleExceptionWithVoidReturnValue(ex: IllegalStateException) {
delay(10)
}
}
@Configuration
open class ServerConfig {
@Bean
open fun controller(): ServerController {
return ServerController()
}
@Bean
open fun messageHandlerAcceptor(): MessageHandlerAcceptor {
val acceptor = MessageHandlerAcceptor()
acceptor.rSocketStrategies = rsocketStrategies()
return acceptor
}
@Bean
open fun rsocketStrategies(): RSocketStrategies {
return RSocketStrategies.builder()
.decoder(StringDecoder.allMimeTypes())
.encoder(CharSequenceEncoder.allMimeTypes())
.dataBufferFactory(NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT))
.build()
}
}
companion object {
private lateinit var context: AnnotationConfigApplicationContext
private lateinit var server: CloseableChannel
private val interceptor = FireAndForgetCountingInterceptor()
private lateinit var requester: RSocketRequester
@BeforeClass
@JvmStatic
fun setupOnce() {
context = AnnotationConfigApplicationContext(ServerConfig::class.java)
server = RSocketFactory.receive()
.addServerPlugin(interceptor)
.frameDecoder(PayloadDecoder.ZERO_COPY)
.acceptor(context.getBean(MessageHandlerAcceptor::class.java))
.transport(TcpServerTransport.create("localhost", 7000))
.start()
.block()!!
requester = RSocketRequester.builder()
.rsocketFactory { factory -> factory.frameDecoder(PayloadDecoder.ZERO_COPY) }
.rsocketStrategies(context.getBean(RSocketStrategies::class.java))
.connectTcp("localhost", 7000)
.block()!!
}
@AfterClass
@JvmStatic
fun tearDownOnce() {
requester.rsocket().dispose()
server.dispose()
}
}
}

View File

@ -414,6 +414,8 @@ Spring Framework provides support for Coroutines on the following scope:
* Suspending function support in Spring WebFlux annotated `@Controller`
* Extensions for WebFlux {doc-root}/spring-framework/docs/{spring-version}/kdoc-api/spring-framework/org.springframework.web.reactive.function.client/index.html[client] and {doc-root}/spring-framework/docs/{spring-version}/kdoc-api/spring-framework/org.springframework.web.reactive.function.server/index.html[server] functional API.
* WebFlux.fn {doc-root}/spring-framework/docs/{spring-version}/kdoc-api/spring-framework/org.springframework.web.reactive.function.server/co-router.html[coRouter { }] DSL
* Suspending function and `Flow` support in RSocket `@MessageMapping` annotated methods
* Extensions for {doc-root}/spring-framework/docs/{spring-version}/kdoc-api/spring-framework/org.springframework.messaging.rsocket/index.html[`RSocketRequester`]
=== How Reactive translates to Coroutines?