RSocketRequester, RSocketStrategies, PayloadUtils

See gh-21987
This commit is contained in:
Rossen Stoyanchev 2019-02-18 16:41:29 -05:00
parent 4e78b5df2f
commit 8bdd709683
14 changed files with 1263 additions and 115 deletions

View File

@ -31,8 +31,8 @@ import org.springframework.messaging.Message;
*/
public interface HandlerMethodReturnValueHandler {
/** Header containing a DataBufferFactory to use. */
public static final String DATA_BUFFER_FACTORY_HEADER = "dataBufferFactoryHeader";
/** Header containing a DataBufferFactory for use in return value handling. */
String DATA_BUFFER_FACTORY_HEADER = "dataBufferFactory";
/**

View File

@ -0,0 +1,267 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
/**
* Default, package-private {@link RSocketRequester} implementation.
*
* @author Rossen Stoyanchev
* @since 5.2
*/
final class DefaultRSocketRequester implements RSocketRequester {
private static final Map<String, Object> EMPTY_HINTS = Collections.emptyMap();
private final RSocket rsocket;
@Nullable
private final MimeType dataMimeType;
private final RSocketStrategies strategies;
private DataBuffer emptyDataBuffer;
DefaultRSocketRequester(RSocket rsocket, @Nullable MimeType dataMimeType, RSocketStrategies strategies) {
Assert.notNull(rsocket, "RSocket is required");
Assert.notNull(strategies, "RSocketStrategies is required");
this.rsocket = rsocket;
this.dataMimeType = dataMimeType;
this.strategies = strategies;
this.emptyDataBuffer = this.strategies.dataBufferFactory().wrap(new byte[0]);
}
@Override
public RSocket rsocket() {
return this.rsocket;
}
@Override
public RequestSpec route(String route) {
return new DefaultRequestSpec(route);
}
private static boolean isVoid(ResolvableType elementType) {
return Void.class.equals(elementType.resolve()) || void.class.equals(elementType.resolve());
}
private class DefaultRequestSpec implements RequestSpec {
private final String route;
DefaultRequestSpec(String route) {
this.route = route;
}
@Override
public ResponseSpec data(Object data) {
Assert.notNull(data, "'data' must not be null");
return toResponseSpec(data, ResolvableType.NONE);
}
@Override
public <T, P extends Publisher<T>> ResponseSpec data(P publisher, Class<T> dataType) {
Assert.notNull(publisher, "'publisher' must not be null");
Assert.notNull(dataType, "'dataType' must not be null");
return toResponseSpec(publisher, ResolvableType.forClass(dataType));
}
@Override
public <T, P extends Publisher<T>> ResponseSpec data(P publisher, ParameterizedTypeReference<T> dataTypeRef) {
Assert.notNull(publisher, "'publisher' must not be null");
Assert.notNull(dataTypeRef, "'dataTypeRef' must not be null");
return toResponseSpec(publisher, ResolvableType.forType(dataTypeRef));
}
private ResponseSpec toResponseSpec(Object input, ResolvableType dataType) {
ReactiveAdapter adapter = strategies.reactiveAdapterRegistry().getAdapter(input.getClass());
Publisher<?> publisher;
if (input instanceof Publisher) {
publisher = (Publisher<?>) input;
}
else if (adapter != null) {
publisher = adapter.toPublisher(input);
}
else {
Mono<Payload> payloadMono = encodeValue(input, ResolvableType.forInstance(input), null)
.map(this::firstPayload)
.switchIfEmpty(emptyPayload());
return new DefaultResponseSpec(payloadMono);
}
if (isVoid(dataType) || (adapter != null && adapter.isNoValue())) {
Mono<Payload> payloadMono = Mono.when(publisher).then(emptyPayload());
return new DefaultResponseSpec(payloadMono);
}
Encoder<?> encoder = dataType != ResolvableType.NONE && !Object.class.equals(dataType.resolve()) ?
strategies.encoder(dataType, dataMimeType) : null;
if (adapter != null && !adapter.isMultiValue()) {
Mono<Payload> payloadMono = Mono.from(publisher)
.flatMap(value -> encodeValue(value, dataType, encoder))
.map(this::firstPayload)
.switchIfEmpty(emptyPayload());
return new DefaultResponseSpec(payloadMono);
}
Flux<Payload> payloadFlux = Flux.from(publisher)
.concatMap(value -> encodeValue(value, dataType, encoder))
.switchOnFirst((signal, inner) -> {
DataBuffer data = signal.get();
return data != null ?
Flux.concat(Mono.just(firstPayload(data)), inner.skip(1).map(PayloadUtils::asPayload)) :
inner.map(PayloadUtils::asPayload);
})
.switchIfEmpty(emptyPayload());
return new DefaultResponseSpec(payloadFlux);
}
@SuppressWarnings("unchecked")
private <T> Mono<DataBuffer> encodeValue(T value, ResolvableType valueType, @Nullable Encoder<?> encoder) {
if (encoder == null) {
encoder = strategies.encoder(ResolvableType.forInstance(value), dataMimeType);
}
return DataBufferUtils.join(((Encoder<T>) encoder).encode(
Mono.just(value), strategies.dataBufferFactory(), valueType, dataMimeType, EMPTY_HINTS));
}
private Payload firstPayload(DataBuffer data) {
return PayloadUtils.asPayload(getMetadata(), data);
}
private Mono<Payload> emptyPayload() {
return Mono.fromCallable(() -> firstPayload(emptyDataBuffer));
}
private DataBuffer getMetadata() {
return strategies.dataBufferFactory().wrap(this.route.getBytes(StandardCharsets.UTF_8));
}
}
private class DefaultResponseSpec implements ResponseSpec {
@Nullable
private final Mono<Payload> payloadMono;
@Nullable
private final Flux<Payload> payloadFlux;
DefaultResponseSpec(Mono<Payload> payloadMono) {
this.payloadMono = payloadMono;
this.payloadFlux = null;
}
DefaultResponseSpec(Flux<Payload> payloadFlux) {
this.payloadMono = null;
this.payloadFlux = payloadFlux;
}
@Override
public Mono<Void> send() {
Assert.notNull(this.payloadMono, "No RSocket interaction model for one-way send with Flux.");
return this.payloadMono.flatMap(rsocket::fireAndForget);
}
@Override
public <T> Mono<T> retrieveMono(Class<T> dataType) {
return retrieveMono(ResolvableType.forClass(dataType));
}
@Override
public <T> Mono<T> retrieveMono(ParameterizedTypeReference<T> dataTypeRef) {
return retrieveMono(ResolvableType.forType(dataTypeRef));
}
@Override
public <T> Flux<T> retrieveFlux(Class<T> dataType) {
return retrieveFlux(ResolvableType.forClass(dataType));
}
@Override
public <T> Flux<T> retrieveFlux(ParameterizedTypeReference<T> dataTypeRef) {
return retrieveFlux(ResolvableType.forType(dataTypeRef));
}
@SuppressWarnings("unchecked")
private <T> Mono<T> retrieveMono(ResolvableType elementType) {
Assert.notNull(this.payloadMono,
"No RSocket interaction model for Flux request to Mono response.");
Mono<Payload> payloadMono = this.payloadMono.flatMap(rsocket::requestResponse);
if (isVoid(elementType)) {
return (Mono<T>) payloadMono.then();
}
Decoder<?> decoder = strategies.decoder(elementType, dataMimeType);
return (Mono<T>) decoder.decodeToMono(
payloadMono.map(this::asDataBuffer), elementType, dataMimeType, EMPTY_HINTS);
}
@SuppressWarnings("unchecked")
private <T> Flux<T> retrieveFlux(ResolvableType elementType) {
Flux<Payload> payloadFlux = this.payloadMono != null ?
this.payloadMono.flatMapMany(rsocket::requestStream) :
rsocket.requestChannel(this.payloadFlux);
if (isVoid(elementType)) {
return payloadFlux.thenMany(Flux.empty());
}
Decoder<?> decoder = strategies.decoder(elementType, dataMimeType);
return payloadFlux.map(this::asDataBuffer).concatMap(dataBuffer ->
(Mono<T>) decoder.decodeToMono(Mono.just(dataBuffer), elementType, dataMimeType, EMPTY_HINTS));
}
private DataBuffer asDataBuffer(Payload payload) {
return PayloadUtils.asDataBuffer(payload, strategies.dataBufferFactory());
}
}
}

View File

@ -0,0 +1,144 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Consumer;
import io.netty.buffer.PooledByteBufAllocator;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.lang.Nullable;
/**
* Default, package-private {@link RSocketStrategies} implementation.
*
* @author Rossen Stoyanchev
* @since 5.2
*/
final class DefaultRSocketStrategies implements RSocketStrategies {
private final List<Encoder<?>> encoders;
private final List<Decoder<?>> decoders;
private final ReactiveAdapterRegistry adapterRegistry;
private final DataBufferFactory bufferFactory;
private DefaultRSocketStrategies(
List<Encoder<?>> encoders, List<Decoder<?>> decoders,
ReactiveAdapterRegistry adapterRegistry, DataBufferFactory bufferFactory) {
this.encoders = Collections.unmodifiableList(encoders);
this.decoders = Collections.unmodifiableList(decoders);
this.adapterRegistry = adapterRegistry;
this.bufferFactory = bufferFactory;
}
@Override
public List<Encoder<?>> encoders() {
return this.encoders;
}
@Override
public List<Decoder<?>> decoders() {
return this.decoders;
}
@Override
public ReactiveAdapterRegistry reactiveAdapterRegistry() {
return this.adapterRegistry;
}
@Override
public DataBufferFactory dataBufferFactory() {
return this.bufferFactory;
}
/**
* Default RSocketStrategies.Builder implementation.
*/
static class DefaultRSocketStrategiesBuilder implements RSocketStrategies.Builder {
private final List<Encoder<?>> encoders = new ArrayList<>();
private final List<Decoder<?>> decoders = new ArrayList<>();
@Nullable
private ReactiveAdapterRegistry adapterRegistry;
@Nullable
private DataBufferFactory bufferFactory;
@Override
public Builder encoder(Encoder<?>... encoders) {
this.encoders.addAll(Arrays.asList(encoders));
return this;
}
@Override
public Builder decoder(Decoder<?>... decoder) {
this.decoders.addAll(Arrays.asList(decoder));
return this;
}
@Override
public Builder encoders(Consumer<List<Encoder<?>>> consumer) {
consumer.accept(this.encoders);
return this;
}
@Override
public Builder decoders(Consumer<List<Decoder<?>>> consumer) {
consumer.accept(this.decoders);
return this;
}
@Override
public Builder reactiveAdapterStrategy(ReactiveAdapterRegistry registry) {
this.adapterRegistry = registry;
return this;
}
@Override
public Builder dataBufferFactory(DataBufferFactory bufferFactory) {
this.bufferFactory = bufferFactory;
return this;
}
@Override
public RSocketStrategies build() {
return new DefaultRSocketStrategies(this.encoders, this.decoders,
this.adapterRegistry != null ?
this.adapterRegistry : ReactiveAdapterRegistry.getSharedInstance(),
this.bufferFactory != null ? this.bufferFactory :
new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT));
}
}
}

View File

@ -18,13 +18,11 @@ package org.springframework.messaging.rsocket;
import java.util.function.Function;
import java.util.function.Predicate;
import io.netty.buffer.PooledByteBufAllocator;
import io.rsocket.ConnectionSetupPayload;
import io.rsocket.RSocket;
import io.rsocket.SocketAcceptor;
import reactor.core.publisher.Mono;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.ReactiveMessageChannel;
@ -47,7 +45,7 @@ public final class MessagingAcceptor implements SocketAcceptor, Function<RSocket
private final ReactiveMessageChannel messageChannel;
private NettyDataBufferFactory bufferFactory = new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT);
private final RSocketStrategies rsocketStrategies;
@Nullable
private MimeType defaultDataMimeType;
@ -64,8 +62,19 @@ public final class MessagingAcceptor implements SocketAcceptor, Function<RSocket
* or with handler instances.
*/
public MessagingAcceptor(ReactiveMessageChannel messageChannel) {
this(messageChannel, RSocketStrategies.builder().build());
}
/**
* Variant of {@link #MessagingAcceptor(ReactiveMessageChannel)} with an
* {@link RSocketStrategies} for wrapping the sending {@link RSocket} as
* {@link RSocketRequester}.
*/
public MessagingAcceptor(ReactiveMessageChannel messageChannel, RSocketStrategies rsocketStrategies) {
Assert.notNull(messageChannel, "ReactiveMessageChannel is required");
Assert.notNull(rsocketStrategies, "RSocketStrategies is required");
this.messageChannel = messageChannel;
this.rsocketStrategies = rsocketStrategies;
}
@ -80,17 +89,6 @@ public final class MessagingAcceptor implements SocketAcceptor, Function<RSocket
this.defaultDataMimeType = defaultDataMimeType;
}
/**
* Configure the buffer factory to use.
* <p>By default this is initialized with the allocator instance
* {@link PooledByteBufAllocator#DEFAULT}.
* @param bufferFactory the bufferFactory to use
*/
public void setNettyDataBufferFactory(NettyDataBufferFactory bufferFactory) {
Assert.notNull(bufferFactory, "DataBufferFactory is required");
this.bufferFactory = bufferFactory;
}
@Override
public Mono<RSocket> accept(ConnectionSetupPayload setupPayload, RSocket sendingRSocket) {
@ -108,7 +106,7 @@ public final class MessagingAcceptor implements SocketAcceptor, Function<RSocket
}
private MessagingRSocket createRSocket(RSocket sendingRSocket, @Nullable MimeType dataMimeType) {
return new MessagingRSocket(this.messageChannel, this.bufferFactory, sendingRSocket, dataMimeType);
return new MessagingRSocket(this.messageChannel, sendingRSocket, dataMimeType, this.rsocketStrategies);
}
}

View File

@ -25,9 +25,9 @@ import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBuffer;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.core.io.buffer.PooledDataBuffer;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
@ -42,8 +42,9 @@ import org.springframework.util.Assert;
import org.springframework.util.MimeType;
/**
* Package private implementation of {@link RSocket} used from
* {@link MessagingAcceptor}.
* Package private implementation of {@link RSocket} that is is hooked into an
* RSocket client or server via {@link MessagingAcceptor} to accept and handle
* requests.
*
* @author Rossen Stoyanchev
* @since 5.2
@ -52,24 +53,23 @@ class MessagingRSocket implements RSocket {
private final ReactiveMessageChannel messageChannel;
private final NettyDataBufferFactory bufferFactory;
private final RSocket sendingRSocket;
private final RSocketRequester requester;
@Nullable
private final MimeType dataMimeType;
private final RSocketStrategies strategies;
MessagingRSocket(ReactiveMessageChannel messageChannel, NettyDataBufferFactory bufferFactory,
RSocket sendingRSocket, @Nullable MimeType dataMimeType) {
MessagingRSocket(ReactiveMessageChannel messageChannel,
RSocket sendingRSocket, @Nullable MimeType dataMimeType, RSocketStrategies strategies) {
Assert.notNull(messageChannel, "'messageChannel' is required");
Assert.notNull(bufferFactory, "'bufferFactory' is required");
Assert.notNull(sendingRSocket, "'sendingRSocket' is required");
this.messageChannel = messageChannel;
this.bufferFactory = bufferFactory;
this.sendingRSocket = sendingRSocket;
this.requester = RSocketRequester.create(sendingRSocket, dataMimeType, strategies);
this.dataMimeType = dataMimeType;
this.strategies = strategies;
}
@ -117,8 +117,8 @@ class MessagingRSocket implements RSocket {
// Since we do retain(), we need to ensure buffers are released if not consumed,
// e.g. error before Flux subscribed to, no handler found, @MessageMapping ignores payload, etc.
Flux<NettyDataBuffer> payloadDataBuffers = payloads
.map(payload -> this.bufferFactory.wrap(payload.retain().sliceData()))
Flux<DataBuffer> payloadDataBuffers = payloads
.map(payload -> PayloadUtils.asDataBuffer(payload, this.strategies.dataBufferFactory()))
.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
MonoProcessor<Flux<Payload>> replyMono = MonoProcessor.create();
@ -146,9 +146,11 @@ class MessagingRSocket implements RSocket {
headers.setContentType(this.dataMimeType);
}
headers.setHeader(SendingRSocketMethodArgumentResolver.SENDING_RSOCKET_HEADER, this.sendingRSocket);
headers.setHeader(RSocketRequesterMethodArgumentResolver.RSOCKET_REQUESTER_HEADER, this.requester);
headers.setHeader(RSocketPayloadReturnValueHandler.RESPONSE_HEADER, replyMono);
headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, this.bufferFactory);
DataBufferFactory bufferFactory = this.strategies.dataBufferFactory();
headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, bufferFactory);
return headers.getMessageHeaders();
}

View File

@ -0,0 +1,99 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import java.nio.ByteBuffer;
import io.netty.buffer.ByteBuf;
import io.rsocket.Payload;
import io.rsocket.util.ByteBufPayload;
import io.rsocket.util.DefaultPayload;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DefaultDataBuffer;
import org.springframework.core.io.buffer.NettyDataBuffer;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
/**
* Static utility methods to create {@link Payload} from {@link DataBuffer}s
* and vice versa.
*
* @author Rossen Stoyanchev
* @since 5.2
*/
abstract class PayloadUtils {
/**
* Return the Payload data wrapped as DataBuffer. If the bufferFactory is
* {@link NettyDataBufferFactory} the payload retained and sliced.
* @param payload the input payload
* @param bufferFactory the BufferFactory to use to wrap
* @return the DataBuffer wrapper
*/
public static DataBuffer asDataBuffer(Payload payload, DataBufferFactory bufferFactory) {
if (bufferFactory instanceof NettyDataBufferFactory) {
return ((NettyDataBufferFactory) bufferFactory).wrap(payload.retain().sliceData());
}
else {
return bufferFactory.wrap(payload.getData());
}
}
/**
* Create a Payload from the given metadata and data.
* @param metadata the metadata part for the payload
* @param data the data part for the payload
* @return the created Payload
*/
public static Payload asPayload(DataBuffer metadata, DataBuffer data) {
if (metadata instanceof NettyDataBuffer && data instanceof NettyDataBuffer) {
return ByteBufPayload.create(getByteBuf(data), getByteBuf(metadata));
}
else if (metadata instanceof DefaultDataBuffer && data instanceof DefaultDataBuffer) {
return DefaultPayload.create(getByteBuffer(data), getByteBuffer(metadata));
}
else {
return DefaultPayload.create(data.asByteBuffer(), metadata.asByteBuffer());
}
}
/**
* Create a Payload from the given data.
* @param data the data part for the payload
* @return the created Payload
*/
public static Payload asPayload(DataBuffer data) {
if (data instanceof NettyDataBuffer) {
return ByteBufPayload.create(getByteBuf(data));
}
else if (data instanceof DefaultDataBuffer) {
return DefaultPayload.create(getByteBuffer(data));
}
else {
return DefaultPayload.create(data.asByteBuffer());
}
}
private static ByteBuf getByteBuf(DataBuffer dataBuffer) {
return ((NettyDataBuffer) dataBuffer).getNativeBuffer();
}
private static
ByteBuffer getByteBuffer(DataBuffer dataBuffer) {
return ((DefaultDataBuffer) dataBuffer).getNativeBuffer();
}
}

View File

