Updates for buffer management in RSocket
- Integration tests run with zero copy configuration. - RSocketBufferLeakTests has been added. - Updates in MessagingRSocket to ensure proper release See gh-21987
This commit is contained in:
parent
23b39ad27b
commit
9e7f557b4a
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
@ -105,13 +105,15 @@ public abstract class AbstractDataBufferAllocatingTestCase {
|
|||
*/
|
||||
protected void waitForDataBufferRelease(Duration duration) throws InterruptedException {
|
||||
Instant start = Instant.now();
|
||||
while (Instant.now().isBefore(start.plus(duration))) {
|
||||
while (true) {
|
||||
try {
|
||||
verifyAllocations();
|
||||
break;
|
||||
}
|
||||
catch (AssertionError ex) {
|
||||
// ignore;
|
||||
if (Instant.now().isAfter(start.plus(duration))) {
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
Thread.sleep(50);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -396,10 +396,10 @@ public abstract class AbstractMethodMessageHandler<T>
|
|||
if (matches.size() > 1) {
|
||||
Match<T> secondBestMatch = matches.get(1);
|
||||
if (comparator.compare(bestMatch, secondBestMatch) == 0) {
|
||||
Method m1 = bestMatch.handlerMethod.getMethod();
|
||||
Method m2 = secondBestMatch.handlerMethod.getMethod();
|
||||
HandlerMethod m1 = bestMatch.handlerMethod;
|
||||
HandlerMethod m2 = secondBestMatch.handlerMethod;
|
||||
throw new IllegalStateException("Ambiguous handler methods mapped for destination '" +
|
||||
destination + "': {" + m1 + ", " + m2 + "}");
|
||||
destination + "': {" + m1.getShortLogMessage() + ", " + m2.getShortLogMessage() + "}");
|
||||
}
|
||||
}
|
||||
return bestMatch;
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ final class DefaultRSocketRequester implements RSocketRequester {
|
|||
|
||||
Decoder<?> decoder = strategies.decoder(elementType, dataMimeType);
|
||||
return (Mono<T>) decoder.decodeToMono(
|
||||
payloadMono.map(this::wrapPayloadData), elementType, dataMimeType, EMPTY_HINTS);
|
||||
payloadMono.map(this::retainDataAndReleasePayload), elementType, dataMimeType, EMPTY_HINTS);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
|
@ -260,12 +260,12 @@ final class DefaultRSocketRequester implements RSocketRequester {
|
|||
|
||||
Decoder<?> decoder = strategies.decoder(elementType, dataMimeType);
|
||||
|
||||
return payloadFlux.map(this::wrapPayloadData).concatMap(dataBuffer ->
|
||||
return payloadFlux.map(this::retainDataAndReleasePayload).concatMap(dataBuffer ->
|
||||
(Mono<T>) decoder.decodeToMono(Mono.just(dataBuffer), elementType, dataMimeType, EMPTY_HINTS));
|
||||
}
|
||||
|
||||
private DataBuffer wrapPayloadData(Payload payload) {
|
||||
return PayloadUtils.wrapPayloadData(payload, strategies.dataBufferFactory());
|
||||
private DataBuffer retainDataAndReleasePayload(Payload payload) {
|
||||
return PayloadUtils.retainDataAndReleasePayload(payload, strategies.dataBufferFactory());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,14 +21,13 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import io.netty.buffer.PooledByteBufAllocator;
|
||||
|
||||
import org.springframework.core.ReactiveAdapterRegistry;
|
||||
import org.springframework.core.codec.Decoder;
|
||||
import org.springframework.core.codec.Encoder;
|
||||
import org.springframework.core.io.buffer.DataBufferFactory;
|
||||
import org.springframework.core.io.buffer.NettyDataBufferFactory;
|
||||
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* Default, package-private {@link RSocketStrategies} implementation.
|
||||
|
|
@ -88,11 +87,10 @@ final class DefaultRSocketStrategies implements RSocketStrategies {
|
|||
|
||||
private final List<Decoder<?>> decoders = new ArrayList<>();
|
||||
|
||||
@Nullable
|
||||
private ReactiveAdapterRegistry adapterRegistry;
|
||||
private ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry.getSharedInstance();
|
||||
|
||||
@Nullable
|
||||
private DataBufferFactory bufferFactory;
|
||||
private DataBufferFactory dataBufferFactory;
|
||||
|
||||
|
||||
@Override
|
||||
|
|
@ -121,23 +119,21 @@ final class DefaultRSocketStrategies implements RSocketStrategies {
|
|||
|
||||
@Override
|
||||
public Builder reactiveAdapterStrategy(ReactiveAdapterRegistry registry) {
|
||||
Assert.notNull(registry, "ReactiveAdapterRegistry is required");
|
||||
this.adapterRegistry = registry;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder dataBufferFactory(DataBufferFactory bufferFactory) {
|
||||
this.bufferFactory = bufferFactory;
|
||||
this.dataBufferFactory = bufferFactory;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RSocketStrategies build() {
|
||||
return new DefaultRSocketStrategies(this.encoders, this.decoders,
|
||||
this.adapterRegistry != null ?
|
||||
this.adapterRegistry : ReactiveAdapterRegistry.getSharedInstance(),
|
||||
this.bufferFactory != null ? this.bufferFactory :
|
||||
new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT));
|
||||
return new DefaultRSocketStrategies(this.encoders, this.decoders, this.adapterRegistry,
|
||||
this.dataBufferFactory != null ? this.dataBufferFactory : new DefaultDataBufferFactory());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
package org.springframework.messaging.rsocket;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Function;
|
||||
|
||||
import io.rsocket.AbstractRSocket;
|
||||
|
|
@ -29,7 +30,7 @@ import reactor.core.publisher.MonoProcessor;
|
|||
import org.springframework.core.io.buffer.DataBuffer;
|
||||
import org.springframework.core.io.buffer.DataBufferFactory;
|
||||
import org.springframework.core.io.buffer.DataBufferUtils;
|
||||
import org.springframework.core.io.buffer.PooledDataBuffer;
|
||||
import org.springframework.core.io.buffer.NettyDataBuffer;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageHeaders;
|
||||
|
|
@ -84,6 +85,9 @@ class MessagingRSocket extends AbstractRSocket {
|
|||
if (StringUtils.hasText(payload.dataMimeType())) {
|
||||
this.dataMimeType = MimeTypeUtils.parseMimeType(payload.dataMimeType());
|
||||
}
|
||||
// frameDecoder does not apply to connectionSetupPayload
|
||||
// so retain here since handle expects it..
|
||||
payload.retain();
|
||||
return handle(payload);
|
||||
}
|
||||
|
||||
|
|
@ -120,54 +124,72 @@ class MessagingRSocket extends AbstractRSocket {
|
|||
|
||||
|
||||
private Mono<Void> handle(Payload payload) {
|
||||
Message<?> message = MessageBuilder.createMessage(
|
||||
Mono.fromCallable(() -> wrapPayloadData(payload)), createHeaders(payload, null));
|
||||
String destination = getDestination(payload);
|
||||
MessageHeaders headers = createHeaders(destination, null);
|
||||
DataBuffer dataBuffer = retainDataAndReleasePayload(payload);
|
||||
int refCount = refCount(dataBuffer);
|
||||
Message<?> message = MessageBuilder.createMessage(dataBuffer, headers);
|
||||
return Mono.defer(() -> this.handler.apply(message))
|
||||
.doFinally(s -> {
|
||||
if (refCount(dataBuffer) == refCount) {
|
||||
DataBufferUtils.release(dataBuffer);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return this.handler.apply(message);
|
||||
private int refCount(DataBuffer dataBuffer) {
|
||||
return dataBuffer instanceof NettyDataBuffer ?
|
||||
((NettyDataBuffer) dataBuffer).getNativeBuffer().refCnt() : 1;
|
||||
}
|
||||
|
||||
private Flux<Payload> handleAndReply(Payload firstPayload, Flux<Payload> payloads) {
|
||||
MonoProcessor<Flux<Payload>> replyMono = MonoProcessor.create();
|
||||
Message<?> message = MessageBuilder.createMessage(
|
||||
payloads.map(this::wrapPayloadData).doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release),
|
||||
createHeaders(firstPayload, replyMono));
|
||||
String destination = getDestination(firstPayload);
|
||||
MessageHeaders headers = createHeaders(destination, replyMono);
|
||||
|
||||
return this.handler.apply(message)
|
||||
AtomicBoolean read = new AtomicBoolean();
|
||||
Flux<DataBuffer> buffers = payloads.map(this::retainDataAndReleasePayload).doOnSubscribe(s -> read.set(true));
|
||||
Message<Flux<DataBuffer>> message = MessageBuilder.createMessage(buffers, headers);
|
||||
|
||||
return Mono.defer(() -> this.handler.apply(message))
|
||||
.doFinally(s -> {
|
||||
// Subscription should have happened by now due to ChannelSendOperator
|
||||
if (!read.get()) {
|
||||
buffers.subscribe(DataBufferUtils::release);
|
||||
}
|
||||
})
|
||||
.thenMany(Flux.defer(() -> replyMono.isTerminated() ?
|
||||
replyMono.flatMapMany(Function.identity()) :
|
||||
Mono.error(new IllegalStateException("Something went wrong: reply Mono not set"))));
|
||||
}
|
||||
|
||||
private MessageHeaders createHeaders(Payload payload, @Nullable MonoProcessor<?> replyMono) {
|
||||
private String getDestination(Payload payload) {
|
||||
|
||||
// TODO:
|
||||
// For now treat the metadata as a simple string with routing information.
|
||||
// We'll have to get more sophisticated once the routing extension is completed.
|
||||
// https://github.com/rsocket/rsocket-java/issues/568
|
||||
|
||||
return payload.getMetadataUtf8();
|
||||
}
|
||||
|
||||
private DataBuffer retainDataAndReleasePayload(Payload payload) {
|
||||
return PayloadUtils.retainDataAndReleasePayload(payload, this.strategies.dataBufferFactory());
|
||||
}
|
||||
|
||||
private MessageHeaders createHeaders(String destination, @Nullable MonoProcessor<?> replyMono) {
|
||||
MessageHeaderAccessor headers = new MessageHeaderAccessor();
|
||||
|
||||
String destination = payload.getMetadataUtf8();
|
||||
headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, destination);
|
||||
|
||||
if (this.dataMimeType != null) {
|
||||
headers.setContentType(this.dataMimeType);
|
||||
}
|
||||
|
||||
headers.setHeader(RSocketRequesterMethodArgumentResolver.RSOCKET_REQUESTER_HEADER, this.requester);
|
||||
|
||||
if (replyMono != null) {
|
||||
headers.setHeader(RSocketPayloadReturnValueHandler.RESPONSE_HEADER, replyMono);
|
||||
}
|
||||
|
||||
DataBufferFactory bufferFactory = this.strategies.dataBufferFactory();
|
||||
headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, bufferFactory);
|
||||
|
||||
return headers.getMessageHeaders();
|
||||
}
|
||||
|
||||
private DataBuffer wrapPayloadData(Payload payload) {
|
||||
return PayloadUtils.wrapPayloadData(payload, this.strategies.dataBufferFactory());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
package org.springframework.messaging.rsocket;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.rsocket.Frame;
|
||||
import io.rsocket.Payload;
|
||||
import io.rsocket.util.ByteBufPayload;
|
||||
import io.rsocket.util.DefaultPayload;
|
||||
|
|
@ -24,6 +26,7 @@ import org.springframework.core.io.buffer.DataBufferFactory;
|
|||
import org.springframework.core.io.buffer.DefaultDataBuffer;
|
||||
import org.springframework.core.io.buffer.NettyDataBuffer;
|
||||
import org.springframework.core.io.buffer.NettyDataBufferFactory;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* Static utility methods to create {@link Payload} from {@link DataBuffer}s
|
||||
|
|
@ -35,19 +38,31 @@ import org.springframework.core.io.buffer.NettyDataBufferFactory;
|
|||
abstract class PayloadUtils {
|
||||
|
||||
/**
|
||||
* Return the Payload data wrapped as DataBuffer. If the bufferFactory is
|
||||
* {@link NettyDataBufferFactory} the payload retained and sliced.
|
||||
* @param payload the input payload
|
||||
* @param bufferFactory the BufferFactory to use to wrap
|
||||
* @return the DataBuffer wrapper
|
||||
* Use this method to slice, retain and wrap the data portion of the
|
||||
* {@code Payload}, and also to release the {@code Payload}. This assumes
|
||||
* the Payload metadata has been read by now and ensures downstream code
|
||||
* need only be aware of {@code DataBuffer}s.
|
||||
* @param payload the payload to process
|
||||
* @param bufferFactory the DataBufferFactory to wrap with
|
||||
* @return the created {@code DataBuffer} instance
|
||||
*/
|
||||
public static DataBuffer wrapPayloadData(Payload payload, DataBufferFactory bufferFactory) {
|
||||
if (bufferFactory instanceof NettyDataBufferFactory) {
|
||||
return ((NettyDataBufferFactory) bufferFactory).wrap(payload.retain().sliceData());
|
||||
}
|
||||
else {
|
||||
public static DataBuffer retainDataAndReleasePayload(Payload payload, DataBufferFactory bufferFactory) {
|
||||
try {
|
||||
if (bufferFactory instanceof NettyDataBufferFactory) {
|
||||
ByteBuf byteBuf = payload.sliceData().retain();
|
||||
return ((NettyDataBufferFactory) bufferFactory).wrap(byteBuf);
|
||||
}
|
||||
|
||||
Assert.isTrue(!(payload instanceof ByteBufPayload) && !(payload instanceof Frame),
|
||||
"NettyDataBufferFactory expected, actual: " + bufferFactory.getClass().getSimpleName());
|
||||
|
||||
return bufferFactory.wrap(payload.getData());
|
||||
}
|
||||
finally {
|
||||
if (payload.refCnt() > 0) {
|
||||
payload.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -142,12 +142,23 @@ public interface RSocketStrategies {
|
|||
Builder reactiveAdapterStrategy(ReactiveAdapterRegistry registry);
|
||||
|
||||
/**
|
||||
* Configure the DataBufferFactory to use for the allocation of buffers
|
||||
* when creating or responding requests.
|
||||
* <p>By default this is an instance of
|
||||
* Configure the DataBufferFactory to use for allocating buffers, for
|
||||
* example when preparing requests or when responding. The choice here
|
||||
* must be aligned with the frame decoder configured in
|
||||
* {@link io.rsocket.RSocketFactory}.
|
||||
* <p>By default this property is an instance of
|
||||
* {@link org.springframework.core.io.buffer.DefaultDataBufferFactory
|
||||
* DefaultDataBufferFactory} matching to the default frame decoder in
|
||||
* {@link io.rsocket.RSocketFactory} which copies the payload. This
|
||||
* comes at cost to performance but does not require reference counting
|
||||
* and eliminates possibility for memory leaks.
|
||||
* <p>To switch to a zero-copy strategy,
|
||||
* <a href="https://github.com/rsocket/rsocket-java#zero-copy">configure RSocket</a>
|
||||
* accordingly, and then configure this property with an instance of
|
||||
* {@link org.springframework.core.io.buffer.NettyDataBufferFactory
|
||||
* NettyDataBufferFactory} with {@link PooledByteBufAllocator#DEFAULT}.
|
||||
* @param bufferFactory the buffer factory to use
|
||||
* NettyDataBufferFactory} with a pooled allocator such as
|
||||
* {@link PooledByteBufAllocator#DEFAULT}.
|
||||
* @param bufferFactory the DataBufferFactory to use
|
||||
*/
|
||||
Builder dataBufferFactory(DataBufferFactory bufferFactory);
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,466 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.messaging.rsocket;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.buffer.PooledByteBufAllocator;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.util.ReferenceCounted;
|
||||
import io.rsocket.AbstractRSocket;
|
||||
import io.rsocket.Frame;
|
||||
import io.rsocket.RSocket;
|
||||
import io.rsocket.RSocketFactory;
|
||||
import io.rsocket.plugins.RSocketInterceptor;
|
||||
import io.rsocket.transport.netty.client.TcpClientTransport;
|
||||
import io.rsocket.transport.netty.server.CloseableChannel;
|
||||
import io.rsocket.transport.netty.server.TcpServerTransport;
|
||||
import org.junit.After;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
import org.reactivestreams.Publisher;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.publisher.ReplayProcessor;
|
||||
import reactor.test.StepVerifier;
|
||||
|
||||
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.core.codec.CharSequenceEncoder;
|
||||
import org.springframework.core.codec.StringDecoder;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.core.io.buffer.DataBuffer;
|
||||
import org.springframework.core.io.buffer.NettyDataBuffer;
|
||||
import org.springframework.core.io.buffer.NettyDataBufferFactory;
|
||||
import org.springframework.core.io.buffer.PooledDataBuffer;
|
||||
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
|
||||
import org.springframework.messaging.handler.annotation.MessageMapping;
|
||||
import org.springframework.messaging.handler.annotation.Payload;
|
||||
import org.springframework.stereotype.Controller;
|
||||
import org.springframework.util.MimeTypeUtils;
|
||||
import org.springframework.util.ObjectUtils;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Tests for scenarios that could lead to Payload and/or DataBuffer leaks.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
*/
|
||||
public class RSocketBufferLeakTests {
|
||||
|
||||
private static AnnotationConfigApplicationContext context;
|
||||
|
||||
private static final PayloadInterceptor payloadInterceptor = new PayloadInterceptor();
|
||||
|
||||
private static CloseableChannel server;
|
||||
|
||||
private static RSocket client;
|
||||
|
||||
private static RSocketRequester requester;
|
||||
|
||||
|
||||
@BeforeClass
|
||||
@SuppressWarnings("ConstantConditions")
|
||||
public static void setupOnce() {
|
||||
|
||||
context = new AnnotationConfigApplicationContext(ServerConfig.class);
|
||||
|
||||
server = RSocketFactory.receive()
|
||||
.frameDecoder(Frame::retain) // zero copy
|
||||
.addServerPlugin(payloadInterceptor) // intercept responding
|
||||
.acceptor(context.getBean(MessageHandlerAcceptor.class))
|
||||
.transport(TcpServerTransport.create("localhost", 7000))
|
||||
.start()
|
||||
.block();
|
||||
|
||||
client = RSocketFactory.connect()
|
||||
.frameDecoder(Frame::retain) // zero copy
|
||||
.addClientPlugin(payloadInterceptor) // intercept outgoing requests
|
||||
.dataMimeType(MimeTypeUtils.TEXT_PLAIN_VALUE)
|
||||
.transport(TcpClientTransport.create("localhost", 7000))
|
||||
.start()
|
||||
.block();
|
||||
|
||||
requester = RSocketRequester.create(
|
||||
client, MimeTypeUtils.TEXT_PLAIN, context.getBean(RSocketStrategies.class));
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void tearDownOnce() {
|
||||
client.dispose();
|
||||
server.dispose();
|
||||
}
|
||||
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
getLeakAwareNettyDataBufferFactory().reset();
|
||||
payloadInterceptor.reset();
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() throws InterruptedException {
|
||||
getLeakAwareNettyDataBufferFactory().checkForLeaks(Duration.ofSeconds(5));
|
||||
payloadInterceptor.checkForLeaks();
|
||||
}
|
||||
|
||||
private LeakAwareNettyDataBufferFactory getLeakAwareNettyDataBufferFactory() {
|
||||
return (LeakAwareNettyDataBufferFactory) context.getBean(RSocketStrategies.class).dataBufferFactory();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void assemblyTimeErrorForHandleAndReply() {
|
||||
Mono<String> result = requester.route("A.B").data("foo").retrieveMono(String.class);
|
||||
StepVerifier.create(result).expectErrorMatches(ex -> {
|
||||
String prefix = "Ambiguous handler methods mapped for destination 'A.B':";
|
||||
return ex.getMessage().startsWith(prefix);
|
||||
}).verify();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void subscriptionTimeErrorForHandleAndReply() {
|
||||
Mono<String> result = requester.route("not-decodable").data("foo").retrieveMono(String.class);
|
||||
StepVerifier.create(result).expectErrorMatches(ex -> {
|
||||
String prefix = "Cannot decode to [org.springframework.core.io.Resource]";
|
||||
return ex.getMessage().contains(prefix);
|
||||
}).verify();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void errorSignalWithExceptionHandler() {
|
||||
Mono<String> result = requester.route("error-signal").data("foo").retrieveMono(String.class);
|
||||
StepVerifier.create(result).expectNext("Handled 'bad input'").verifyComplete();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void ignoreInput() {
|
||||
Flux<String> result = requester.route("ignore-input").data("a").retrieveFlux(String.class);
|
||||
StepVerifier.create(result).expectNext("bar").verifyComplete();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void retrieveMonoFromFluxResponderMethod() {
|
||||
Mono<String> result = requester.route("request-stream").data("foo").retrieveMono(String.class);
|
||||
StepVerifier.create(result).expectNext("foo-1").verifyComplete();
|
||||
}
|
||||
|
||||
|
||||
@Controller
|
||||
static class ServerController {
|
||||
|
||||
@MessageMapping("A.*")
|
||||
void ambiguousMatchA(String payload) {
|
||||
throw new IllegalStateException("Unexpected call");
|
||||
}
|
||||
|
||||
@MessageMapping("*.B")
|
||||
void ambiguousMatchB(String payload) {
|
||||
throw new IllegalStateException("Unexpected call");
|
||||
}
|
||||
|
||||
@MessageMapping("not-decodable")
|
||||
void notDecodable(@Payload Resource resource) {
|
||||
throw new IllegalStateException("Unexpected call");
|
||||
}
|
||||
|
||||
@MessageMapping("error-signal")
|
||||
public Flux<String> errorSignal(String payload) {
|
||||
return Flux.error(new IllegalArgumentException("bad input"))
|
||||
.delayElements(Duration.ofMillis(10))
|
||||
.cast(String.class);
|
||||
}
|
||||
|
||||
@MessageExceptionHandler
|
||||
public String handleIllegalArgument(IllegalArgumentException ex) {
|
||||
return "Handled '" + ex.getMessage() + "'";
|
||||
}
|
||||
|
||||
@MessageMapping("ignore-input")
|
||||
Mono<String> ignoreInput() {
|
||||
return Mono.delay(Duration.ofMillis(10)).map(l -> "bar");
|
||||
}
|
||||
|
||||
@MessageMapping("request-stream")
|
||||
Flux<String> stream(String payload) {
|
||||
return Flux.range(1,100).delayElements(Duration.ofMillis(10)).map(idx -> payload + "-" + idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Configuration
|
||||
static class ServerConfig {
|
||||
|
||||
@Bean
|
||||
public ServerController controller() {
|
||||
return new ServerController();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public MessageHandlerAcceptor messageHandlerAcceptor() {
|
||||
MessageHandlerAcceptor acceptor = new MessageHandlerAcceptor();
|
||||
acceptor.setRSocketStrategies(rsocketStrategies());
|
||||
return acceptor;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public RSocketStrategies rsocketStrategies() {
|
||||
return RSocketStrategies.builder()
|
||||
.decoder(StringDecoder.allMimeTypes())
|
||||
.encoder(CharSequenceEncoder.allMimeTypes())
|
||||
.dataBufferFactory(new LeakAwareNettyDataBufferFactory(PooledByteBufAllocator.DEFAULT))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Similar {@link org.springframework.core.io.buffer.LeakAwareDataBufferFactory}
|
||||
* but extends {@link NettyDataBufferFactory} rather than rely on
|
||||
* decoration, since {@link PayloadUtils} does instanceof checks.
|
||||
*/
|
||||
private static class LeakAwareNettyDataBufferFactory extends NettyDataBufferFactory {
|
||||
|
||||
private final List<DataBufferLeakInfo> created = new ArrayList<>();
|
||||
|
||||
|
||||
LeakAwareNettyDataBufferFactory(ByteBufAllocator byteBufAllocator) {
|
||||
super(byteBufAllocator);
|
||||
}
|
||||
|
||||
|
||||
void checkForLeaks(Duration duration) throws InterruptedException {
|
||||
Instant start = Instant.now();
|
||||
while (true) {
|
||||
try {
|
||||
this.created.forEach(info -> {
|
||||
if (((PooledDataBuffer) info.getDataBuffer()).isAllocated()) {
|
||||
throw info.getError();
|
||||
}
|
||||
});
|
||||
break;
|
||||
}
|
||||
catch (AssertionError ex) {
|
||||
if (Instant.now().isAfter(start.plus(duration))) {
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
Thread.sleep(50);
|
||||
}
|
||||
}
|
||||
|
||||
void reset() {
|
||||
this.created.clear();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public NettyDataBuffer allocateBuffer() {
|
||||
return (NettyDataBuffer) record(super.allocateBuffer());
|
||||
}
|
||||
|
||||
@Override
|
||||
public NettyDataBuffer allocateBuffer(int initialCapacity) {
|
||||
return (NettyDataBuffer) record(super.allocateBuffer(initialCapacity));
|
||||
}
|
||||
|
||||
@Override
|
||||
public NettyDataBuffer wrap(ByteBuf byteBuf) {
|
||||
NettyDataBuffer dataBuffer = super.wrap(byteBuf);
|
||||
if (byteBuf != Unpooled.EMPTY_BUFFER) {
|
||||
record(dataBuffer);
|
||||
}
|
||||
return dataBuffer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataBuffer join(List<? extends DataBuffer> dataBuffers) {
|
||||
return record(super.join(dataBuffers));
|
||||
}
|
||||
|
||||
private DataBuffer record(DataBuffer buffer) {
|
||||
this.created.add(new DataBufferLeakInfo(buffer, new AssertionError(String.format(
|
||||
"DataBuffer leak: {%s} {%s} not released.%nStacktrace at buffer creation: ", buffer,
|
||||
ObjectUtils.getIdentityHexString(((NettyDataBuffer) buffer).getNativeBuffer())))));
|
||||
return buffer;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private static class DataBufferLeakInfo {
|
||||
|
||||
private final DataBuffer dataBuffer;
|
||||
|
||||
private final AssertionError error;
|
||||
|
||||
|
||||
DataBufferLeakInfo(DataBuffer dataBuffer, AssertionError error) {
|
||||
this.dataBuffer = dataBuffer;
|
||||
this.error = error;
|
||||
}
|
||||
|
||||
DataBuffer getDataBuffer() {
|
||||
return this.dataBuffer;
|
||||
}
|
||||
|
||||
AssertionError getError() {
|
||||
return this.error;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Store all intercepted incoming and outgoing payloads and then use
|
||||
* {@link #checkForLeaks()} at the end to check reference counts.
|
||||
*/
|
||||
private static class PayloadInterceptor extends AbstractRSocket implements RSocketInterceptor {
|
||||
|
||||
private final List<PayloadSavingDecorator> rsockets = new CopyOnWriteArrayList<>();
|
||||
|
||||
|
||||
void checkForLeaks() {
|
||||
this.rsockets.stream().map(PayloadSavingDecorator::getPayloads)
|
||||
.forEach(payloadInfoProcessor -> {
|
||||
payloadInfoProcessor.onComplete();
|
||||
payloadInfoProcessor
|
||||
.doOnNext(this::checkForLeak)
|
||||
.blockLast();
|
||||
});
|
||||
}
|
||||
|
||||
private void checkForLeak(PayloadLeakInfo info) {
|
||||
Instant start = Instant.now();
|
||||
while (true) {
|
||||
try {
|
||||
int count = info.getReferenceCount();
|
||||
assertTrue("Leaked payload (refCnt=" + count + "): " + info, count == 0);
|
||||
break;
|
||||
}
|
||||
catch (AssertionError ex) {
|
||||
if (Instant.now().isAfter(start.plus(Duration.ofSeconds(5)))) {
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
try {
|
||||
Thread.sleep(50);
|
||||
}
|
||||
catch (InterruptedException ex) {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void reset() {
|
||||
this.rsockets.forEach(PayloadSavingDecorator::reset);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public RSocket apply(RSocket rsocket) {
|
||||
PayloadSavingDecorator decorator = new PayloadSavingDecorator(rsocket);
|
||||
this.rsockets.add(decorator);
|
||||
return decorator;
|
||||
}
|
||||
|
||||
|
||||
private static class PayloadSavingDecorator extends AbstractRSocket {
|
||||
|
||||
private final RSocket delegate;
|
||||
|
||||
private ReplayProcessor<PayloadLeakInfo> payloads = ReplayProcessor.create();
|
||||
|
||||
|
||||
PayloadSavingDecorator(RSocket delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
|
||||
ReplayProcessor<PayloadLeakInfo> getPayloads() {
|
||||
return this.payloads;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
this.payloads = ReplayProcessor.create();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> fireAndForget(io.rsocket.Payload payload) {
|
||||
return this.delegate.fireAndForget(addPayload(payload));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<io.rsocket.Payload> requestResponse(io.rsocket.Payload payload) {
|
||||
return this.delegate.requestResponse(addPayload(payload)).doOnSuccess(this::addPayload);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<io.rsocket.Payload> requestStream(io.rsocket.Payload payload) {
|
||||
return this.delegate.requestStream(addPayload(payload)).doOnNext(this::addPayload);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<io.rsocket.Payload> requestChannel(Publisher<io.rsocket.Payload> payloads) {
|
||||
return this.delegate
|
||||
.requestChannel(Flux.from(payloads).doOnNext(this::addPayload))
|
||||
.doOnNext(this::addPayload);
|
||||
}
|
||||
|
||||
private io.rsocket.Payload addPayload(io.rsocket.Payload payload) {
|
||||
this.payloads.onNext(new PayloadLeakInfo(payload));
|
||||
return payload;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> metadataPush(io.rsocket.Payload payload) {
|
||||
return this.delegate.metadataPush(addPayload(payload));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private static class PayloadLeakInfo {
|
||||
|
||||
private final String description;
|
||||
|
||||
private final ReferenceCounted referenceCounted;
|
||||
|
||||
|
||||
PayloadLeakInfo(io.rsocket.Payload payload) {
|
||||
this.description = payload.toString();
|
||||
this.referenceCounted = payload;
|
||||
}
|
||||
|
||||
|
||||
int getReferenceCount() {
|
||||
return this.referenceCounted.refCnt();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return this.description;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -17,6 +17,8 @@ package org.springframework.messaging.rsocket;
|
|||
|
||||
import java.time.Duration;
|
||||
|
||||
import io.netty.buffer.PooledByteBufAllocator;
|
||||
import io.rsocket.Frame;
|
||||
import io.rsocket.RSocket;
|
||||
import io.rsocket.RSocketFactory;
|
||||
import io.rsocket.transport.netty.client.TcpClientTransport;
|
||||
|
|
@ -35,6 +37,7 @@ import org.springframework.context.annotation.Bean;
|
|||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.core.codec.CharSequenceEncoder;
|
||||
import org.springframework.core.codec.StringDecoder;
|
||||
import org.springframework.core.io.buffer.NettyDataBufferFactory;
|
||||
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
|
||||
import org.springframework.messaging.handler.annotation.MessageMapping;
|
||||
import org.springframework.stereotype.Controller;
|
||||
|
|
@ -68,6 +71,7 @@ public class RSocketClientToServerIntegrationTests {
|
|||
|
||||
server = RSocketFactory.receive()
|
||||
.addServerPlugin(interceptor)
|
||||
.frameDecoder(Frame::retain) // as per https://github.com/rsocket/rsocket-java#zero-copy
|
||||
.acceptor(context.getBean(MessageHandlerAcceptor.class))
|
||||
.transport(TcpServerTransport.create("localhost", 7000))
|
||||
.start()
|
||||
|
|
@ -75,6 +79,7 @@ public class RSocketClientToServerIntegrationTests {
|
|||
|
||||
client = RSocketFactory.connect()
|
||||
.dataMimeType(MimeTypeUtils.TEXT_PLAIN_VALUE)
|
||||
.frameDecoder(Frame::retain) // as per https://github.com/rsocket/rsocket-java#zero-copy
|
||||
.transport(TcpClientTransport.create("localhost", 7000))
|
||||
.start()
|
||||
.block();
|
||||
|
|
@ -261,6 +266,7 @@ public class RSocketClientToServerIntegrationTests {
|
|||
return RSocketStrategies.builder()
|
||||
.decoder(StringDecoder.allMimeTypes())
|
||||
.encoder(CharSequenceEncoder.allMimeTypes())
|
||||
.dataBufferFactory(new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ package org.springframework.messaging.rsocket;
|
|||
import java.time.Duration;
|
||||
import java.util.Collections;
|
||||
|
||||
import io.netty.buffer.PooledByteBufAllocator;
|
||||
import io.rsocket.Closeable;
|
||||
import io.rsocket.Frame;
|
||||
import io.rsocket.RSocket;
|
||||
import io.rsocket.RSocketFactory;
|
||||
import io.rsocket.transport.netty.client.TcpClientTransport;
|
||||
|
|
@ -39,6 +41,7 @@ import org.springframework.context.annotation.Bean;
|
|||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.core.codec.CharSequenceEncoder;
|
||||
import org.springframework.core.codec.StringDecoder;
|
||||
import org.springframework.core.io.buffer.NettyDataBufferFactory;
|
||||
import org.springframework.messaging.handler.annotation.MessageMapping;
|
||||
import org.springframework.stereotype.Controller;
|
||||
|
||||
|
|
@ -61,6 +64,7 @@ public class RSocketServerToClientIntegrationTests {
|
|||
context = new AnnotationConfigApplicationContext(RSocketConfig.class);
|
||||
|
||||
server = RSocketFactory.receive()
|
||||
.frameDecoder(Frame::retain) // as per https://github.com/rsocket/rsocket-java#zero-copy
|
||||
.acceptor(context.getBean("serverAcceptor", MessageHandlerAcceptor.class))
|
||||
.transport(TcpServerTransport.create("localhost", 7000))
|
||||
.start()
|
||||
|
|
@ -104,6 +108,7 @@ public class RSocketServerToClientIntegrationTests {
|
|||
rsocket = RSocketFactory.connect()
|
||||
.setupPayload(DefaultPayload.create("", destination))
|
||||
.dataMimeType("text/plain")
|
||||
.frameDecoder(Frame::retain) // as per https://github.com/rsocket/rsocket-java#zero-copy
|
||||
.acceptor(context.getBean("clientAcceptor", MessageHandlerAcceptor.class))
|
||||
.transport(TcpClientTransport.create("localhost", 7000))
|
||||
.start()
|
||||
|
|
@ -272,6 +277,7 @@ public class RSocketServerToClientIntegrationTests {
|
|||
return RSocketStrategies.builder()
|
||||
.decoder(StringDecoder.allMimeTypes())
|
||||
.encoder(CharSequenceEncoder.allMimeTypes())
|
||||
.dataBufferFactory(new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue