MessagingAcceptor/RSocket refinements + upgrade to 0.11.17

See gh-21987
This commit is contained in:
Rossen Stoyanchev 2019-02-25 12:56:32 -05:00
parent 8bdd709683
commit d6f4ec8c33
10 changed files with 108 additions and 119 deletions

View File

@ -7,7 +7,7 @@ dependencyManagement {
}
}
def rsocketVersion = "0.11.15"
def rsocketVersion = "0.11.17"
dependencies {
compile(project(":spring-beans"))

View File

@ -149,9 +149,14 @@ final class DefaultRSocketRequester implements RSocketRequester {
.concatMap(value -> encodeValue(value, dataType, encoder))
.switchOnFirst((signal, inner) -> {
DataBuffer data = signal.get();
return data != null ?
Flux.concat(Mono.just(firstPayload(data)), inner.skip(1).map(PayloadUtils::asPayload)) :
inner.map(PayloadUtils::asPayload);
if (data != null) {
return Flux.concat(
Mono.just(firstPayload(data)),
inner.skip(1).map(PayloadUtils::createPayload));
}
else {
return inner.map(PayloadUtils::createPayload);
}
})
.switchIfEmpty(emptyPayload());
return new DefaultResponseSpec(payloadFlux);
@ -167,7 +172,7 @@ final class DefaultRSocketRequester implements RSocketRequester {
}
private Payload firstPayload(DataBuffer data) {
return PayloadUtils.asPayload(getMetadata(), data);
return PayloadUtils.createPayload(getMetadata(), data);
}
private Mono<Payload> emptyPayload() {
@ -239,7 +244,7 @@ final class DefaultRSocketRequester implements RSocketRequester {
Decoder<?> decoder = strategies.decoder(elementType, dataMimeType);
return (Mono<T>) decoder.decodeToMono(
payloadMono.map(this::asDataBuffer), elementType, dataMimeType, EMPTY_HINTS);
payloadMono.map(this::wrapPayloadData), elementType, dataMimeType, EMPTY_HINTS);
}
@SuppressWarnings("unchecked")
@ -255,12 +260,12 @@ final class DefaultRSocketRequester implements RSocketRequester {
Decoder<?> decoder = strategies.decoder(elementType, dataMimeType);
return payloadFlux.map(this::asDataBuffer).concatMap(dataBuffer ->
return payloadFlux.map(this::wrapPayloadData).concatMap(dataBuffer ->
(Mono<T>) decoder.decodeToMono(Mono.just(dataBuffer), elementType, dataMimeType, EMPTY_HINTS));
}
private DataBuffer asDataBuffer(Payload payload) {
return PayloadUtils.asDataBuffer(payload, strategies.dataBufferFactory());
private DataBuffer wrapPayloadData(Payload payload) {
return PayloadUtils.wrapPayloadData(payload, strategies.dataBufferFactory());
}
}

View File

@ -28,7 +28,6 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.ReactiveMessageChannel;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
/**
* RSocket acceptor for
@ -79,10 +78,9 @@ public final class MessagingAcceptor implements SocketAcceptor, Function<RSocket
/**
* Configure the default content type for data payloads. For server
* acceptors this is available from the {@link ConnectionSetupPayload} but
* for client acceptors it's not and must be provided here.
* <p>By default this is not set.
* Configure the default content type to use for data payloads.
* <p>By default this is not set. However a server acceptor will use the
* content type from the {@link ConnectionSetupPayload}.
* @param defaultDataMimeType the MimeType to use
*/
public void setDefaultDataMimeType(@Nullable MimeType defaultDataMimeType) {
@ -92,21 +90,18 @@ public final class MessagingAcceptor implements SocketAcceptor, Function<RSocket
@Override
public Mono<RSocket> accept(ConnectionSetupPayload setupPayload, RSocket sendingRSocket) {
MimeType mimeType = setupPayload.dataMimeType() != null ?
MimeTypeUtils.parseMimeType(setupPayload.dataMimeType()) : this.defaultDataMimeType;
MessagingRSocket rsocket = createRSocket(sendingRSocket, mimeType);
return rsocket.afterConnectionEstablished(setupPayload).then(Mono.just(rsocket));
MessagingRSocket rsocket = createRSocket(sendingRSocket);
rsocket.handleConnectionSetupPayload(setupPayload).subscribe();
return Mono.just(rsocket);
}
@Override
public RSocket apply(RSocket sendingRSocket) {
return createRSocket(sendingRSocket, this.defaultDataMimeType);
return createRSocket(sendingRSocket);
}
private MessagingRSocket createRSocket(RSocket sendingRSocket, @Nullable MimeType dataMimeType) {
return new MessagingRSocket(this.messageChannel, sendingRSocket, dataMimeType, this.rsocketStrategies);
private MessagingRSocket createRSocket(RSocket rsocket) {
return new MessagingRSocket(this.messageChannel, rsocket, this.defaultDataMimeType, this.rsocketStrategies);
}
}

View File

@ -17,6 +17,7 @@ package org.springframework.messaging.rsocket;
import java.util.function.Function;
import io.rsocket.AbstractRSocket;
import io.rsocket.ConnectionSetupPayload;
import io.rsocket.Payload;
import io.rsocket.RSocket;
@ -40,6 +41,8 @@ import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
import org.springframework.util.StringUtils;
/**
* Package private implementation of {@link RSocket} that is is hooked into an
@ -49,90 +52,96 @@ import org.springframework.util.MimeType;
* @author Rossen Stoyanchev
* @since 5.2
*/
class MessagingRSocket implements RSocket {
class MessagingRSocket extends AbstractRSocket {
private final ReactiveMessageChannel messageChannel;
private final RSocketRequester requester;
@Nullable
private final MimeType dataMimeType;
private MimeType dataMimeType;
private final RSocketStrategies strategies;
MessagingRSocket(ReactiveMessageChannel messageChannel,
RSocket sendingRSocket, @Nullable MimeType dataMimeType, RSocketStrategies strategies) {
RSocket sendingRSocket, @Nullable MimeType defaultDataMimeType, RSocketStrategies strategies) {
Assert.notNull(messageChannel, "'messageChannel' is required");
Assert.notNull(sendingRSocket, "'sendingRSocket' is required");
this.messageChannel = messageChannel;
this.requester = RSocketRequester.create(sendingRSocket, dataMimeType, strategies);
this.dataMimeType = dataMimeType;
this.requester = RSocketRequester.create(sendingRSocket, defaultDataMimeType, strategies);
this.dataMimeType = defaultDataMimeType;
this.strategies = strategies;
}
public Mono<Void> afterConnectionEstablished(ConnectionSetupPayload payload) {
return execute(payload).flatMap(flux -> flux.take(0).then());
public Mono<Void> handleConnectionSetupPayload(ConnectionSetupPayload payload) {
if (StringUtils.hasText(payload.dataMimeType())) {
this.dataMimeType = MimeTypeUtils.parseMimeType(payload.dataMimeType());
}
return handle(payload);
}
@Override
public Mono<Void> fireAndForget(Payload payload) {
return execute(payload).flatMap(flux -> flux.take(0).then());
return handle(payload);
}
@Override
public Mono<Payload> requestResponse(Payload payload) {
return execute(payload).flatMap(Flux::next);
return handleAndReply(payload, Flux.just(payload)).next();
}
@Override
public Flux<Payload> requestStream(Payload payload) {
return execute(payload).flatMapMany(Function.identity());
return handleAndReply(payload, Flux.just(payload));
}
@Override
public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
return Flux.from(payloads)
.switchOnFirst((signal, inner) -> {
Payload first = signal.get();
return first != null ? execute(first, inner).flatMapMany(Function.identity()) : inner;
.switchOnFirst((signal, innerFlux) -> {
Payload firstPayload = signal.get();
return firstPayload == null ? innerFlux : handleAndReply(firstPayload, innerFlux);
});
}
@Override
public Mono<Void> metadataPush(Payload payload) {
return null;
// This won't be very useful until createHeaders starting doing something more with metadata..
return handle(payload);
}
private Mono<Flux<Payload>> execute(Payload payload) {
return execute(payload, Flux.just(payload));
}
private Mono<Flux<Payload>> execute(Payload firstPayload, Flux<Payload> payloads) {
private Mono<Void> handle(Payload payload) {
// TODO:
// Since we do retain(), we need to ensure buffers are released if not consumed,
// e.g. error before Flux subscribed to, no handler found, @MessageMapping ignores payload, etc.
Flux<DataBuffer> payloadDataBuffers = payloads
.map(payload -> PayloadUtils.asDataBuffer(payload, this.strategies.dataBufferFactory()))
.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
MonoProcessor<Flux<Payload>> replyMono = MonoProcessor.create();
MessageHeaders headers = createHeaders(firstPayload, replyMono);
Message<?> message = MessageBuilder.createMessage(payloadDataBuffers, headers);
Message<?> message = MessageBuilder.createMessage(
Mono.fromCallable(() -> wrapPayloadData(payload)),
createHeaders(payload, null));
return this.messageChannel.send(message).flatMap(result -> result ?
replyMono.isTerminated() ? replyMono : Mono.empty() :
Mono.error(new MessageDeliveryException("RSocket interaction not handled")));
Mono.empty() : Mono.error(new MessageDeliveryException("RSocket request not handled")));
}
private MessageHeaders createHeaders(Payload payload, MonoProcessor<?> replyMono) {
private Flux<Payload> handleAndReply(Payload firstPayload, Flux<Payload> payloads) {
MonoProcessor<Flux<Payload>> replyMono = MonoProcessor.create();
Message<?> message = MessageBuilder.createMessage(
payloads.map(this::wrapPayloadData).doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release),
createHeaders(firstPayload, replyMono));
return this.messageChannel.send(message).flatMapMany(result ->
result && replyMono.isTerminated() ? replyMono.flatMapMany(Function.identity()) :
Mono.error(new MessageDeliveryException("RSocket request not handled")));
}
private MessageHeaders createHeaders(Payload payload, @Nullable MonoProcessor<?> replyMono) {
// TODO:
// For now treat the metadata as a simple string with routing information.
// We'll have to get more sophisticated once the routing extension is completed.
// https://github.com/rsocket/rsocket-java/issues/568
@ -147,7 +156,10 @@ class MessagingRSocket implements RSocket {
}
headers.setHeader(RSocketRequesterMethodArgumentResolver.RSOCKET_REQUESTER_HEADER, this.requester);
headers.setHeader(RSocketPayloadReturnValueHandler.RESPONSE_HEADER, replyMono);
if (replyMono != null) {
headers.setHeader(RSocketPayloadReturnValueHandler.RESPONSE_HEADER, replyMono);
}
DataBufferFactory bufferFactory = this.strategies.dataBufferFactory();
headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, bufferFactory);
@ -155,13 +167,8 @@ class MessagingRSocket implements RSocket {
return headers.getMessageHeaders();
}
@Override
public Mono<Void> onClose() {
return null;
}
@Override
public void dispose() {
private DataBuffer wrapPayloadData(Payload payload) {
return PayloadUtils.wrapPayloadData(payload, this.strategies.dataBufferFactory());
}
}

View File

@ -15,9 +15,6 @@
*/
package org.springframework.messaging.rsocket;
import java.nio.ByteBuffer;
import io.netty.buffer.ByteBuf;
import io.rsocket.Payload;
import io.rsocket.util.ByteBufPayload;
import io.rsocket.util.DefaultPayload;
@ -44,7 +41,7 @@ abstract class PayloadUtils {
* @param bufferFactory the BufferFactory to use to wrap
* @return the DataBuffer wrapper
*/
public static DataBuffer asDataBuffer(Payload payload, DataBufferFactory bufferFactory) {
public static DataBuffer wrapPayloadData(Payload payload, DataBufferFactory bufferFactory) {
if (bufferFactory instanceof NettyDataBufferFactory) {
return ((NettyDataBufferFactory) bufferFactory).wrap(payload.retain().sliceData());
}
@ -59,12 +56,16 @@ abstract class PayloadUtils {
* @param data the data part for the payload
* @return the created Payload
*/
public static Payload asPayload(DataBuffer metadata, DataBuffer data) {
public static Payload createPayload(DataBuffer metadata, DataBuffer data) {
if (metadata instanceof NettyDataBuffer && data instanceof NettyDataBuffer) {
return ByteBufPayload.create(getByteBuf(data), getByteBuf(metadata));
return ByteBufPayload.create(
((NettyDataBuffer) data).getNativeBuffer(),
((NettyDataBuffer) metadata).getNativeBuffer());
}
else if (metadata instanceof DefaultDataBuffer && data instanceof DefaultDataBuffer) {
return DefaultPayload.create(getByteBuffer(data), getByteBuffer(metadata));
return DefaultPayload.create(
((DefaultDataBuffer) data).getNativeBuffer(),
((DefaultDataBuffer) metadata).getNativeBuffer());
}
else {
return DefaultPayload.create(data.asByteBuffer(), metadata.asByteBuffer());
@ -76,24 +77,16 @@ abstract class PayloadUtils {
* @param data the data part for the payload
* @return the created Payload
*/
public static Payload asPayload(DataBuffer data) {
public static Payload createPayload(DataBuffer data) {
if (data instanceof NettyDataBuffer) {
return ByteBufPayload.create(getByteBuf(data));
return ByteBufPayload.create(((NettyDataBuffer) data).getNativeBuffer());
}
else if (data instanceof DefaultDataBuffer) {
return DefaultPayload.create(getByteBuffer(data));
return DefaultPayload.create(((DefaultDataBuffer) data).getNativeBuffer());
}
else {
return DefaultPayload.create(data.asByteBuffer());
}
}
private static ByteBuf getByteBuf(DataBuffer dataBuffer) {
return ((NettyDataBuffer) dataBuffer).getNativeBuffer();
}
private static
ByteBuffer getByteBuffer(DataBuffer dataBuffer) {
return ((DefaultDataBuffer) dataBuffer).getNativeBuffer();
}
}

View File

@ -21,13 +21,10 @@ import java.util.List;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.ReactiveSubscribableChannel;
import org.springframework.messaging.handler.annotation.support.reactive.MessageMappingMessageHandler;
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* RSocket-specific extension of {@link MessageMappingMessageHandler}.
@ -124,14 +121,4 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler {
return handlers;
}
@Override
protected void handleNoMatch(@Nullable String destination, Message<?> message) {
// Ignore empty destination, probably the ConnectionSetupPayload
if (!StringUtils.isEmpty(destination)) {
super.handleNoMatch(destination, message);
throw new MessageDeliveryException("No handler for '" + destination + "'");
}
}
}

View File

@ -63,7 +63,7 @@ public class RSocketPayloadReturnValueHandler extends AbstractEncoderMethodRetur
Assert.isInstanceOf(MonoProcessor.class, headerValue, "Expected MonoProcessor");
MonoProcessor<Flux<Payload>> monoProcessor = (MonoProcessor<Flux<Payload>>) headerValue;
monoProcessor.onNext(encodedContent.map(PayloadUtils::asPayload));
monoProcessor.onNext(encodedContent.map(PayloadUtils::createPayload));
monoProcessor.onComplete();
return Mono.empty();

View File

@ -199,7 +199,7 @@ public class DefaultRSocketRequesterTests {
}
private Payload toPayload(String value) {
return PayloadUtils.asPayload(bufferFactory.wrap(value.getBytes(StandardCharsets.UTF_8)));
return PayloadUtils.createPayload(bufferFactory.wrap(value.getBytes(StandardCharsets.UTF_8)));
}

View File

@ -35,6 +35,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.codec.StringDecoder;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.ReactiveMessageChannel;
import org.springframework.messaging.ReactiveSubscribableChannel;
import org.springframework.messaging.handler.annotation.MessageMapping;
@ -169,6 +170,12 @@ public class RSocketClientToServerIntegrationTests {
.verifyComplete();
}
@Test
public void noMatchingRoute() {
Mono<String> result = requester.route("invalid").data("anything").retrieveMono(String.class);
StepVerifier.create(result).verifyErrorMessage("RSocket request not handled");
}
@Controller
static class ServerController {

View File

@ -20,7 +20,6 @@ import java.util.Collections;
import java.util.List;
import io.rsocket.Closeable;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.RSocketFactory;
import io.rsocket.transport.netty.client.TcpClientTransport;
@ -140,13 +139,22 @@ public class RSocketServerToClientIntegrationTests {
volatile MonoProcessor<Void> result;
public void reset() {
this.result = MonoProcessor.create();
}
public void await(Duration duration) {
this.result.block(duration);
}
@MessageMapping("connect.echo")
void echo(RSocketRequester requester) {
runTest(() -> {
Flux<String> result = Flux.range(1, 3).concatMap(i ->
Flux<String> flux = Flux.range(1, 3).concatMap(i ->
requester.route("echo").data("Hello " + i).retrieveMono(String.class));
StepVerifier.create(result)
StepVerifier.create(flux)
.expectNext("Hello 1")
.expectNext("Hello 2")
.expectNext("Hello 3")
@ -157,10 +165,10 @@ public class RSocketServerToClientIntegrationTests {
@MessageMapping("connect.echo-async")
void echoAsync(RSocketRequester requester) {
runTest(() -> {
Flux<String> result = Flux.range(1, 3).concatMap(i ->
Flux<String> flux = Flux.range(1, 3).concatMap(i ->
requester.route("echo-async").data("Hello " + i).retrieveMono(String.class));
StepVerifier.create(result)
StepVerifier.create(flux)
.expectNext("Hello 1 async")
.expectNext("Hello 2 async")
.expectNext("Hello 3 async")
@ -171,9 +179,9 @@ public class RSocketServerToClientIntegrationTests {
@MessageMapping("connect.echo-stream")
void echoStream(RSocketRequester requester) {
runTest(() -> {
Flux<String> result = requester.route("echo-stream").data("Hello").retrieveFlux(String.class);
Flux<String> flux = requester.route("echo-stream").data("Hello").retrieveFlux(String.class);
StepVerifier.create(result)
StepVerifier.create(flux)
.expectNext("Hello 0")
.expectNextCount(5)
.expectNext("Hello 6")
@ -186,11 +194,11 @@ public class RSocketServerToClientIntegrationTests {
@MessageMapping("connect.echo-channel")
void echoChannel(RSocketRequester requester) {
runTest(() -> {
Flux<String> result = requester.route("echo-channel")
Flux<String> flux = requester.route("echo-channel")
.data(Flux.range(1, 10).map(i -> "Hello " + i), String.class)
.retrieveFlux(String.class);
StepVerifier.create(result)
StepVerifier.create(flux)
.expectNext("Hello 1 async")
.expectNextCount(7)
.expectNext("Hello 9 async")
@ -207,19 +215,6 @@ public class RSocketServerToClientIntegrationTests {
.subscribeOn(Schedulers.elastic())
.subscribe();
}
private static Payload payload(String destination, String data) {
return DefaultPayload.create(data, destination);
}
public void reset() {
this.result = MonoProcessor.create();
}
public void await(Duration duration) {
this.result.block(duration);
}
}