@ -18,6 +18,7 @@ package org.springframework.messaging.rsocket;
import java.util.ArrayList;
import java.util.List;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
@ -25,6 +26,7 @@ import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.ReactiveSubscribableChannel;
import org.springframework.messaging.handler.annotation.support.reactive.MessageMappingMessageHandler;
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
@ -41,6 +43,9 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler {
private final List<Encoder<?>> encoders = new ArrayList<>();
@Nullable
private RSocketStrategies rsocketStrategies;
public RSocketMessageHandler(ReactiveSubscribableChannel inboundChannel) {
super(inboundChannel);
@ -55,6 +60,7 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler {
}
/**
* Configure the encoders to use for encoding handler method return values.
*/
@ -69,10 +75,44 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler {
return this.encoders;
}
/**
* Provide configuration in the form of {@link RSocketStrategies}. This is
* an alternative to using {@link #setEncoders(List)},
* {@link #setDecoders(List)}, and others directly. It is convenient when
* you also need to configure an {@link RSocketRequester} in which case
* the strategies can be configured once and used in multiple places.
* @param rsocketStrategies the strategies to use
*/
public void setRSocketStrategies(RSocketStrategies rsocketStrategies) {
Assert.notNull(rsocketStrategies, "RSocketStrategies must not be null");
this.rsocketStrategies = rsocketStrategies;
setDecoders(rsocketStrategies.decoders());
setEncoders(rsocketStrategies.encoders());
setReactiveAdapterRegistry(rsocketStrategies.reactiveAdapterRegistry());
}
/**
* Return the {@code RSocketStrategies} instance provided via
* {@link #setRSocketStrategies rsocketStrategies}, or
* otherwise a new instance populated with the configured
* {@link #setEncoders(List) encoders}, {@link #setDecoders(List) decoders}
* and others.
*/
public RSocketStrategies getRSocketStrategies() {
if (this.rsocketStrategies != null) {
return this.rsocketStrategies;
}
return RSocketStrategies.builder()
.decoder(getDecoders().toArray(new Decoder<?>[0]))
.encoder(getEncoders().toArray(new Encoder<?>[0]))
.reactiveAdapterStrategy(getReactiveAdapterRegistry())
.build();
}
@Override
public void afterPropertiesSet() {
getArgumentResolverConfigurer().addCustomResolver(new SendingRSocketMethodArgumentResolver());
getArgumentResolverConfigurer().addCustomResolver(new RSocketRequesterMethodArgumentResolver());
super.afterPropertiesSet();
}

View File

@ -18,8 +18,6 @@ package org.springframework.messaging.rsocket;
import java.util.List;
import io.rsocket.Payload;
import io.rsocket.util.ByteBufPayload;
import io.rsocket.util.DefaultPayload;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
@ -28,8 +26,6 @@ import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DefaultDataBuffer;
import org.springframework.core.io.buffer.NettyDataBuffer;
import org.springframework.messaging.Message;
import org.springframework.messaging.handler.invocation.reactive.AbstractEncoderMethodReturnValueHandler;
import org.springframework.util.Assert;
@ -67,22 +63,10 @@ public class RSocketPayloadReturnValueHandler extends AbstractEncoderMethodRetur
Assert.isInstanceOf(MonoProcessor.class, headerValue, "Expected MonoProcessor");
MonoProcessor<Flux<Payload>> monoProcessor = (MonoProcessor<Flux<Payload>>) headerValue;
monoProcessor.onNext(encodedContent.map(this::toPayload));
monoProcessor.onNext(encodedContent.map(PayloadUtils::asPayload));
monoProcessor.onComplete();
return Mono.empty();
}
private Payload toPayload(DataBuffer dataBuffer) {
if (dataBuffer instanceof NettyDataBuffer) {
return ByteBufPayload.create(((NettyDataBuffer) dataBuffer).getNativeBuffer());
}
else if (dataBuffer instanceof DefaultDataBuffer) {
return DefaultPayload.create(((DefaultDataBuffer) dataBuffer).getNativeBuffer());
}
else {
return DefaultPayload.create(dataBuffer.asByteBuffer());
}
}
}

View File

@ -0,0 +1,166 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import io.rsocket.RSocket;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.lang.Nullable;
import org.springframework.util.MimeType;
/**
* A thin wrapper around a sending {@link RSocket} with a fluent API accepting
* and returning higher level Objects for input and for output, along with
* methods specify routing and other metadata.
*
* @author Rossen Stoyanchev
* @since 5.2
*/
public interface RSocketRequester {
/**
* Return the underlying RSocket used to make requests.
*/
RSocket rsocket();
/**
* Create a new {@code RSocketRequester} from the given {@link RSocket} and
* strategies for encoding and decoding request and response payloads.
* @param rsocket the sending RSocket to use
* @param dataMimeType the MimeType for data (from the SETUP frame)
* @param strategies encoders, decoders, and others
* @return the created RSocketRequester wrapper
*/
static RSocketRequester create(RSocket rsocket, @Nullable MimeType dataMimeType, RSocketStrategies strategies) {
return new DefaultRSocketRequester(rsocket, dataMimeType, strategies);
}
// For now we treat metadata as a simple string that is the route.
// This will change after the resolution of:
// https://github.com/rsocket/rsocket-java/issues/568
/**
* Entry point to prepare a new request to the given route.
*
* <p>For requestChannel interactions, i.e. Flux-to-Flux the metadata is
* attached to the first request payload.
*
* @param route the routing destination
* @return a spec for further defining and executing the reuqest
*/
RequestSpec route(String route);
/**
* Contract to provide input data for an RSocket request.
*/
interface RequestSpec {
/**
* Provide request payload data. The given Object may be a synchronous
* value, or a {@link Publisher} of values, or another async type that's
* registered in the configured {@link ReactiveAdapterRegistry}.
* <p>For multivalued Publishers, prefer using
* {@link #data(Publisher, Class)} or
* {@link #data(Publisher, ParameterizedTypeReference)} since that makes
* it possible to find a compatible {@code Encoder} up front vs looking
* it up on every value.
* @param data the Object to use for payload data
* @return spec for declaring the expected response
*/
ResponseSpec data(Object data);
/**
* Provide a {@link Publisher} of value(s) for request payload data.
* <p>Publisher semantics determined through the configured
* {@link ReactiveAdapterRegistry} influence which of the 4 RSocket
* interactions to use. Publishers with unknown semantics are treated
* as multivalued. Consider registering a reactive type adapter, or
* passing {@code Mono.from(publisher)}.
* <p>If the publisher completes empty, possibly {@code Publisher<Void>},
* the request will have an empty data Payload.
* @param publisher source of payload data value(s)
* @param dataType the type of values to be published
* @param <T> the type of element values
* @param <P> the type of publisher
* @return spec for declaring the expected response
*/
<T, P extends Publisher<T>> ResponseSpec data(P publisher, Class<T> dataType);
/**
* Variant of {@link #data(Publisher, Class)} for when the dataType has
* to have a generic type. See {@link ParameterizedTypeReference}.
*/
<T, P extends Publisher<T>> ResponseSpec data(P publisher, ParameterizedTypeReference<T> dataTypeRef);
}
/**
* Contract to declare the expected RSocket response.
*/
interface ResponseSpec {
/**
* Perform {@link RSocket#fireAndForget fireAndForget}.
*/
Mono<Void> send();
/**
* Perform {@link RSocket#requestResponse requestResponse}. If the
* expected data type is {@code Void.class}, the returned {@code Mono}
* will complete after all data is consumed.
* <p><strong>Note:</strong> Use of this method will raise an error if
* the request payload is a multivalued {@link Publisher} as
* determined through the configured {@link ReactiveAdapterRegistry}.
* @param dataType the expected data type for the response
* @param <T> parameter for the expected data type
* @return the decoded response
*/
<T> Mono<T> retrieveMono(Class<T> dataType);
/**
* Variant of {@link #retrieveMono(Class)} for when the dataType has
* to have a generic type. See {@link ParameterizedTypeReference}.
*/
<T> Mono<T> retrieveMono(ParameterizedTypeReference<T> dataTypeRef);
/**
* Perform {@link RSocket#requestStream requestStream} or
* {@link RSocket#requestChannel requestChannel} depending on whether
* the request input consists of a single or multiple payloads.
* If the expected data type is {@code Void.class}, the returned
* {@code Flux} will complete after all data is consumed.
* @param dataType the expected type for values in the response
* @param <T> parameterize the expected type of values
* @return the decoded response
*/
<T> Flux<T> retrieveFlux(Class<T> dataType);
/**
* Variant of {@link #retrieveFlux(Class)} for when the dataType has
* to have a generic type. See {@link ParameterizedTypeReference}.
*/
<T> Flux<T> retrieveFlux(ParameterizedTypeReference<T> dataTypeRef);
}
}

View File

@ -31,28 +31,40 @@ import org.springframework.util.Assert;
* @author Rossen Stoyanchev
* @since 5.2
*/
public class SendingRSocketMethodArgumentResolver implements HandlerMethodArgumentResolver {
public class RSocketRequesterMethodArgumentResolver implements HandlerMethodArgumentResolver {
/**
* Message header name that is expected to have the {@link RSocket} to
* initiate new interactions to the remote peer with.
*/
public static final String SENDING_RSOCKET_HEADER = "sendingRSocket";
public static final String RSOCKET_REQUESTER_HEADER = "rsocketRequester";
@Override
public boolean supportsParameter(MethodParameter parameter) {
return RSocket.class.isAssignableFrom(parameter.getParameterType());
Class<?> type = parameter.getParameterType();
return RSocketRequester.class.equals(type) || RSocket.class.isAssignableFrom(type);
}
@Override
public Mono<Object> resolveArgument(MethodParameter parameter, Message<?> message) {
Object headerValue = message.getHeaders().get(SENDING_RSOCKET_HEADER);
Assert.notNull(headerValue, "Missing '" + SENDING_RSOCKET_HEADER + "'");
Assert.isInstanceOf(RSocket.class, headerValue, "Expected header value of type io.rsocket.RSocket");
Object headerValue = message.getHeaders().get(RSOCKET_REQUESTER_HEADER);
Assert.notNull(headerValue, "Missing '" + RSOCKET_REQUESTER_HEADER + "'");
Assert.isInstanceOf(RSocketRequester.class, headerValue, "Expected header value of type RSocketRequester");
return Mono.just(headerValue);
RSocketRequester requester = (RSocketRequester) headerValue;
Class<?> type = parameter.getParameterType();
if (RSocketRequester.class.equals(type)) {
return Mono.just(requester);
}
else if (RSocket.class.isAssignableFrom(type)) {
return Mono.just(requester.rsocket());
}
else {
return Mono.error(new IllegalArgumentException("Unexpected parameter type: " + parameter));
}
}
}

View File

@ -0,0 +1,160 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import java.util.List;
import java.util.function.Consumer;
import io.netty.buffer.PooledByteBufAllocator;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.lang.Nullable;
import org.springframework.util.MimeType;
/**
* Access to strategies for use by RSocket requester and responder components.
*
* @author Rossen Stoyanchev
* @since 5.2
*/
public interface RSocketStrategies {
/**
* Return the configured {@link Builder#encoder(Encoder[]) encoders}.
* @see #encoder(ResolvableType, MimeType)
*/
List<Encoder<?>> encoders();
/**
* Find a compatible Encoder for the given element type.
* @param elementType the element type to match
* @param mimeType the MimeType to match
* @param <T> for casting the Encoder to the expected element type
* @return the matching Encoder
* @throws IllegalArgumentException if no matching Encoder is found
*/
@SuppressWarnings("unchecked")
default <T> Encoder<T> encoder(ResolvableType elementType, @Nullable MimeType mimeType) {
for (Encoder<?> encoder : encoders()) {
if (encoder.canEncode(elementType, mimeType)) {
return (Encoder<T>) encoder;
}
}
throw new IllegalArgumentException("No encoder for " + elementType);
}
/**
* Return the configured {@link Builder#decoder(Decoder[]) decoders}.
* @see #decoder(ResolvableType, MimeType)
*/
List<Decoder<?>> decoders();
/**
* Find a compatible Decoder for the given element type.
* @param elementType the element type to match
* @param mimeType the MimeType to match
* @param <T> for casting the Decoder to the expected element type
* @return the matching Decoder
* @throws IllegalArgumentException if no matching Decoder is found
*/
@SuppressWarnings("unchecked")
default <T> Decoder<T> decoder(ResolvableType elementType, @Nullable MimeType mimeType) {
for (Decoder<?> decoder : decoders()) {
if (decoder.canDecode(elementType, mimeType)) {
return (Decoder<T>) decoder;
}
}
throw new IllegalArgumentException("No decoder for " + elementType);
}
/**
* Return the configured
* {@link Builder#reactiveAdapterStrategy(ReactiveAdapterRegistry) reactiveAdapterRegistry}.
*/
ReactiveAdapterRegistry reactiveAdapterRegistry();
/**
* Return the configured
* {@link Builder#dataBufferFactory(DataBufferFactory) dataBufferFactory}.
*/
DataBufferFactory dataBufferFactory();
/**
* Return a builder to build a new {@code RSocketStrategies} instance.
*/
static Builder builder() {
return new DefaultRSocketStrategies.DefaultRSocketStrategiesBuilder();
}
/**
* The builder options for creating {@code RSocketStrategies}.
*/
interface Builder {
/**
* Add encoders to use for serializing Objects.
* <p>By default this is empty.
*/
Builder encoder(Encoder<?>... encoder);
/**
* Add decoders for de-serializing Objects.
* <p>By default this is empty.
*/
Builder decoder(Decoder<?>... decoder);
/**
* Access and manipulate the list of configured {@link #encoder encoders}.
*/
Builder encoders(Consumer<List<Encoder<?>>> consumer);
/**
* Access and manipulate the list of configured {@link #encoder decoders}.
*/
Builder decoders(Consumer<List<Decoder<?>>> consumer);
/**
* Configure the registry for reactive type support. This can be used to
* to adapt to, and/or determine the semantics of a given
* {@link org.reactivestreams.Publisher Publisher}.
* <p>By default this {@link ReactiveAdapterRegistry#sharedInstance}.
* @param registry the registry to use
*/
Builder reactiveAdapterStrategy(ReactiveAdapterRegistry registry);
/**
* Configure the DataBufferFactory to use for the allocation of buffers
* when creating or responding requests.
* <p>By default this is an instance of
* {@link org.springframework.core.io.buffer.NettyDataBufferFactory
* NettyDataBufferFactory} with {@link PooledByteBufAllocator#DEFAULT}.
* @param bufferFactory the buffer factory to use
*/
Builder dataBufferFactory(DataBufferFactory bufferFactory);
/**
* Builder the {@code RSocketStrategies} instance.
*/
RSocketStrategies build();
}
}

View File

@ -0,0 +1,275 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import io.reactivex.Completable;
import io.reactivex.Observable;
import io.reactivex.Single;
import io.rsocket.AbstractRSocket;
import io.rsocket.Payload;
import org.junit.Before;
import org.junit.Test;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.codec.StringDecoder;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.lang.Nullable;
import org.springframework.messaging.rsocket.RSocketRequester.RequestSpec;
import org.springframework.messaging.rsocket.RSocketRequester.ResponseSpec;
import org.springframework.util.MimeTypeUtils;
import static java.util.concurrent.TimeUnit.*;
import static org.junit.Assert.*;
/**
* Unit tests for {@link DefaultRSocketRequester}.
*
* @author Rossen Stoyanchev
*/
public class DefaultRSocketRequesterTests {
private static final Duration MILLIS_10 = Duration.ofMillis(10);
private TestRSocket rsocket;
private RSocketRequester requester;
private final DefaultDataBufferFactory bufferFactory = new DefaultDataBufferFactory();
@Before
public void setUp() {
RSocketStrategies strategies = RSocketStrategies.builder()
.decoder(StringDecoder.allMimeTypes())
.encoder(CharSequenceEncoder.allMimeTypes())
.build();
this.rsocket = new TestRSocket();
this.requester = RSocketRequester.create(rsocket, MimeTypeUtils.TEXT_PLAIN, strategies);
}
@Test
public void singlePayload() {
// data(Object)
testSinglePayload(spec -> spec.data("bodyA"), "bodyA");
testSinglePayload(spec -> spec.data(Mono.delay(MILLIS_10).map(l -> "bodyA")), "bodyA");
testSinglePayload(spec -> spec.data(Mono.delay(MILLIS_10).then()), "");
testSinglePayload(spec -> spec.data(Single.timer(10, MILLISECONDS).map(l -> "bodyA")), "bodyA");
testSinglePayload(spec -> spec.data(Completable.complete()), "");
// data(Publisher<T>, Class<T>)
testSinglePayload(spec -> spec.data(Mono.delay(MILLIS_10).map(l -> "bodyA"), String.class), "bodyA");
testSinglePayload(spec -> spec.data(Mono.delay(MILLIS_10).map(l -> "bodyA"), Object.class), "bodyA");
testSinglePayload(spec -> spec.data(Mono.delay(MILLIS_10).then(), Void.class), "");
}
private void testSinglePayload(Function<RequestSpec, ResponseSpec> mapper, String expectedValue) {
mapper.apply(this.requester.route("toA")).send().block(Duration.ofSeconds(5));
assertEquals("fireAndForget", this.rsocket.getSavedMethodName());
assertEquals("toA", this.rsocket.getSavedPayload().getMetadataUtf8());
assertEquals(expectedValue, this.rsocket.getSavedPayload().getDataUtf8());
}
@Test
public void multiPayload() {
String[] values = new String[] {"bodyA", "bodyB", "bodyC"};
Flux<String> stringFlux = Flux.fromArray(values).delayElements(MILLIS_10);
// data(Object)
testMultiPayload(spec -> spec.data(stringFlux), values);
testMultiPayload(spec -> spec.data(Flux.empty()), "");
testMultiPayload(spec -> spec.data(Observable.fromArray(values).delay(10, MILLISECONDS)), values);
testMultiPayload(spec -> spec.data(Observable.empty()), "");
// data(Publisher<T>, Class<T>)
testMultiPayload(spec -> spec.data(stringFlux, String.class), values);
testMultiPayload(spec -> spec.data(stringFlux.cast(Object.class), Object.class), values);
}
private void testMultiPayload(Function<RequestSpec, ResponseSpec> mapper, String... expectedValues) {
this.rsocket.reset();
mapper.apply(this.requester.route("toA")).retrieveFlux(String.class).blockLast(Duration.ofSeconds(5));
assertEquals("requestChannel", this.rsocket.getSavedMethodName());
List<Payload> payloads = this.rsocket.getSavedPayloadFlux().collectList().block(Duration.ofSeconds(5));
assertNotNull(payloads);
if (Arrays.equals(new String[] {""}, expectedValues)) {
assertEquals(1, payloads.size());
assertEquals("toA", payloads.get(0).getMetadataUtf8());
assertEquals("", payloads.get(0).getDataUtf8());
}
else {
assertArrayEquals(new String[] {"toA", "", ""},
payloads.stream().map(Payload::getMetadataUtf8).toArray(String[]::new));
assertArrayEquals(expectedValues,
payloads.stream().map(Payload::getDataUtf8).toArray(String[]::new));
}
}
@Test
public void send() {
String value = "bodyA";
this.requester.route("toA").data(value).send().block(Duration.ofSeconds(5));
assertEquals("fireAndForget", this.rsocket.getSavedMethodName());
assertEquals("toA", this.rsocket.getSavedPayload().getMetadataUtf8());
assertEquals("bodyA", this.rsocket.getSavedPayload().getDataUtf8());
}
@Test
public void retrieveMono() {
String value = "bodyA";
this.rsocket.setPayloadMonoToReturn(Mono.delay(MILLIS_10).thenReturn(toPayload(value)));
Mono<String> response = this.requester.route("").data("").retrieveMono(String.class);
StepVerifier.create(response).expectNext(value).expectComplete().verify(Duration.ofSeconds(5));
assertEquals("requestResponse", this.rsocket.getSavedMethodName());
}
@Test
public void retrieveMonoVoid() {
AtomicBoolean consumed = new AtomicBoolean(false);
Mono<Payload> mono = Mono.delay(MILLIS_10).thenReturn(toPayload("bodyA")).doOnSuccess(p -> consumed.set(true));
this.rsocket.setPayloadMonoToReturn(mono);
this.requester.route("").data("").retrieveMono(Void.class).block(Duration.ofSeconds(5));
assertTrue(consumed.get());
assertEquals("requestResponse", this.rsocket.getSavedMethodName());
}
@Test
public void retrieveFlux() {
String[] values = new String[] {"bodyA", "bodyB", "bodyC"};
this.rsocket.setPayloadFluxToReturn(Flux.fromArray(values).delayElements(MILLIS_10).map(this::toPayload));
Flux<String> response = this.requester.route("").data("").retrieveFlux(String.class);
StepVerifier.create(response).expectNext(values).expectComplete().verify(Duration.ofSeconds(5));
assertEquals("requestStream", this.rsocket.getSavedMethodName());
}
@Test
public void retrieveFluxVoid() {
AtomicBoolean consumed = new AtomicBoolean(false);
Flux<Payload> flux = Flux.just("bodyA", "bodyB")
.delayElements(MILLIS_10).map(this::toPayload).doOnComplete(() -> consumed.set(true));
this.rsocket.setPayloadFluxToReturn(flux);
this.requester.route("").data("").retrieveFlux(Void.class).blockLast(Duration.ofSeconds(5));
assertTrue(consumed.get());
assertEquals("requestStream", this.rsocket.getSavedMethodName());
}
@Test
public void rejectFluxToMono() {
try {
this.requester.route("").data(Flux.just("a", "b")).retrieveMono(String.class);
fail();
}
catch (IllegalArgumentException ex) {
assertEquals("No RSocket interaction model for Flux request to Mono response.", ex.getMessage());
}
}
private Payload toPayload(String value) {
return PayloadUtils.asPayload(bufferFactory.wrap(value.getBytes(StandardCharsets.UTF_8)));
}
private static class TestRSocket extends AbstractRSocket {
private Mono<Payload> payloadMonoToReturn = Mono.empty();
private Flux<Payload> payloadFluxToReturn = Flux.empty();
@Nullable private volatile String savedMethodName;
@Nullable private volatile Payload savedPayload;
@Nullable private volatile Flux<Payload> savedPayloadFlux;
void setPayloadMonoToReturn(Mono<Payload> payloadMonoToReturn) {
this.payloadMonoToReturn = payloadMonoToReturn;
}
void setPayloadFluxToReturn(Flux<Payload> payloadFluxToReturn) {
this.payloadFluxToReturn = payloadFluxToReturn;
}
@Nullable
String getSavedMethodName() {
return this.savedMethodName;
}
@Nullable
Payload getSavedPayload() {
return this.savedPayload;
}
@Nullable
Flux<Payload> getSavedPayloadFlux() {
return this.savedPayloadFlux;
}
public void reset() {
this.savedMethodName = null;
this.savedPayload = null;
this.savedPayloadFlux = null;
}
@Override
public Mono<Void> fireAndForget(Payload payload) {
this.savedMethodName = "fireAndForget";
this.savedPayload = payload;
return Mono.empty();
}
@Override
public Mono<Payload> requestResponse(Payload payload) {
this.savedMethodName = "requestResponse";
this.savedPayload = payload;
return this.payloadMonoToReturn;
}
@Override
public Flux<Payload> requestStream(Payload payload) {
this.savedMethodName = "requestStream";
this.savedPayload = payload;
return this.payloadFluxToReturn;
}
@Override
public Flux<Payload> requestChannel(Publisher<Payload> publisher) {
this.savedMethodName = "requestChannel";
this.savedPayloadFlux = Flux.from(publisher);
return this.payloadFluxToReturn;
}
}
}

View File

@ -16,15 +16,12 @@
package org.springframework.messaging.rsocket;
import java.time.Duration;
import java.util.Collections;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.RSocketFactory;
import io.rsocket.transport.netty.client.TcpClientTransport;
import io.rsocket.transport.netty.server.CloseableChannel;
import io.rsocket.transport.netty.server.TcpServerTransport;
import io.rsocket.util.DefaultPayload;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
@ -43,6 +40,7 @@ import org.springframework.messaging.ReactiveSubscribableChannel;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.support.DefaultReactiveMessageChannel;
import org.springframework.stereotype.Controller;
import org.springframework.util.MimeTypeUtils;
import static org.junit.Assert.*;
@ -55,11 +53,13 @@ public class RSocketClientToServerIntegrationTests {
private static AnnotationConfigApplicationContext context;
private static CloseableChannel serverChannel;
private static CloseableChannel server;
private static FireAndForgetCountingInterceptor interceptor = new FireAndForgetCountingInterceptor();
private static RSocket clientRsocket;
private static RSocket client;
private static RSocketRequester requester;
@BeforeClass
@ -68,27 +68,30 @@ public class RSocketClientToServerIntegrationTests {
context = new AnnotationConfigApplicationContext(ServerConfig.class);
MessagingAcceptor acceptor = new MessagingAcceptor(
context.getBean("rsocketChannel", ReactiveMessageChannel.class));
ReactiveMessageChannel messageChannel = context.getBean(ReactiveMessageChannel.class);
RSocketStrategies rsocketStrategies = context.getBean(RSocketStrategies.class);
serverChannel = RSocketFactory.receive()
server = RSocketFactory.receive()
.addServerPlugin(interceptor)
.acceptor(acceptor)
.acceptor(new MessagingAcceptor(messageChannel))
.transport(TcpServerTransport.create("localhost", 7000))
.start()
.block();
clientRsocket = RSocketFactory.connect()
.dataMimeType("text/plain")
client = RSocketFactory.connect()
.dataMimeType(MimeTypeUtils.TEXT_PLAIN_VALUE)
.transport(TcpClientTransport.create("localhost", 7000))
.start()
.block();
requester = RSocketRequester.create(
client, MimeTypeUtils.TEXT_PLAIN, rsocketStrategies);
}
@AfterClass
public static void tearDownOnce() {
clientRsocket.dispose();
serverChannel.dispose();
client.dispose();
server.dispose();
}
@ -96,7 +99,7 @@ public class RSocketClientToServerIntegrationTests {
public void fireAndForget() {
Flux.range(1, 3)
.concatMap(i -> clientRsocket.fireAndForget(payload("receive", "Hello " + i)))
.concatMap(i -> requester.route("receive").data("Hello " + i).send())
.blockLast();
StepVerifier.create(context.getBean(ServerController.class).fireForgetPayloads)
@ -115,7 +118,7 @@ public class RSocketClientToServerIntegrationTests {
public void echo() {
Flux<String> result = Flux.range(1, 3).concatMap(i ->
clientRsocket.requestResponse(payload("echo", "Hello " + i)).map(Payload::getDataUtf8));
requester.route("echo").data("Hello " + i).retrieveMono(String.class));
StepVerifier.create(result)
.expectNext("Hello 1")
@ -128,7 +131,7 @@ public class RSocketClientToServerIntegrationTests {
public void echoAsync() {
Flux<String> result = Flux.range(1, 3).concatMap(i ->
clientRsocket.requestResponse(payload("echo-async", "Hello " + i)).map(Payload::getDataUtf8));
requester.route("echo-async").data("Hello " + i).retrieveMono(String.class));
StepVerifier.create(result)
.expectNext("Hello 1 async")
@ -140,8 +143,7 @@ public class RSocketClientToServerIntegrationTests {
@Test
public void echoStream() {
Flux<String> result = clientRsocket.requestStream(payload("echo-stream", "Hello"))
.map(io.rsocket.Payload::getDataUtf8);
Flux<String> result = requester.route("echo-stream").data("Hello").retrieveFlux(String.class);
StepVerifier.create(result)
.expectNext("Hello 0")
@ -155,11 +157,9 @@ public class RSocketClientToServerIntegrationTests {
@Test
public void echoChannel() {
Flux<Payload> payloads = Flux.concat(
Flux.just(payload("echo-channel", "Hello 1")),
Flux.range(2, 9).map(i -> DefaultPayload.create("Hello " + i)));
Flux<String> result = clientRsocket.requestChannel(payloads).map(Payload::getDataUtf8);
Flux<String> result = requester.route("echo-channel")
.data(Flux.range(1, 10).map(i -> "Hello " + i), String.class)
.retrieveFlux(String.class);
StepVerifier.create(result)
.expectNext("Hello 1 async")
@ -170,12 +170,6 @@ public class RSocketClientToServerIntegrationTests {
}
private static Payload payload(String destination, String data) {
return DefaultPayload.create(data, destination);
}
@Controller
static class ServerController {
@ -226,10 +220,17 @@ public class RSocketClientToServerIntegrationTests {
@Bean
public RSocketMessageHandler rsocketMessageHandler() {
RSocketMessageHandler handler = new RSocketMessageHandler(rsocketChannel());
handler.setDecoders(Collections.singletonList(StringDecoder.allMimeTypes()));
handler.setEncoders(Collections.singletonList(CharSequenceEncoder.allMimeTypes()));
handler.setRSocketStrategies(rsocketStrategies());
return handler;
}
@Bean
public RSocketStrategies rsocketStrategies() {
return RSocketStrategies.builder()
.decoder(StringDecoder.allMimeTypes())
.encoder(CharSequenceEncoder.allMimeTypes())
.build();
}
}
}

View File

@ -19,11 +19,11 @@ import java.time.Duration;
import java.util.Collections;
import java.util.List;
import io.rsocket.Closeable;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.RSocketFactory;
import io.rsocket.transport.netty.client.TcpClientTransport;
import io.rsocket.transport.netty.server.CloseableChannel;
import io.rsocket.transport.netty.server.TcpServerTransport;
import io.rsocket.util.DefaultPayload;
import org.junit.AfterClass;
@ -56,7 +56,7 @@ public class RSocketServerToClientIntegrationTests {
private static AnnotationConfigApplicationContext context;
private static CloseableChannel serverChannel;
private static Closeable server;
private static MessagingAcceptor clientAcceptor;
@ -67,14 +67,14 @@ public class RSocketServerToClientIntegrationTests {
context = new AnnotationConfigApplicationContext(ServerConfig.class);
ReactiveMessageChannel messageChannel = context.getBean("serverChannel", ReactiveMessageChannel.class);
RSocketStrategies rsocketStrategies = context.getBean(RSocketStrategies.class);
clientAcceptor = new MessagingAcceptor(
context.getBean("clientChannel", ReactiveMessageChannel.class));
MessagingAcceptor serverAcceptor = new MessagingAcceptor(
context.getBean("serverChannel", ReactiveMessageChannel.class));
serverChannel = RSocketFactory.receive()
.acceptor(serverAcceptor)
server = RSocketFactory.receive()
.acceptor(new MessagingAcceptor(messageChannel, rsocketStrategies))
.transport(TcpServerTransport.create("localhost", 7000))
.start()
.block();
@ -82,7 +82,7 @@ public class RSocketServerToClientIntegrationTests {
@AfterClass
public static void tearDownOnce() {
serverChannel.dispose();
server.dispose();
}
@ -141,10 +141,10 @@ public class RSocketServerToClientIntegrationTests {
@MessageMapping("connect.echo")
void echo(RSocket rsocket) {
void echo(RSocketRequester requester) {
runTest(() -> {
Flux<String> result = Flux.range(1, 3).concatMap(i ->
rsocket.requestResponse(payload("echo", "Hello " + i)).map(Payload::getDataUtf8));
requester.route("echo").data("Hello " + i).retrieveMono(String.class));
StepVerifier.create(result)
.expectNext("Hello 1")
@ -155,10 +155,10 @@ public class RSocketServerToClientIntegrationTests {
}
@MessageMapping("connect.echo-async")
void echoAsync(RSocket rsocket) {
void echoAsync(RSocketRequester requester) {
runTest(() -> {
Flux<String> result = Flux.range(1, 3).concatMap(i ->
rsocket.requestResponse(payload("echo-async", "Hello " + i)).map(Payload::getDataUtf8));
requester.route("echo-async").data("Hello " + i).retrieveMono(String.class));
StepVerifier.create(result)
.expectNext("Hello 1 async")
@ -169,10 +169,9 @@ public class RSocketServerToClientIntegrationTests {
}
@MessageMapping("connect.echo-stream")
void echoStream(RSocket rsocket) {
void echoStream(RSocketRequester requester) {
runTest(() -> {
Flux<String> result = rsocket.requestStream(payload("echo-stream", "Hello"))
.map(io.rsocket.Payload::getDataUtf8);
Flux<String> result = requester.route("echo-stream").data("Hello").retrieveFlux(String.class);
StepVerifier.create(result)
.expectNext("Hello 0")
@ -185,13 +184,11 @@ public class RSocketServerToClientIntegrationTests {
}
@MessageMapping("connect.echo-channel")
void echoChannel(RSocket rsocket) {
void echoChannel(RSocketRequester requester) {
runTest(() -> {
Flux<Payload> payloads = Flux.concat(
Flux.just(payload("echo-channel", "Hello 1")),
Flux.range(2, 9).map(i -> DefaultPayload.create("Hello " + i)));
Flux<String> result = rsocket.requestChannel(payloads).map(Payload::getDataUtf8);
Flux<String> result = requester.route("echo-channel")
.data(Flux.range(1, 10).map(i -> "Hello " + i), String.class)
.retrieveFlux(String.class);
StepVerifier.create(result)
.expectNext("Hello 1 async")
@ -285,20 +282,23 @@ public class RSocketServerToClientIntegrationTests {
public RSocketMessageHandler clientMessageHandler() {
List<Object> handlers = Collections.singletonList(clientController());
RSocketMessageHandler handler = new RSocketMessageHandler(clientChannel(), handlers);
addDefaultCodecs(handler);
handler.setRSocketStrategies(rsocketStrategies());
return handler;
}
@Bean
public RSocketMessageHandler serverMessageHandler() {
RSocketMessageHandler handler = new RSocketMessageHandler(serverChannel());
addDefaultCodecs(handler);
handler.setRSocketStrategies(rsocketStrategies());
return handler;
}
private void addDefaultCodecs(RSocketMessageHandler handler) {
handler.setDecoders(Collections.singletonList(StringDecoder.allMimeTypes()));
handler.setEncoders(Collections.singletonList(CharSequenceEncoder.allMimeTypes()));
@Bean
public RSocketStrategies rsocketStrategies() {
return RSocketStrategies.builder()
.decoder(StringDecoder.allMimeTypes())
.encoder(CharSequenceEncoder.allMimeTypes())
.build();
}
}