RSocket @MessageMapping handling

See gh-21987
This commit is contained in:
Rossen Stoyanchev 2019-02-10 14:45:16 -05:00
parent f2bb95ba7b
commit 4e78b5df2f
10 changed files with 1153 additions and 0 deletions

View File

@ -7,12 +7,15 @@ dependencyManagement {
}
}
def rsocketVersion = "0.11.15"
dependencies {
compile(project(":spring-beans"))
compile(project(":spring-core"))
optional(project(":spring-context"))
optional(project(":spring-oxm"))
optional("io.projectreactor.netty:reactor-netty")
optional("io.rsocket:rsocket-core:${rsocketVersion}")
optional("com.fasterxml.jackson.core:jackson-databind:${jackson2Version}")
optional("javax.xml.bind:jaxb-api:2.3.1")
testCompile("javax.inject:javax.inject-tck:1")
@ -26,6 +29,7 @@ dependencies {
testCompile("org.apache.activemq:activemq-stomp:5.8.0")
testCompile("io.projectreactor:reactor-test")
testCompile "io.reactivex.rxjava2:rxjava:${rxjava2Version}"
testCompile("io.rsocket:rsocket-transport-netty:${rsocketVersion}")
testCompile("org.jetbrains.kotlin:kotlin-reflect:${kotlinVersion}")
testCompile("org.jetbrains.kotlin:kotlin-stdlib:${kotlinVersion}")
testCompile("org.xmlunit:xmlunit-matchers:2.6.2")

View File

@ -0,0 +1,114 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import java.util.function.Function;
import java.util.function.Predicate;
import io.netty.buffer.PooledByteBufAllocator;
import io.rsocket.ConnectionSetupPayload;
import io.rsocket.RSocket;
import io.rsocket.SocketAcceptor;
import reactor.core.publisher.Mono;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.ReactiveMessageChannel;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
/**
* RSocket acceptor for
* {@link io.rsocket.RSocketFactory.ClientRSocketFactory#acceptor(Function) client} or
* {@link io.rsocket.RSocketFactory.ServerRSocketFactory#acceptor(SocketAcceptor) server}
* side use. It wraps requests with a {@link Message} envelope and sends them
* to a {@link ReactiveMessageChannel} for handling, e.g. via
* {@code @MessageMapping} method.
*
* @author Rossen Stoyanchev
* @since 5.2
*/
public final class MessagingAcceptor implements SocketAcceptor, Function<RSocket, RSocket> {
private final ReactiveMessageChannel messageChannel;
private NettyDataBufferFactory bufferFactory = new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT);
@Nullable
private MimeType defaultDataMimeType;
/**
* Constructor with a message channel to send messages to.
* @param messageChannel the message channel to use
* <p>This assumes a Spring configuration setup with a
* {@code ReactiveMessageChannel} and an {@link RSocketMessageHandler} which
* by default auto-detects {@code @MessageMapping} methods in
* {@code @Controller} classes, but can also be configured with a
* {@link RSocketMessageHandler#setHandlerPredicate(Predicate) handlerPredicate}
* or with handler instances.
*/
public MessagingAcceptor(ReactiveMessageChannel messageChannel) {
Assert.notNull(messageChannel, "ReactiveMessageChannel is required");
this.messageChannel = messageChannel;
}
/**
* Configure the default content type for data payloads. For server
* acceptors this is available from the {@link ConnectionSetupPayload} but
* for client acceptors it's not and must be provided here.
* <p>By default this is not set.
* @param defaultDataMimeType the MimeType to use
*/
public void setDefaultDataMimeType(@Nullable MimeType defaultDataMimeType) {
this.defaultDataMimeType = defaultDataMimeType;
}
/**
* Configure the buffer factory to use.
* <p>By default this is initialized with the allocator instance
* {@link PooledByteBufAllocator#DEFAULT}.
* @param bufferFactory the bufferFactory to use
*/
public void setNettyDataBufferFactory(NettyDataBufferFactory bufferFactory) {
Assert.notNull(bufferFactory, "DataBufferFactory is required");
this.bufferFactory = bufferFactory;
}
@Override
public Mono<RSocket> accept(ConnectionSetupPayload setupPayload, RSocket sendingRSocket) {
MimeType mimeType = setupPayload.dataMimeType() != null ?
MimeTypeUtils.parseMimeType(setupPayload.dataMimeType()) : this.defaultDataMimeType;
MessagingRSocket rsocket = createRSocket(sendingRSocket, mimeType);
return rsocket.afterConnectionEstablished(setupPayload).then(Mono.just(rsocket));
}
@Override
public RSocket apply(RSocket sendingRSocket) {
return createRSocket(sendingRSocket, this.defaultDataMimeType);
}
private MessagingRSocket createRSocket(RSocket sendingRSocket, @Nullable MimeType dataMimeType) {
return new MessagingRSocket(this.messageChannel, this.bufferFactory, sendingRSocket, dataMimeType);
}
}

View File

@ -0,0 +1,165 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import java.util.function.Function;
import io.rsocket.ConnectionSetupPayload;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBuffer;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.core.io.buffer.PooledDataBuffer;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.ReactiveMessageChannel;
import org.springframework.messaging.handler.DestinationPatternsMessageCondition;
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
/**
* Package private implementation of {@link RSocket} used from
* {@link MessagingAcceptor}.
*
* @author Rossen Stoyanchev
* @since 5.2
*/
class MessagingRSocket implements RSocket {
private final ReactiveMessageChannel messageChannel;
private final NettyDataBufferFactory bufferFactory;
private final RSocket sendingRSocket;
@Nullable
private final MimeType dataMimeType;
MessagingRSocket(ReactiveMessageChannel messageChannel, NettyDataBufferFactory bufferFactory,
RSocket sendingRSocket, @Nullable MimeType dataMimeType) {
Assert.notNull(messageChannel, "'messageChannel' is required");
Assert.notNull(bufferFactory, "'bufferFactory' is required");
Assert.notNull(sendingRSocket, "'sendingRSocket' is required");
this.messageChannel = messageChannel;
this.bufferFactory = bufferFactory;
this.sendingRSocket = sendingRSocket;
this.dataMimeType = dataMimeType;
}
public Mono<Void> afterConnectionEstablished(ConnectionSetupPayload payload) {
return execute(payload).flatMap(flux -> flux.take(0).then());
}
@Override
public Mono<Void> fireAndForget(Payload payload) {
return execute(payload).flatMap(flux -> flux.take(0).then());
}
@Override
public Mono<Payload> requestResponse(Payload payload) {
return execute(payload).flatMap(Flux::next);
}
@Override
public Flux<Payload> requestStream(Payload payload) {
return execute(payload).flatMapMany(Function.identity());
}
@Override
public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
return Flux.from(payloads)
.switchOnFirst((signal, inner) -> {
Payload first = signal.get();
return first != null ? execute(first, inner).flatMapMany(Function.identity()) : inner;
});
}
@Override
public Mono<Void> metadataPush(Payload payload) {
return null;
}
private Mono<Flux<Payload>> execute(Payload payload) {
return execute(payload, Flux.just(payload));
}
private Mono<Flux<Payload>> execute(Payload firstPayload, Flux<Payload> payloads) {
// TODO:
// Since we do retain(), we need to ensure buffers are released if not consumed,
// e.g. error before Flux subscribed to, no handler found, @MessageMapping ignores payload, etc.
Flux<NettyDataBuffer> payloadDataBuffers = payloads
.map(payload -> this.bufferFactory.wrap(payload.retain().sliceData()))
.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
MonoProcessor<Flux<Payload>> replyMono = MonoProcessor.create();
MessageHeaders headers = createHeaders(firstPayload, replyMono);
Message<?> message = MessageBuilder.createMessage(payloadDataBuffers, headers);
return this.messageChannel.send(message).flatMap(result -> result ?
replyMono.isTerminated() ? replyMono : Mono.empty() :
Mono.error(new MessageDeliveryException("RSocket interaction not handled")));
}
private MessageHeaders createHeaders(Payload payload, MonoProcessor<?> replyMono) {
// For now treat the metadata as a simple string with routing information.
// We'll have to get more sophisticated once the routing extension is completed.
// https://github.com/rsocket/rsocket-java/issues/568
MessageHeaderAccessor headers = new MessageHeaderAccessor();
String destination = payload.getMetadataUtf8();
headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, destination);
if (this.dataMimeType != null) {
headers.setContentType(this.dataMimeType);
}
headers.setHeader(SendingRSocketMethodArgumentResolver.SENDING_RSOCKET_HEADER, this.sendingRSocket);
headers.setHeader(RSocketPayloadReturnValueHandler.RESPONSE_HEADER, replyMono);
headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, this.bufferFactory);
return headers.getMessageHeaders();
}
@Override
public Mono<Void> onClose() {
return null;
}
@Override
public void dispose() {
}
}

View File

@ -0,0 +1,97 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import java.util.ArrayList;
import java.util.List;
import org.springframework.core.codec.Encoder;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.ReactiveSubscribableChannel;
import org.springframework.messaging.handler.annotation.support.reactive.MessageMappingMessageHandler;
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler;
import org.springframework.util.StringUtils;
/**
* RSocket-specific extension of {@link MessageMappingMessageHandler}.
*
* <p>The configured {@link #setEncoders(List) encoders} are used to encode the
* return values from handler methods, with the help of
* {@link RSocketPayloadReturnValueHandler}.
*
* @author Rossen Stoyanchev
* @since 5.2
*/
public class RSocketMessageHandler extends MessageMappingMessageHandler {
private final List<Encoder<?>> encoders = new ArrayList<>();
public RSocketMessageHandler(ReactiveSubscribableChannel inboundChannel) {
super(inboundChannel);
}
public RSocketMessageHandler(ReactiveSubscribableChannel inboundChannel, List<Object> handlers) {
super(inboundChannel);
setHandlerPredicate(null); // disable auto-detection..
for (Object handler : handlers) {
detectHandlerMethods(handler);
}
}
/**
* Configure the encoders to use for encoding handler method return values.
*/
public void setEncoders(List<? extends Encoder<?>> encoders) {
this.encoders.addAll(encoders);
}
/**
* Return the configured {@link #setEncoders(List) encoders}.
*/
public List<? extends Encoder<?>> getEncoders() {
return this.encoders;
}
@Override
public void afterPropertiesSet() {
getArgumentResolverConfigurer().addCustomResolver(new SendingRSocketMethodArgumentResolver());
super.afterPropertiesSet();
}
@Override
protected List<? extends HandlerMethodReturnValueHandler> initReturnValueHandlers() {
List<HandlerMethodReturnValueHandler> handlers = new ArrayList<>();
handlers.add(new RSocketPayloadReturnValueHandler(this.encoders, getReactiveAdapterRegistry()));
handlers.addAll(getReturnValueHandlerConfigurer().getCustomHandlers());
return handlers;
}
@Override
protected void handleNoMatch(@Nullable String destination, Message<?> message) {
// Ignore empty destination, probably the ConnectionSetupPayload
if (!StringUtils.isEmpty(destination)) {
super.handleNoMatch(destination, message);
throw new MessageDeliveryException("No handler for '" + destination + "'");
}
}
}

