Updates for buffer management in RSocket
- Integration tests run with zero copy configuration. - RSocketBufferLeakTests has been added. - Updates in MessagingRSocket to ensure proper release See gh-21987
This commit is contained in:
		
							parent
							
								
									23b39ad27b
								
							
						
					
					
						commit
						9e7f557b4a
					
				|  | @ -1,5 +1,5 @@ | |||
| /* | ||||
|  * Copyright 2002-2018 the original author or authors. | ||||
|  * Copyright 2002-2019 the original author or authors. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  | @ -105,13 +105,15 @@ public abstract class AbstractDataBufferAllocatingTestCase { | |||
| 	 */ | ||||
| 	protected void waitForDataBufferRelease(Duration duration) throws InterruptedException { | ||||
| 		Instant start = Instant.now(); | ||||
| 		while (Instant.now().isBefore(start.plus(duration))) { | ||||
| 		while (true) { | ||||
| 			try { | ||||
| 				verifyAllocations(); | ||||
| 				break; | ||||
| 			} | ||||
| 			catch (AssertionError ex) { | ||||
| 				// ignore; | ||||
| 				if (Instant.now().isAfter(start.plus(duration))) { | ||||
| 					throw ex; | ||||
| 				} | ||||
| 			} | ||||
| 			Thread.sleep(50); | ||||
| 		} | ||||
|  |  | |||
|  | @ -396,10 +396,10 @@ public abstract class AbstractMethodMessageHandler<T> | |||
| 		if (matches.size() > 1) { | ||||
| 			Match<T> secondBestMatch = matches.get(1); | ||||
| 			if (comparator.compare(bestMatch, secondBestMatch) == 0) { | ||||
| 				Method m1 = bestMatch.handlerMethod.getMethod(); | ||||
| 				Method m2 = secondBestMatch.handlerMethod.getMethod(); | ||||
| 				HandlerMethod m1 = bestMatch.handlerMethod; | ||||
| 				HandlerMethod m2 = secondBestMatch.handlerMethod; | ||||
| 				throw new IllegalStateException("Ambiguous handler methods mapped for destination '" + | ||||
| 						destination + "': {" + m1 + ", " + m2 + "}"); | ||||
| 						destination + "': {" + m1.getShortLogMessage() + ", " + m2.getShortLogMessage() + "}"); | ||||
| 			} | ||||
| 		} | ||||
| 		return bestMatch; | ||||
|  |  | |||
|  | @ -244,7 +244,7 @@ final class DefaultRSocketRequester implements RSocketRequester { | |||
| 
 | ||||
| 			Decoder<?> decoder = strategies.decoder(elementType, dataMimeType); | ||||
| 			return (Mono<T>) decoder.decodeToMono( | ||||
| 					payloadMono.map(this::wrapPayloadData), elementType, dataMimeType, EMPTY_HINTS); | ||||
| 					payloadMono.map(this::retainDataAndReleasePayload), elementType, dataMimeType, EMPTY_HINTS); | ||||
| 		} | ||||
| 
 | ||||