View File

@ -0,0 +1,88 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import java.util.List;
import io.rsocket.Payload;
import io.rsocket.util.ByteBufPayload;
import io.rsocket.util.DefaultPayload;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DefaultDataBuffer;
import org.springframework.core.io.buffer.NettyDataBuffer;
import org.springframework.messaging.Message;
import org.springframework.messaging.handler.invocation.reactive.AbstractEncoderMethodReturnValueHandler;
import org.springframework.util.Assert;
/**
* Extension of {@link AbstractEncoderMethodReturnValueHandler} that
* {@link #handleEncodedContent handles} encoded content by wrapping data buffers
* as RSocket payloads and by passing those to the {@link MonoProcessor}
* from the {@link #RESPONSE_HEADER} header.
*
* @author Rossen Stoyanchev
* @since 5.2
*/
public class RSocketPayloadReturnValueHandler extends AbstractEncoderMethodReturnValueHandler {
/**
* Message header name that is expected to have a {@link MonoProcessor}
* which will receive the {@code Flux<Payload>} that represents the response.
*/
public static final String RESPONSE_HEADER = "rsocketResponse";
public RSocketPayloadReturnValueHandler(List<Encoder<?>> encoders, ReactiveAdapterRegistry registry) {
super(encoders, registry);
}
@Override
@SuppressWarnings("unchecked")
protected Mono<Void> handleEncodedContent(
Flux<DataBuffer> encodedContent, MethodParameter returnType, Message<?> message) {
Object headerValue = message.getHeaders().get(RESPONSE_HEADER);
Assert.notNull(headerValue, "Missing '" + RESPONSE_HEADER + "'");
Assert.isInstanceOf(MonoProcessor.class, headerValue, "Expected MonoProcessor");
MonoProcessor<Flux<Payload>> monoProcessor = (MonoProcessor<Flux<Payload>>) headerValue;
monoProcessor.onNext(encodedContent.map(this::toPayload));
monoProcessor.onComplete();
return Mono.empty();
}
private Payload toPayload(DataBuffer dataBuffer) {
if (dataBuffer instanceof NettyDataBuffer) {
return ByteBufPayload.create(((NettyDataBuffer) dataBuffer).getNativeBuffer());
}
else if (dataBuffer instanceof DefaultDataBuffer) {
return DefaultPayload.create(((DefaultDataBuffer) dataBuffer).getNativeBuffer());
}
else {
return DefaultPayload.create(dataBuffer.asByteBuffer());
}
}
}

View File

@ -0,0 +1,58 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import io.rsocket.RSocket;
import reactor.core.publisher.Mono;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver;
import org.springframework.util.Assert;
/**
* Resolves arguments of type {@link RSocket} that can be used for making
* requests to the remote peer.
*
* @author Rossen Stoyanchev
* @since 5.2
*/
public class SendingRSocketMethodArgumentResolver implements HandlerMethodArgumentResolver {
/**
* Message header name that is expected to have the {@link RSocket} to
* initiate new interactions to the remote peer with.
*/
public static final String SENDING_RSOCKET_HEADER = "sendingRSocket";
@Override
public boolean supportsParameter(MethodParameter parameter) {
return RSocket.class.isAssignableFrom(parameter.getParameterType());
}
@Override
public Mono<Object> resolveArgument(MethodParameter parameter, Message<?> message) {
Object headerValue = message.getHeaders().get(SENDING_RSOCKET_HEADER);
Assert.notNull(headerValue, "Missing '" + SENDING_RSOCKET_HEADER + "'");
Assert.isInstanceOf(RSocket.class, headerValue, "Expected header value of type io.rsocket.RSocket");
return Mono.just(headerValue);
}
}

View File

@ -0,0 +1,9 @@
/**
* Support for the RSocket protocol.
*/
@NonNullApi
@NonNullFields
package org.springframework.messaging.rsocket;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@ -0,0 +1,78 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import io.rsocket.AbstractRSocket;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.plugins.RSocketInterceptor;
import io.rsocket.util.RSocketProxy;
import reactor.core.publisher.Mono;
/**
* Intercept received RSockets and count successfully completed requests seen
* on the server side. This is useful for verifying fire-and-forget
* interactions.
*
* @author Rossen Stoyanchev
*/
class FireAndForgetCountingInterceptor extends AbstractRSocket implements RSocketInterceptor {
private final List<CountingDecorator> rsockets = new CopyOnWriteArrayList<>();
public int getRSocketCount() {
return this.rsockets.size();
}
public int getFireAndForgetCount(int index) {
return this.rsockets.get(index).getFireAndForgetCount();
}
@Override
public RSocket apply(RSocket rsocket) {
CountingDecorator decorator = new CountingDecorator(rsocket);
this.rsockets.add(decorator);
return decorator;
}
private static class CountingDecorator extends RSocketProxy {
private final AtomicInteger fireAndForget = new AtomicInteger(0);
CountingDecorator(RSocket delegate) {
super(delegate);
}
public int getFireAndForgetCount() {
return this.fireAndForget.get();
}
@Override
public Mono<Void> fireAndForget(Payload payload) {
return super.fireAndForget(payload).doOnSuccess(aVoid -> this.fireAndForget.incrementAndGet());
}
}
}

View File

@ -0,0 +1,235 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import java.time.Duration;
import java.util.Collections;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.RSocketFactory;
import io.rsocket.transport.netty.client.TcpClientTransport;
import io.rsocket.transport.netty.server.CloseableChannel;
import io.rsocket.transport.netty.server.TcpServerTransport;
import io.rsocket.util.DefaultPayload;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.ReplayProcessor;
import reactor.test.StepVerifier;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.codec.StringDecoder;
import org.springframework.messaging.ReactiveMessageChannel;
import org.springframework.messaging.ReactiveSubscribableChannel;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.support.DefaultReactiveMessageChannel;
import org.springframework.stereotype.Controller;
import static org.junit.Assert.*;
/**
* Server-side handling of RSocket requests.
*
* @author Rossen Stoyanchev
*/
public class RSocketClientToServerIntegrationTests {
private static AnnotationConfigApplicationContext context;
private static CloseableChannel serverChannel;
private static FireAndForgetCountingInterceptor interceptor = new FireAndForgetCountingInterceptor();
private static RSocket clientRsocket;
@BeforeClass
@SuppressWarnings("ConstantConditions")
public static void setupOnce() {
context = new AnnotationConfigApplicationContext(ServerConfig.class);
MessagingAcceptor acceptor = new MessagingAcceptor(
context.getBean("rsocketChannel", ReactiveMessageChannel.class));
serverChannel = RSocketFactory.receive()
.addServerPlugin(interceptor)
.acceptor(acceptor)
.transport(TcpServerTransport.create("localhost", 7000))
.start()
.block();
clientRsocket = RSocketFactory.connect()
.dataMimeType("text/plain")
.transport(TcpClientTransport.create("localhost", 7000))
.start()
.block();
}
@AfterClass
public static void tearDownOnce() {
clientRsocket.dispose();
serverChannel.dispose();
}
@Test
public void fireAndForget() {
Flux.range(1, 3)
.concatMap(i -> clientRsocket.fireAndForget(payload("receive", "Hello " + i)))
.blockLast();
StepVerifier.create(context.getBean(ServerController.class).fireForgetPayloads)
.expectNext("Hello 1")
.expectNext("Hello 2")
.expectNext("Hello 3")
.thenCancel()
.verify(Duration.ofSeconds(5));
assertEquals(1, interceptor.getRSocketCount());
assertEquals("Fire and forget requests did not actually complete handling on the server side",
3, interceptor.getFireAndForgetCount(0));
}
@Test
public void echo() {
Flux<String> result = Flux.range(1, 3).concatMap(i ->
clientRsocket.requestResponse(payload("echo", "Hello " + i)).map(Payload::getDataUtf8));
StepVerifier.create(result)
.expectNext("Hello 1")
.expectNext("Hello 2")
.expectNext("Hello 3")
.verifyComplete();
}
@Test
public void echoAsync() {
Flux<String> result = Flux.range(1, 3).concatMap(i ->
clientRsocket.requestResponse(payload("echo-async", "Hello " + i)).map(Payload::getDataUtf8));
StepVerifier.create(result)
.expectNext("Hello 1 async")
.expectNext("Hello 2 async")
.expectNext("Hello 3 async")
.verifyComplete();
}
@Test
public void echoStream() {
Flux<String> result = clientRsocket.requestStream(payload("echo-stream", "Hello"))
.map(io.rsocket.Payload::getDataUtf8);
StepVerifier.create(result)
.expectNext("Hello 0")
.expectNextCount(5)
.expectNext("Hello 6")
.expectNext("Hello 7")
.thenCancel()
.verify();
}
@Test
public void echoChannel() {
Flux<Payload> payloads = Flux.concat(
Flux.just(payload("echo-channel", "Hello 1")),
Flux.range(2, 9).map(i -> DefaultPayload.create("Hello " + i)));
Flux<String> result = clientRsocket.requestChannel(payloads).map(Payload::getDataUtf8);
StepVerifier.create(result)
.expectNext("Hello 1 async")
.expectNextCount(7)
.expectNext("Hello 9 async")
.expectNext("Hello 10 async")
.verifyComplete();
}
private static Payload payload(String destination, String data) {
return DefaultPayload.create(data, destination);
}
@Controller
static class ServerController {
final ReplayProcessor<String> fireForgetPayloads = ReplayProcessor.create();
@MessageMapping("receive")
void receive(String payload) {
this.fireForgetPayloads.onNext(payload);
}
@MessageMapping("echo")
String echo(String payload) {
return payload;
}
@MessageMapping("echo-async")
Mono<String> echoAsync(String payload) {
return Mono.delay(Duration.ofMillis(10)).map(aLong -> payload + " async");
}
@MessageMapping("echo-stream")
Flux<String> echoStream(String payload) {
return Flux.interval(Duration.ofMillis(10)).map(aLong -> payload + " " + aLong);
}
@MessageMapping("echo-channel")
Flux<String> echoChannel(Flux<String> payloads) {
return payloads.delayElements(Duration.ofMillis(10)).map(payload -> payload + " async");
}
}
@Configuration
static class ServerConfig {
@Bean
public ServerController controller() {
return new ServerController();
}
@Bean
public ReactiveSubscribableChannel rsocketChannel() {
return new DefaultReactiveMessageChannel();
}
@Bean
public RSocketMessageHandler rsocketMessageHandler() {
RSocketMessageHandler handler = new RSocketMessageHandler(rsocketChannel());
handler.setDecoders(Collections.singletonList(StringDecoder.allMimeTypes()));
handler.setEncoders(Collections.singletonList(CharSequenceEncoder.allMimeTypes()));
return handler;
}
}
}

View File

@ -0,0 +1,305 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.RSocketFactory;
import io.rsocket.transport.netty.client.TcpClientTransport;
import io.rsocket.transport.netty.server.CloseableChannel;
import io.rsocket.transport.netty.server.TcpServerTransport;
import io.rsocket.util.DefaultPayload;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
import reactor.core.publisher.ReplayProcessor;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.codec.StringDecoder;
import org.springframework.messaging.ReactiveMessageChannel;
import org.springframework.messaging.ReactiveSubscribableChannel;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.support.DefaultReactiveMessageChannel;
import org.springframework.stereotype.Controller;
/**
* Client-side handling of requests initiated from the server side.
*
* @author Rossen Stoyanchev
*/
public class RSocketServerToClientIntegrationTests {
private static AnnotationConfigApplicationContext context;
private static CloseableChannel serverChannel;
private static MessagingAcceptor clientAcceptor;
@BeforeClass
@SuppressWarnings("ConstantConditions")
public static void setupOnce() {
context = new AnnotationConfigApplicationContext(ServerConfig.class);
clientAcceptor = new MessagingAcceptor(
context.getBean("clientChannel", ReactiveMessageChannel.class));
MessagingAcceptor serverAcceptor = new MessagingAcceptor(
context.getBean("serverChannel", ReactiveMessageChannel.class));
serverChannel = RSocketFactory.receive()
.acceptor(serverAcceptor)
.transport(TcpServerTransport.create("localhost", 7000))
.start()
.block();
}
@AfterClass
public static void tearDownOnce() {
serverChannel.dispose();
}
@Test
public void echo() {
connectAndVerify("connect.echo");
}
@Test
public void echoAsync() {
connectAndVerify("connect.echo-async");
}
@Test
public void echoStream() {
connectAndVerify("connect.echo-stream");
}
@Test
public void echoChannel() {
connectAndVerify("connect.echo-channel");
}
private static void connectAndVerify(String destination) {
ServerController serverController = context.getBean(ServerController.class);
serverController.reset();
RSocket rsocket = null;
try {
rsocket = RSocketFactory.connect()
.setupPayload(DefaultPayload.create("", destination))
.dataMimeType("text/plain")
.acceptor(clientAcceptor)
.transport(TcpClientTransport.create("localhost", 7000))
.start()
.block();
serverController.await(Duration.ofSeconds(5));
}
finally {
if (rsocket != null) {
rsocket.dispose();
}
}
}
@Controller
@SuppressWarnings({"unused", "NullableProblems"})
static class ServerController {
// Must be initialized by @Test method...
volatile MonoProcessor<Void> result;
@MessageMapping("connect.echo")
void echo(RSocket rsocket) {
runTest(() -> {
Flux<String> result = Flux.range(1, 3).concatMap(i ->
rsocket.requestResponse(payload("echo", "Hello " + i)).map(Payload::getDataUtf8));
StepVerifier.create(result)
.expectNext("Hello 1")
.expectNext("Hello 2")
.expectNext("Hello 3")
.verifyComplete();
});
}
@MessageMapping("connect.echo-async")
void echoAsync(RSocket rsocket) {
runTest(() -> {
Flux<String> result = Flux.range(1, 3).concatMap(i ->
rsocket.requestResponse(payload("echo-async", "Hello " + i)).map(Payload::getDataUtf8));
StepVerifier.create(result)
.expectNext("Hello 1 async")
.expectNext("Hello 2 async")
.expectNext("Hello 3 async")
.verifyComplete();
});
}
@MessageMapping("connect.echo-stream")
void echoStream(RSocket rsocket) {
runTest(() -> {
Flux<String> result = rsocket.requestStream(payload("echo-stream", "Hello"))
.map(io.rsocket.Payload::getDataUtf8);
StepVerifier.create(result)
.expectNext("Hello 0")
.expectNextCount(5)
.expectNext("Hello 6")
.expectNext("Hello 7")
.thenCancel()
.verify();
});
}
@MessageMapping("connect.echo-channel")
void echoChannel(RSocket rsocket) {
runTest(() -> {
Flux<Payload> payloads = Flux.concat(
Flux.just(payload("echo-channel", "Hello 1")),
Flux.range(2, 9).map(i -> DefaultPayload.create("Hello " + i)));
Flux<String> result = rsocket.requestChannel(payloads).map(Payload::getDataUtf8);
StepVerifier.create(result)
.expectNext("Hello 1 async")
.expectNextCount(7)
.expectNext("Hello 9 async")
.expectNext("Hello 10 async")
.verifyComplete();
});
}
private void runTest(Runnable testEcho) {
Mono.fromRunnable(testEcho)
.doOnError(ex -> result.onError(ex))
.doOnSuccess(o -> result.onComplete())
.subscribeOn(Schedulers.elastic())
.subscribe();
}
private static Payload payload(String destination, String data) {
return DefaultPayload.create(data, destination);
}
public void reset() {
this.result = MonoProcessor.create();
}
public void await(Duration duration) {
this.result.block(duration);
}
}
private static class ClientController {
final ReplayProcessor<String> fireForgetPayloads = ReplayProcessor.create();
@MessageMapping("receive")
void receive(String payload) {
this.fireForgetPayloads.onNext(payload);
}
@MessageMapping("echo")
String echo(String payload) {
return payload;
}
@MessageMapping("echo-async")
Mono<String> echoAsync(String payload) {
return Mono.delay(Duration.ofMillis(10)).map(aLong -> payload + " async");
}
@MessageMapping("echo-stream")
Flux<String> echoStream(String payload) {
return Flux.interval(Duration.ofMillis(10)).map(aLong -> payload + " " + aLong);
}
@MessageMapping("echo-channel")
Flux<String> echoChannel(Flux<String> payloads) {
return payloads.delayElements(Duration.ofMillis(10)).map(payload -> payload + " async");
}
}
@Configuration
static class ServerConfig {
@Bean
public ClientController clientController() {
return new ClientController();
}
@Bean
public ServerController serverController() {
return new ServerController();
}
@Bean
public ReactiveSubscribableChannel clientChannel() {
return new DefaultReactiveMessageChannel();
}
@Bean
public ReactiveSubscribableChannel serverChannel() {
return new DefaultReactiveMessageChannel();
}
@Bean
public RSocketMessageHandler clientMessageHandler() {
List<Object> handlers = Collections.singletonList(clientController());
RSocketMessageHandler handler = new RSocketMessageHandler(clientChannel(), handlers);
addDefaultCodecs(handler);
return handler;
}
@Bean
public RSocketMessageHandler serverMessageHandler() {
RSocketMessageHandler handler = new RSocketMessageHandler(serverChannel());
addDefaultCodecs(handler);
return handler;
}
private void addDefaultCodecs(RSocketMessageHandler handler) {
handler.setDecoders(Collections.singletonList(StringDecoder.allMimeTypes()));
handler.setEncoders(Collections.singletonList(CharSequenceEncoder.allMimeTypes()));
}
}
}