| 		@SuppressWarnings("unchecked") | ||||
|  | @ -260,12 +260,12 @@ final class DefaultRSocketRequester implements RSocketRequester { | |||
| 
 | ||||
| 			Decoder<?> decoder = strategies.decoder(elementType, dataMimeType); | ||||
| 
 | ||||
| 			return payloadFlux.map(this::wrapPayloadData).concatMap(dataBuffer -> | ||||
| 			return payloadFlux.map(this::retainDataAndReleasePayload).concatMap(dataBuffer -> | ||||
| 					(Mono<T>) decoder.decodeToMono(Mono.just(dataBuffer), elementType, dataMimeType, EMPTY_HINTS)); | ||||
| 		} | ||||
| 
 | ||||
| 		private DataBuffer wrapPayloadData(Payload payload) { | ||||
| 			return PayloadUtils.wrapPayloadData(payload, strategies.dataBufferFactory()); | ||||
| 		private DataBuffer retainDataAndReleasePayload(Payload payload) { | ||||
| 			return PayloadUtils.retainDataAndReleasePayload(payload, strategies.dataBufferFactory()); | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -21,14 +21,13 @@ import java.util.Collections; | |||
| import java.util.List; | ||||
| import java.util.function.Consumer; | ||||
| 
 | ||||
| import io.netty.buffer.PooledByteBufAllocator; | ||||
| 
 | ||||
| import org.springframework.core.ReactiveAdapterRegistry; | ||||
| import org.springframework.core.codec.Decoder; | ||||
| import org.springframework.core.codec.Encoder; | ||||
| import org.springframework.core.io.buffer.DataBufferFactory; | ||||
| import org.springframework.core.io.buffer.NettyDataBufferFactory; | ||||
| import org.springframework.core.io.buffer.DefaultDataBufferFactory; | ||||
| import org.springframework.lang.Nullable; | ||||
| import org.springframework.util.Assert; | ||||
| 
 | ||||
| /** | ||||
|  * Default, package-private {@link RSocketStrategies} implementation. | ||||
|  | @ -88,11 +87,10 @@ final class DefaultRSocketStrategies implements RSocketStrategies { | |||
| 
 | ||||
| 		private final List<Decoder<?>> decoders = new ArrayList<>(); | ||||
| 
 | ||||
| 		@Nullable | ||||
| 		private ReactiveAdapterRegistry adapterRegistry; | ||||
| 		private ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry.getSharedInstance(); | ||||
| 
 | ||||
| 		@Nullable | ||||
| 		private DataBufferFactory bufferFactory; | ||||
| 		private DataBufferFactory dataBufferFactory; | ||||
| 
 | ||||
| 
 | ||||
| 		@Override | ||||
|  | @ -121,23 +119,21 @@ final class DefaultRSocketStrategies implements RSocketStrategies { | |||
| 
 | ||||
| 		@Override | ||||
| 		public Builder reactiveAdapterStrategy(ReactiveAdapterRegistry registry) { | ||||
| 			Assert.notNull(registry, "ReactiveAdapterRegistry is required"); | ||||
| 			this.adapterRegistry = registry; | ||||
| 			return this; | ||||
| 		} | ||||
| 
 | ||||
| 		@Override | ||||
| 		public Builder dataBufferFactory(DataBufferFactory bufferFactory) { | ||||
| 			this.bufferFactory = bufferFactory; | ||||
| 			this.dataBufferFactory = bufferFactory; | ||||
| 			return this; | ||||
| 		} | ||||
| 
 | ||||
| 		@Override | ||||
| 		public RSocketStrategies build() { | ||||
| 			return new DefaultRSocketStrategies(this.encoders, this.decoders, | ||||
| 					this.adapterRegistry != null ? | ||||
| 							this.adapterRegistry : ReactiveAdapterRegistry.getSharedInstance(), | ||||
| 					this.bufferFactory != null ? this.bufferFactory : | ||||
| 							new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT)); | ||||
| 			return new DefaultRSocketStrategies(this.encoders, this.decoders, this.adapterRegistry, | ||||
| 					this.dataBufferFactory != null ? this.dataBufferFactory : new DefaultDataBufferFactory()); | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -15,6 +15,7 @@ | |||
|  */ | ||||
| package org.springframework.messaging.rsocket; | ||||
| 
 | ||||
| import java.util.concurrent.atomic.AtomicBoolean; | ||||
| import java.util.function.Function; | ||||
| 
 | ||||
| import io.rsocket.AbstractRSocket; | ||||
|  | @ -29,7 +30,7 @@ import reactor.core.publisher.MonoProcessor; | |||
| import org.springframework.core.io.buffer.DataBuffer; | ||||
| import org.springframework.core.io.buffer.DataBufferFactory; | ||||
| import org.springframework.core.io.buffer.DataBufferUtils; | ||||
| import org.springframework.core.io.buffer.PooledDataBuffer; | ||||
| import org.springframework.core.io.buffer.NettyDataBuffer; | ||||
| import org.springframework.lang.Nullable; | ||||
| import org.springframework.messaging.Message; | ||||
| import org.springframework.messaging.MessageHeaders; | ||||
|  | @ -84,6 +85,9 @@ class MessagingRSocket extends AbstractRSocket { | |||
| 		if (StringUtils.hasText(payload.dataMimeType())) { | ||||
| 			this.dataMimeType = MimeTypeUtils.parseMimeType(payload.dataMimeType()); | ||||
| 		} | ||||
| 		// frameDecoder does not apply to connectionSetupPayload | ||||
| 		// so retain here since handle expects it.. | ||||
| 		payload.retain(); | ||||
| 		return handle(payload); | ||||
| 	} | ||||
| 
 | ||||
|  | @ -120,54 +124,72 @@ class MessagingRSocket extends AbstractRSocket { | |||
| 
 | ||||
| 
 | ||||
| 	private Mono<Void> handle(Payload payload) { | ||||
| 		Message<?> message = MessageBuilder.createMessage( | ||||
| 				Mono.fromCallable(() -> wrapPayloadData(payload)), createHeaders(payload, null)); | ||||
| 		String destination = getDestination(payload); | ||||
| 		MessageHeaders headers = createHeaders(destination, null); | ||||
| 		DataBuffer dataBuffer = retainDataAndReleasePayload(payload); | ||||
| 		int refCount = refCount(dataBuffer); | ||||
| 		Message<?> message = MessageBuilder.createMessage(dataBuffer, headers); | ||||
| 		return Mono.defer(() -> this.handler.apply(message)) | ||||
| 				.doFinally(s -> { | ||||
| 					if (refCount(dataBuffer) == refCount) { | ||||
| 						DataBufferUtils.release(dataBuffer); | ||||
| 					} | ||||
| 				}); | ||||
| 	} | ||||
| 
 | ||||
| 		return this.handler.apply(message); | ||||
| 	private int refCount(DataBuffer dataBuffer) { | ||||
| 		return dataBuffer instanceof NettyDataBuffer ? | ||||
| 				((NettyDataBuffer) dataBuffer).getNativeBuffer().refCnt() : 1; | ||||
| 	} | ||||
| 
 | ||||
| 	private Flux<Payload> handleAndReply(Payload firstPayload, Flux<Payload> payloads) { | ||||
| 		MonoProcessor<Flux<Payload>> replyMono = MonoProcessor.create(); | ||||
| 		Message<?> message = MessageBuilder.createMessage( | ||||
| 				payloads.map(this::wrapPayloadData).doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release), | ||||
| 				createHeaders(firstPayload, replyMono)); | ||||
| 		String destination = getDestination(firstPayload); | ||||
| 		MessageHeaders headers = createHeaders(destination, replyMono); | ||||
| 
 | ||||
| 		return this.handler.apply(message) | ||||
| 		AtomicBoolean read = new AtomicBoolean(); | ||||
| 		Flux<DataBuffer> buffers = payloads.map(this::retainDataAndReleasePayload).doOnSubscribe(s -> read.set(true)); | ||||
| 		Message<Flux<DataBuffer>> message = MessageBuilder.createMessage(buffers, headers); | ||||
| 
 | ||||
| 		return Mono.defer(() -> this.handler.apply(message)) | ||||
| 				.doFinally(s -> { | ||||
| 					// Subscription should have happened by now due to ChannelSendOperator | ||||
| 					if (!read.get()) { | ||||
| 						buffers.subscribe(DataBufferUtils::release); | ||||
| 					} | ||||
| 				}) | ||||
| 				.thenMany(Flux.defer(() -> replyMono.isTerminated() ? | ||||
| 						replyMono.flatMapMany(Function.identity()) : | ||||
| 						Mono.error(new IllegalStateException("Something went wrong: reply Mono not set")))); | ||||
| 	} | ||||
| 
 | ||||
| 	private MessageHeaders createHeaders(Payload payload, @Nullable MonoProcessor<?> replyMono) { | ||||
| 	private String getDestination(Payload payload) { | ||||
| 
 | ||||
| 		// TODO: | ||||
| 		// For now treat the metadata as a simple string with routing information. | ||||
| 		// We'll have to get more sophisticated once the routing extension is completed. | ||||
| 		// https://github.com/rsocket/rsocket-java/issues/568 | ||||
| 
 | ||||
| 		return payload.getMetadataUtf8(); | ||||
| 	} | ||||
| 
 | ||||
| 	private DataBuffer retainDataAndReleasePayload(Payload payload) { | ||||
| 		return PayloadUtils.retainDataAndReleasePayload(payload, this.strategies.dataBufferFactory()); | ||||
| 	} | ||||
| 
 | ||||
| 	private MessageHeaders createHeaders(String destination, @Nullable MonoProcessor<?> replyMono) { | ||||
| 		MessageHeaderAccessor headers = new MessageHeaderAccessor(); | ||||
| 
 | ||||
| 		String destination = payload.getMetadataUtf8(); | ||||
| 		headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, destination); | ||||
| 
 | ||||
| 		if (this.dataMimeType != null) { | ||||
| 			headers.setContentType(this.dataMimeType); | ||||
| 		} | ||||
| 
 | ||||
| 		headers.setHeader(RSocketRequesterMethodArgumentResolver.RSOCKET_REQUESTER_HEADER, this.requester); | ||||
| 
 | ||||
| 		if (replyMono != null) { | ||||
| 			headers.setHeader(RSocketPayloadReturnValueHandler.RESPONSE_HEADER, replyMono); | ||||
| 		} | ||||
| 
 | ||||
| 		DataBufferFactory bufferFactory = this.strategies.dataBufferFactory(); | ||||
| 		headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, bufferFactory); | ||||
| 
 | ||||
| 		return headers.getMessageHeaders(); | ||||
| 	} | ||||
| 
 | ||||
| 	private DataBuffer wrapPayloadData(Payload payload) { | ||||
| 		return PayloadUtils.wrapPayloadData(payload, this.strategies.dataBufferFactory()); | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
|  |  | |||
|  | @ -15,6 +15,8 @@ | |||
|  */ | ||||
| package org.springframework.messaging.rsocket; | ||||
| 
 | ||||
| import io.netty.buffer.ByteBuf; | ||||
| import io.rsocket.Frame; | ||||
| import io.rsocket.Payload; | ||||
| import io.rsocket.util.ByteBufPayload; | ||||
| import io.rsocket.util.DefaultPayload; | ||||
|  | @ -24,6 +26,7 @@ import org.springframework.core.io.buffer.DataBufferFactory; | |||
| import org.springframework.core.io.buffer.DefaultDataBuffer; | ||||
| import org.springframework.core.io.buffer.NettyDataBuffer; | ||||
| import org.springframework.core.io.buffer.NettyDataBufferFactory; | ||||
| import org.springframework.util.Assert; | ||||
| 
 | ||||
| /** | ||||
|  * Static utility methods to create {@link Payload} from {@link DataBuffer}s | ||||
|  | @ -35,19 +38,31 @@ import org.springframework.core.io.buffer.NettyDataBufferFactory; | |||
| abstract class PayloadUtils { | ||||
| 
 | ||||
| 	/** | ||||
| 	 * Return the Payload data wrapped as DataBuffer. If the bufferFactory is | ||||
| 	 * {@link NettyDataBufferFactory} the payload retained and sliced. | ||||
| 	 * @param payload the input payload | ||||
| 	 * @param bufferFactory the BufferFactory to use to wrap | ||||
| 	 * @return the DataBuffer wrapper | ||||
| 	 * Use this method to slice, retain and wrap the data portion of the | ||||
| 	 * {@code Payload}, and also to release the {@code Payload}. This assumes | ||||
| 	 * the Payload metadata has been read by now and ensures downstream code | ||||
| 	 * need only be aware of {@code DataBuffer}s. | ||||
| 	 * @param payload the payload to process | ||||
| 	 * @param bufferFactory the DataBufferFactory to wrap with | ||||
| 	 * @return the created {@code DataBuffer} instance | ||||
| 	 */ | ||||
| 	public static DataBuffer wrapPayloadData(Payload payload, DataBufferFactory bufferFactory) { | ||||
| 		if (bufferFactory instanceof NettyDataBufferFactory) { | ||||
| 			return ((NettyDataBufferFactory) bufferFactory).wrap(payload.retain().sliceData()); | ||||
| 		} | ||||
| 		else { | ||||
| 	public static DataBuffer retainDataAndReleasePayload(Payload payload, DataBufferFactory bufferFactory) { | ||||
| 		try { | ||||
| 			if (bufferFactory instanceof NettyDataBufferFactory) { | ||||
| 				ByteBuf byteBuf = payload.sliceData().retain(); | ||||
| 				return ((NettyDataBufferFactory) bufferFactory).wrap(byteBuf); | ||||
| 			} | ||||
| 
 | ||||
| 			Assert.isTrue(!(payload instanceof ByteBufPayload) && !(payload instanceof Frame), | ||||
| 					"NettyDataBufferFactory expected, actual: " + bufferFactory.getClass().getSimpleName()); | ||||
| 
 | ||||
| 			return bufferFactory.wrap(payload.getData()); | ||||
| 		} | ||||
| 		finally { | ||||
| 			if (payload.refCnt() > 0) { | ||||
| 				payload.release(); | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	/** | ||||
|  |  | |||
|  | @ -142,12 +142,23 @@ public interface RSocketStrategies { | |||
| 		Builder reactiveAdapterStrategy(ReactiveAdapterRegistry registry); | ||||
| 
 | ||||
| 		/** | ||||
| 		 * Configure the DataBufferFactory to use for the allocation of buffers | ||||
| 		 * when creating or responding requests. | ||||
| 		 * <p>By default this is an instance of | ||||
| 		 * Configure the DataBufferFactory to use for allocating buffers, for | ||||
| 		 * example when preparing requests or when responding. The choice here | ||||
| 		 * must be aligned with the frame decoder configured in | ||||
| 		 * {@link io.rsocket.RSocketFactory}. | ||||
| 		 * <p>By default this property is an instance of | ||||
| 		 * {@link org.springframework.core.io.buffer.DefaultDataBufferFactory | ||||
| 		 * DefaultDataBufferFactory} matching to the default frame decoder in | ||||
| 		 * {@link io.rsocket.RSocketFactory} which copies the payload. This | ||||
| 		 * comes at cost to performance but does not require reference counting | ||||
| 		 * and eliminates possibility for memory leaks. | ||||
| 		 * <p>To switch to a zero-copy strategy, | ||||
| 		 * <a href="https://github.com/rsocket/rsocket-java#zero-copy">configure RSocket</a> | ||||
| 		 * accordingly, and then configure this property with an instance of | ||||
| 		 * {@link org.springframework.core.io.buffer.NettyDataBufferFactory | ||||
| 		 * NettyDataBufferFactory} with {@link PooledByteBufAllocator#DEFAULT}. | ||||
| 		 * @param bufferFactory the buffer factory to use | ||||
| 		 * NettyDataBufferFactory} with a pooled allocator such as | ||||
| 		 * {@link PooledByteBufAllocator#DEFAULT}. | ||||
| 		 * @param bufferFactory the DataBufferFactory to use | ||||
| 		 */ | ||||
| 		Builder dataBufferFactory(DataBufferFactory bufferFactory); | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,466 @@ | |||
| /* | ||||
|  * Copyright 2002-2019 the original author or authors. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *      http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| package org.springframework.messaging.rsocket; | ||||
| 
 | ||||
| import java.time.Duration; | ||||
| import java.time.Instant; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| import java.util.concurrent.CopyOnWriteArrayList; | ||||
| 
 | ||||
| import io.netty.buffer.ByteBuf; | ||||
| import io.netty.buffer.ByteBufAllocator; | ||||
| import io.netty.buffer.PooledByteBufAllocator; | ||||
| import io.netty.buffer.Unpooled; | ||||
| import io.netty.util.ReferenceCounted; | ||||
| import io.rsocket.AbstractRSocket; | ||||
| import io.rsocket.Frame; | ||||
| import io.rsocket.RSocket; | ||||
| import io.rsocket.RSocketFactory; | ||||
| import io.rsocket.plugins.RSocketInterceptor; | ||||
| import io.rsocket.transport.netty.client.TcpClientTransport; | ||||
| import io.rsocket.transport.netty.server.CloseableChannel; | ||||
| import io.rsocket.transport.netty.server.TcpServerTransport; | ||||
| import org.junit.After; | ||||
| import org.junit.AfterClass; | ||||
| import org.junit.Before; | ||||
| import org.junit.BeforeClass; | ||||
| import org.junit.Test; | ||||
| import org.reactivestreams.Publisher; | ||||
| import reactor.core.publisher.Flux; | ||||
| import reactor.core.publisher.Mono; | ||||
| import reactor.core.publisher.ReplayProcessor; | ||||
| import reactor.test.StepVerifier; | ||||
| 
 | ||||
| import org.springframework.context.annotation.AnnotationConfigApplicationContext; | ||||
| import org.springframework.context.annotation.Bean; | ||||
| import org.springframework.context.annotation.Configuration; | ||||
| import org.springframework.core.codec.CharSequenceEncoder; | ||||
| import org.springframework.core.codec.StringDecoder; | ||||
| import org.springframework.core.io.Resource; | ||||
| import org.springframework.core.io.buffer.DataBuffer; | ||||
| import org.springframework.core.io.buffer.NettyDataBuffer; | ||||
| import org.springframework.core.io.buffer.NettyDataBufferFactory; | ||||
| import org.springframework.core.io.buffer.PooledDataBuffer; | ||||
| import org.springframework.messaging.handler.annotation.MessageExceptionHandler; | ||||
| import org.springframework.messaging.handler.annotation.MessageMapping; | ||||
| import org.springframework.messaging.handler.annotation.Payload; | ||||
| import org.springframework.stereotype.Controller; | ||||
| import org.springframework.util.MimeTypeUtils; | ||||
| import org.springframework.util.ObjectUtils; | ||||
| 
 | ||||
| import static org.junit.Assert.*; | ||||
| 
 | ||||
| /** | ||||
|  * Tests for scenarios that could lead to Payload and/or DataBuffer leaks. | ||||
|  * | ||||
|  * @author Rossen Stoyanchev | ||||
|  */ | ||||
| public class RSocketBufferLeakTests { | ||||
| 
 | ||||
| 	private static AnnotationConfigApplicationContext context; | ||||
| 
 | ||||
| 	private static final PayloadInterceptor payloadInterceptor = new PayloadInterceptor(); | ||||
| 
 | ||||
| 	private static CloseableChannel server; | ||||
| 
 | ||||
| 	private static RSocket client; | ||||
| 
 | ||||
| 	private static RSocketRequester requester; | ||||
| 
 | ||||
| 
 | ||||
| 	@BeforeClass | ||||
| 	@SuppressWarnings("ConstantConditions") | ||||
| 	public static void setupOnce() { | ||||
| 
 | ||||
| 		context = new AnnotationConfigApplicationContext(ServerConfig.class); | ||||
| 
 | ||||
| 		server = RSocketFactory.receive() | ||||
| 				.frameDecoder(Frame::retain) // zero copy | ||||
| 				.addServerPlugin(payloadInterceptor) // intercept responding | ||||
| 				.acceptor(context.getBean(MessageHandlerAcceptor.class)) | ||||
| 				.transport(TcpServerTransport.create("localhost", 7000)) | ||||
| 				.start() | ||||
| 				.block(); | ||||
| 
 | ||||
| 		client = RSocketFactory.connect() | ||||
| 				.frameDecoder(Frame::retain) // zero copy | ||||
| 				.addClientPlugin(payloadInterceptor) // intercept outgoing requests | ||||
| 				.dataMimeType(MimeTypeUtils.TEXT_PLAIN_VALUE) | ||||
| 				.transport(TcpClientTransport.create("localhost", 7000)) | ||||
| 				.start() | ||||
| 				.block(); | ||||
| 
 | ||||
| 		requester = RSocketRequester.create( | ||||
| 				client, MimeTypeUtils.TEXT_PLAIN, context.getBean(RSocketStrategies.class)); | ||||
| 	} | ||||
| 
 | ||||
| 	@AfterClass | ||||
| 	public static void tearDownOnce() { | ||||
| 		client.dispose(); | ||||
| 		server.dispose(); | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	@Before | ||||
| 	public void setUp() { | ||||
| 		getLeakAwareNettyDataBufferFactory().reset(); | ||||
| 		payloadInterceptor.reset(); | ||||
| 	} | ||||
| 
 | ||||
| 	@After | ||||
| 	public void tearDown() throws InterruptedException { | ||||
| 		getLeakAwareNettyDataBufferFactory().checkForLeaks(Duration.ofSeconds(5)); | ||||
| 		payloadInterceptor.checkForLeaks(); | ||||
| 	} | ||||
| 
 | ||||
| 	private LeakAwareNettyDataBufferFactory getLeakAwareNettyDataBufferFactory() { | ||||
| 		return (LeakAwareNettyDataBufferFactory) context.getBean(RSocketStrategies.class).dataBufferFactory(); | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void assemblyTimeErrorForHandleAndReply() { | ||||
| 		Mono<String> result = requester.route("A.B").data("foo").retrieveMono(String.class); | ||||
| 		StepVerifier.create(result).expectErrorMatches(ex -> { | ||||
| 			String prefix = "Ambiguous handler methods mapped for destination 'A.B':"; | ||||
| 			return ex.getMessage().startsWith(prefix); | ||||
| 		}).verify(); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void subscriptionTimeErrorForHandleAndReply() { | ||||
| 		Mono<String> result = requester.route("not-decodable").data("foo").retrieveMono(String.class); | ||||
| 		StepVerifier.create(result).expectErrorMatches(ex -> { | ||||
| 			String prefix = "Cannot decode to [org.springframework.core.io.Resource]"; | ||||
| 			return ex.getMessage().contains(prefix); | ||||
| 		}).verify(); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void errorSignalWithExceptionHandler() { | ||||
| 		Mono<String> result = requester.route("error-signal").data("foo").retrieveMono(String.class); | ||||
| 		StepVerifier.create(result).expectNext("Handled 'bad input'").verifyComplete(); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void ignoreInput() { | ||||
| 		Flux<String> result = requester.route("ignore-input").data("a").retrieveFlux(String.class); | ||||
| 		StepVerifier.create(result).expectNext("bar").verifyComplete(); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void retrieveMonoFromFluxResponderMethod() { | ||||
| 		Mono<String> result = requester.route("request-stream").data("foo").retrieveMono(String.class); | ||||
| 		StepVerifier.create(result).expectNext("foo-1").verifyComplete(); | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	@Controller | ||||
| 	static class ServerController { | ||||
| 
 | ||||
| 		@MessageMapping("A.*") | ||||
| 		void ambiguousMatchA(String payload) { | ||||
| 			throw new IllegalStateException("Unexpected call"); | ||||
| 		} | ||||
| 
 | ||||
| 		@MessageMapping("*.B") | ||||
| 		void ambiguousMatchB(String payload) { | ||||
| 			throw new IllegalStateException("Unexpected call"); | ||||
| 		} | ||||
| 
 | ||||
| 		@MessageMapping("not-decodable") | ||||
| 		void notDecodable(@Payload Resource resource) { | ||||
| 			throw new IllegalStateException("Unexpected call"); | ||||
| 		} | ||||
| 
 | ||||
| 		@MessageMapping("error-signal") | ||||
| 		public Flux<String> errorSignal(String payload) { | ||||
| 			return Flux.error(new IllegalArgumentException("bad input")) | ||||
| 					.delayElements(Duration.ofMillis(10)) | ||||
| 					.cast(String.class); | ||||
| 		} | ||||
| 
 | ||||
| 		@MessageExceptionHandler | ||||
| 		public String handleIllegalArgument(IllegalArgumentException ex) { | ||||
| 			return "Handled '" + ex.getMessage() + "'"; | ||||
| 		} | ||||
| 
 | ||||
| 		@MessageMapping("ignore-input") | ||||
| 		Mono<String> ignoreInput() { | ||||
| 			return Mono.delay(Duration.ofMillis(10)).map(l -> "bar"); | ||||
| 		} | ||||
| 
 | ||||
| 		@MessageMapping("request-stream") | ||||
| 		Flux<String> stream(String payload) { | ||||
| 			return Flux.range(1,100).delayElements(Duration.ofMillis(10)).map(idx -> payload + "-" + idx); | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	@Configuration | ||||
| 	static class ServerConfig { | ||||
| 
 | ||||
| 		@Bean | ||||
| 		public ServerController controller() { | ||||
| 			return new ServerController(); | ||||
| 		} | ||||
| 
 | ||||
| 		@Bean | ||||
| 		public MessageHandlerAcceptor messageHandlerAcceptor() { | ||||
| 			MessageHandlerAcceptor acceptor = new MessageHandlerAcceptor(); | ||||
| 			acceptor.setRSocketStrategies(rsocketStrategies()); | ||||
| 			return acceptor; | ||||
| 		} | ||||
| 
 | ||||
| 		@Bean | ||||
| 		public RSocketStrategies rsocketStrategies() { | ||||
| 			return RSocketStrategies.builder() | ||||
| 					.decoder(StringDecoder.allMimeTypes()) | ||||
| 					.encoder(CharSequenceEncoder.allMimeTypes()) | ||||
| 					.dataBufferFactory(new LeakAwareNettyDataBufferFactory(PooledByteBufAllocator.DEFAULT)) | ||||
| 					.build(); | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	/** | ||||
| 	 * Similar {@link org.springframework.core.io.buffer.LeakAwareDataBufferFactory} | ||||
| 	 * but extends {@link NettyDataBufferFactory} rather than rely on | ||||
| 	 * decoration, since {@link PayloadUtils} does instanceof checks. | ||||
| 	 */ | ||||
| 	private static class LeakAwareNettyDataBufferFactory extends NettyDataBufferFactory { | ||||
| 
 | ||||
| 		private final List<DataBufferLeakInfo> created = new ArrayList<>(); | ||||
| 
 | ||||
| 
 | ||||
| 		LeakAwareNettyDataBufferFactory(ByteBufAllocator byteBufAllocator) { | ||||
| 			super(byteBufAllocator); | ||||
| 		} | ||||
| 
 | ||||
| 
 | ||||
| 		void checkForLeaks(Duration duration) throws InterruptedException { | ||||
| 			Instant start = Instant.now(); | ||||
| 			while (true) { | ||||
| 				try { | ||||
| 					this.created.forEach(info -> { | ||||
| 						if (((PooledDataBuffer) info.getDataBuffer()).isAllocated()) { | ||||
| 							throw info.getError(); | ||||
| 						} | ||||
| 					}); | ||||
| 					break; | ||||
| 				} | ||||
| 				catch (AssertionError ex) { | ||||
| 					if (Instant.now().isAfter(start.plus(duration))) { | ||||
| 						throw ex; | ||||
| 					} | ||||
| 				} | ||||
| 				Thread.sleep(50); | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		void reset() { | ||||
| 			this.created.clear(); | ||||
| 		} | ||||
| 
 | ||||
| 
 | ||||
| 		@Override | ||||
| 		public NettyDataBuffer allocateBuffer() { | ||||
| 			return (NettyDataBuffer) record(super.allocateBuffer()); | ||||
| 		} | ||||
| 
 | ||||
| 		@Override | ||||
| 		public NettyDataBuffer allocateBuffer(int initialCapacity) { | ||||
| 			return (NettyDataBuffer) record(super.allocateBuffer(initialCapacity)); | ||||
| 		} | ||||
| 
 | ||||
| 		@Override | ||||
| 		public NettyDataBuffer wrap(ByteBuf byteBuf) { | ||||
| 			NettyDataBuffer dataBuffer = super.wrap(byteBuf); | ||||
| 			if (byteBuf != Unpooled.EMPTY_BUFFER) { | ||||
| 				record(dataBuffer); | ||||
| 			} | ||||
| 			return dataBuffer; | ||||
| 		} | ||||
| 
 | ||||
| 		@Override | ||||
| 		public DataBuffer join(List<? extends DataBuffer> dataBuffers) { | ||||
| 			return record(super.join(dataBuffers)); | ||||
| 		} | ||||
| 
 | ||||
| 		private DataBuffer record(DataBuffer buffer) { | ||||
| 			this.created.add(new DataBufferLeakInfo(buffer, new AssertionError(String.format( | ||||
| 					"DataBuffer leak: {%s} {%s} not released.%nStacktrace at buffer creation: ", buffer, | ||||
| 					ObjectUtils.getIdentityHexString(((NettyDataBuffer) buffer).getNativeBuffer()))))); | ||||
| 			return buffer; | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	private static class DataBufferLeakInfo { | ||||
| 
 | ||||
| 		private final DataBuffer dataBuffer; | ||||
| 
 | ||||
| 		private final AssertionError error; | ||||
| 
 | ||||
| 
 | ||||
| 		DataBufferLeakInfo(DataBuffer dataBuffer, AssertionError error) { | ||||
| 			this.dataBuffer = dataBuffer; | ||||
| 			this.error = error; | ||||
| 		} | ||||
| 
 | ||||
| 		DataBuffer getDataBuffer() { | ||||
| 			return this.dataBuffer; | ||||
| 		} | ||||
| 
 | ||||
| 		AssertionError getError() { | ||||
| 			return this.error; | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	/** | ||||
| 	 * Store all intercepted incoming and outgoing payloads and then use | ||||
| 	 * {@link #checkForLeaks()} at the end to check reference counts. | ||||
| 	 */ | ||||
| 	private static class PayloadInterceptor extends AbstractRSocket implements RSocketInterceptor { | ||||
| 
 | ||||
| 		private final List<PayloadSavingDecorator> rsockets = new CopyOnWriteArrayList<>(); | ||||
| 
 | ||||
| 
 | ||||
| 		void checkForLeaks() { | ||||
| 			this.rsockets.stream().map(PayloadSavingDecorator::getPayloads) | ||||
| 					.forEach(payloadInfoProcessor -> { | ||||
| 						payloadInfoProcessor.onComplete(); | ||||
| 						payloadInfoProcessor | ||||
| 								.doOnNext(this::checkForLeak) | ||||
| 								.blockLast(); | ||||
| 					}); | ||||
| 		} | ||||
| 
 | ||||
| 		private void checkForLeak(PayloadLeakInfo info) { | ||||
| 			Instant start = Instant.now(); | ||||
| 			while (true) { | ||||
| 				try { | ||||
| 					int count = info.getReferenceCount(); | ||||
| 					assertTrue("Leaked payload (refCnt=" + count + "): " + info, count == 0); | ||||
| 					break; | ||||
| 				} | ||||
| 				catch (AssertionError ex) { | ||||
| 					if (Instant.now().isAfter(start.plus(Duration.ofSeconds(5)))) { | ||||
| 						throw ex; | ||||
| 					} | ||||
| 				} | ||||
| 				try { | ||||
| 					Thread.sleep(50); | ||||
| 				} | ||||
| 				catch (InterruptedException ex) { | ||||
| 					// ignore | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		public void reset() { | ||||
| 			this.rsockets.forEach(PayloadSavingDecorator::reset); | ||||
| 		} | ||||
| 
 | ||||
| 
 | ||||
| 		@Override | ||||
| 		public RSocket apply(RSocket rsocket) { | ||||
| 			PayloadSavingDecorator decorator = new PayloadSavingDecorator(rsocket); | ||||
| 			this.rsockets.add(decorator); | ||||
| 			return decorator; | ||||
| 		} | ||||
| 
 | ||||
| 
 | ||||
| 		private static class PayloadSavingDecorator extends AbstractRSocket { | ||||
| 
 | ||||
| 			private final RSocket delegate; | ||||
| 
 | ||||
| 			private ReplayProcessor<PayloadLeakInfo> payloads = ReplayProcessor.create(); | ||||
| 
 | ||||
| 
 | ||||
| 			PayloadSavingDecorator(RSocket delegate) { | ||||
| 				this.delegate = delegate; | ||||
| 			} | ||||
| 
 | ||||
| 
 | ||||
| 			ReplayProcessor<PayloadLeakInfo> getPayloads() { | ||||
| 				return this.payloads; | ||||
| 			} | ||||
| 
 | ||||
| 			void reset() { | ||||
| 				this.payloads = ReplayProcessor.create(); | ||||
| 			} | ||||
| 
 | ||||
| 			@Override | ||||
| 			public Mono<Void> fireAndForget(io.rsocket.Payload payload) { | ||||
| 				return this.delegate.fireAndForget(addPayload(payload)); | ||||
| 			} | ||||
| 
 | ||||
| 			@Override | ||||
| 			public Mono<io.rsocket.Payload> requestResponse(io.rsocket.Payload payload) { | ||||
| 				return this.delegate.requestResponse(addPayload(payload)).doOnSuccess(this::addPayload); | ||||
| 			} | ||||
| 
 | ||||
| 			@Override | ||||
| 			public Flux<io.rsocket.Payload> requestStream(io.rsocket.Payload payload) { | ||||
| 				return this.delegate.requestStream(addPayload(payload)).doOnNext(this::addPayload); | ||||
| 			} | ||||
| 
 | ||||
| 			@Override | ||||
| 			public Flux<io.rsocket.Payload> requestChannel(Publisher<io.rsocket.Payload> payloads) { | ||||
| 				return this.delegate | ||||
| 						.requestChannel(Flux.from(payloads).doOnNext(this::addPayload)) | ||||
| 						.doOnNext(this::addPayload); | ||||
| 			} | ||||
| 
 | ||||
| 			private io.rsocket.Payload addPayload(io.rsocket.Payload payload) { | ||||
| 				this.payloads.onNext(new PayloadLeakInfo(payload)); | ||||
| 				return payload; | ||||
| 			} | ||||
| 
 | ||||
| 			@Override | ||||
| 			public Mono<Void> metadataPush(io.rsocket.Payload payload) { | ||||
| 				return this.delegate.metadataPush(addPayload(payload)); | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	private static class PayloadLeakInfo { | ||||
| 
 | ||||
| 		private final String description; | ||||
| 
 | ||||
| 		private final ReferenceCounted referenceCounted; | ||||
| 
 | ||||
| 
 | ||||
| 		PayloadLeakInfo(io.rsocket.Payload payload) { | ||||
| 			this.description = payload.toString(); | ||||
| 			this.referenceCounted = payload; | ||||
| 		} | ||||
| 
 | ||||
| 
 | ||||
| 		int getReferenceCount() { | ||||
| 			return this.referenceCounted.refCnt(); | ||||
| 		} | ||||
| 
 | ||||
| 		@Override | ||||
| 		public String toString() { | ||||
| 			return this.description; | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | @ -17,6 +17,8 @@ package org.springframework.messaging.rsocket; | |||
| 
 | ||||
| import java.time.Duration; | ||||
| 
 | ||||
| import io.netty.buffer.PooledByteBufAllocator; | ||||
| import io.rsocket.Frame; | ||||
| import io.rsocket.RSocket; | ||||
| import io.rsocket.RSocketFactory; | ||||
| import io.rsocket.transport.netty.client.TcpClientTransport; | ||||
|  | @ -35,6 +37,7 @@ import org.springframework.context.annotation.Bean; | |||
| import org.springframework.context.annotation.Configuration; | ||||
| import org.springframework.core.codec.CharSequenceEncoder; | ||||
| import org.springframework.core.codec.StringDecoder; | ||||
| import org.springframework.core.io.buffer.NettyDataBufferFactory; | ||||
| import org.springframework.messaging.handler.annotation.MessageExceptionHandler; | ||||
| import org.springframework.messaging.handler.annotation.MessageMapping; | ||||
| import org.springframework.stereotype.Controller; | ||||
|  | @ -68,6 +71,7 @@ public class RSocketClientToServerIntegrationTests { | |||
| 
 | ||||
| 		server = RSocketFactory.receive() | ||||
| 				.addServerPlugin(interceptor) | ||||
| 				.frameDecoder(Frame::retain)  // as per https://github.com/rsocket/rsocket-java#zero-copy | ||||
| 				.acceptor(context.getBean(MessageHandlerAcceptor.class)) | ||||
| 				.transport(TcpServerTransport.create("localhost", 7000)) | ||||
| 				.start() | ||||
|  | @ -75,6 +79,7 @@ public class RSocketClientToServerIntegrationTests { | |||
| 
 | ||||
| 		client = RSocketFactory.connect() | ||||
| 				.dataMimeType(MimeTypeUtils.TEXT_PLAIN_VALUE) | ||||
| 				.frameDecoder(Frame::retain)  // as per https://github.com/rsocket/rsocket-java#zero-copy | ||||
| 				.transport(TcpClientTransport.create("localhost", 7000)) | ||||
| 				.start() | ||||
| 				.block(); | ||||
|  | @ -261,6 +266,7 @@ public class RSocketClientToServerIntegrationTests { | |||
| 			return RSocketStrategies.builder() | ||||
| 					.decoder(StringDecoder.allMimeTypes()) | ||||
| 					.encoder(CharSequenceEncoder.allMimeTypes()) | ||||
| 					.dataBufferFactory(new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT)) | ||||
| 					.build(); | ||||
| 		} | ||||
| 	} | ||||
|  |  | |||
|  | @ -18,7 +18,9 @@ package org.springframework.messaging.rsocket; | |||
| import java.time.Duration; | ||||
| import java.util.Collections; | ||||
| 
 | ||||
| import io.netty.buffer.PooledByteBufAllocator; | ||||
| import io.rsocket.Closeable; | ||||
| import io.rsocket.Frame; | ||||
| import io.rsocket.RSocket; | ||||
| import io.rsocket.RSocketFactory; | ||||
| import io.rsocket.transport.netty.client.TcpClientTransport; | ||||
|  | @ -39,6 +41,7 @@ import org.springframework.context.annotation.Bean; | |||
| import org.springframework.context.annotation.Configuration; | ||||
| import org.springframework.core.codec.CharSequenceEncoder; | ||||
| import org.springframework.core.codec.StringDecoder; | ||||
| import org.springframework.core.io.buffer.NettyDataBufferFactory; | ||||
| import org.springframework.messaging.handler.annotation.MessageMapping; | ||||
| import org.springframework.stereotype.Controller; | ||||
| 
 | ||||
|  | @ -61,6 +64,7 @@ public class RSocketServerToClientIntegrationTests { | |||
| 		context = new AnnotationConfigApplicationContext(RSocketConfig.class); | ||||
| 
 | ||||
| 		server = RSocketFactory.receive() | ||||
| 				.frameDecoder(Frame::retain)  // as per https://github.com/rsocket/rsocket-java#zero-copy | ||||
| 				.acceptor(context.getBean("serverAcceptor", MessageHandlerAcceptor.class)) | ||||
| 				.transport(TcpServerTransport.create("localhost", 7000)) | ||||
| 				.start() | ||||
|  | @ -104,6 +108,7 @@ public class RSocketServerToClientIntegrationTests { | |||
| 			rsocket = RSocketFactory.connect() | ||||
| 					.setupPayload(DefaultPayload.create("", destination)) | ||||
| 					.dataMimeType("text/plain") | ||||
| 					.frameDecoder(Frame::retain)  // as per https://github.com/rsocket/rsocket-java#zero-copy | ||||
| 					.acceptor(context.getBean("clientAcceptor", MessageHandlerAcceptor.class)) | ||||
| 					.transport(TcpClientTransport.create("localhost", 7000)) | ||||
| 					.start() | ||||
|  | @ -272,6 +277,7 @@ public class RSocketServerToClientIntegrationTests { | |||
| 			return RSocketStrategies.builder() | ||||
| 					.decoder(StringDecoder.allMimeTypes()) | ||||
| 					.encoder(CharSequenceEncoder.allMimeTypes()) | ||||
| 					.dataBufferFactory(new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT)) | ||||
| 					.build(); | ||||
| 		} | ||||
| 	} | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue