Buffer leak fixes

Address issues where buffers are allocated (and cached somehow) at or
before subscription, and before explicit demand.

The commit adds tests proving the leaks and fixes. The common thread
for all tests is a "zero demand" subscriber that subscribes  but does
not request, and then cancels without consuming anything.

Closes gh-22107
This commit is contained in:
Rossen Stoyanchev 2019-03-26 21:11:19 -04:00
parent 65b46079a2
commit c54355784e
16 changed files with 504 additions and 211 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,6 +24,7 @@ import reactor.core.publisher.Flux;
import org.springframework.core.ResolvableType;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.PooledDataBuffer;
import org.springframework.lang.Nullable;
import org.springframework.util.MimeType;
@ -47,9 +48,10 @@ public abstract class AbstractSingleValueEncoder<T> extends AbstractEncoder<T> {
public final Flux<DataBuffer> encode(Publisher<? extends T> inputStream, DataBufferFactory bufferFactory,
ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
return Flux.from(inputStream).
take(1).
concatMap(t -> encode(t, bufferFactory, elementType, mimeType, hints));
return Flux.from(inputStream)
.take(1)
.concatMap(value -> encode(value, bufferFactory, elementType, mimeType, hints))
.doOnDiscard(PooledDataBuffer.class, PooledDataBuffer::release);
}
/**

View File

@ -17,7 +17,6 @@
package org.springframework.core.codec;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.OptionalLong;
@ -89,24 +88,22 @@ public class ResourceRegionEncoder extends AbstractEncoder<ResourceRegion> {
return Mono.from(inputStream)
.flatMapMany(region -> {
if (!region.getResource().isReadable()) {
return Flux.error(new EncodingException("Resource " +
region.getResource() + " is not readable"));
return Flux.error(new EncodingException(
"Resource " + region.getResource() + " is not readable"));
}
return writeResourceRegion(region, bufferFactory, hints);
});
}
else {
final String boundaryString = Hints.getRequiredHint(hints, BOUNDARY_STRING_HINT);
byte[] startBoundary = getAsciiBytes("\r\n--" + boundaryString + "\r\n");
byte[] contentType =
(mimeType != null ? getAsciiBytes("Content-Type: " + mimeType + "\r\n") : new byte[0]);
byte[] contentType = mimeType != null ? getAsciiBytes("Content-Type: " + mimeType + "\r\n") : new byte[0];
return Flux.from(inputStream).
concatMap(region -> {
if (!region.getResource().isReadable()) {
return Flux.error(new EncodingException("Resource " +
region.getResource() + " is not readable"));
return Flux.error(new EncodingException(
"Resource " + region.getResource() + " is not readable"));
}
else {
return Flux.concat(
@ -121,11 +118,10 @@ public class ResourceRegionEncoder extends AbstractEncoder<ResourceRegion> {
private Flux<DataBuffer> getRegionPrefix(DataBufferFactory bufferFactory, byte[] startBoundary,
byte[] contentType, ResourceRegion region) {
return Flux.defer(() -> Flux.just(
bufferFactory.allocateBuffer(startBoundary.length).write(startBoundary),
bufferFactory.allocateBuffer(contentType.length).write(contentType),
bufferFactory.wrap(ByteBuffer.wrap(getContentRangeHeader(region))))
);
return Flux.just(
bufferFactory.wrap(startBoundary),
bufferFactory.wrap(contentType),
bufferFactory.wrap(getContentRangeHeader(region))); // only wrapping, no allocation
}
private Flux<DataBuffer> writeResourceRegion(
@ -146,8 +142,7 @@ public class ResourceRegionEncoder extends AbstractEncoder<ResourceRegion> {
private Flux<DataBuffer> getRegionSuffix(DataBufferFactory bufferFactory, String boundaryString) {
byte[] endBoundary = getAsciiBytes("\r\n--" + boundaryString + "--");
return Flux.defer(() -> Flux.just(
bufferFactory.allocateBuffer(endBoundary.length).write(endBoundary)));
return Flux.just(bufferFactory.wrap(endBoundary));
}
private byte[] getAsciiBytes(String in) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -82,40 +82,36 @@ public abstract class DataBufferUtils {
* Obtain a {@link ReadableByteChannel} from the given supplier, and read it into a
* {@code Flux} of {@code DataBuffer}s. Closes the channel when the flux is terminated.
* @param channelSupplier the supplier for the channel to read from
* @param dataBufferFactory the factory to create data buffers with
* @param bufferFactory the factory to create data buffers with
* @param bufferSize the maximum size of the data buffers
* @return a flux of data buffers read from the given channel
*/
public static Flux<DataBuffer> readByteChannel(
Callable<ReadableByteChannel> channelSupplier, DataBufferFactory dataBufferFactory, int bufferSize) {
Callable<ReadableByteChannel> channelSupplier, DataBufferFactory bufferFactory, int bufferSize) {
Assert.notNull(channelSupplier, "'channelSupplier' must not be null");
Assert.notNull(dataBufferFactory, "'dataBufferFactory' must not be null");
Assert.notNull(bufferFactory, "'dataBufferFactory' must not be null");
Assert.isTrue(bufferSize > 0, "'bufferSize' must be > 0");
return Flux.using(channelSupplier,
channel -> {
ReadableByteChannelGenerator generator =
new ReadableByteChannelGenerator(channel, dataBufferFactory,
bufferSize);
return Flux.generate(generator);
},
DataBufferUtils::closeChannel)
.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
channel -> Flux.generate(new ReadableByteChannelGenerator(channel, bufferFactory, bufferSize)),
DataBufferUtils::closeChannel);
// No doOnDiscard as operators used do not cache
}
/**
* Obtain a {@code AsynchronousFileChannel} from the given supplier, and read it into a
* {@code Flux} of {@code DataBuffer}s. Closes the channel when the flux is terminated.
* @param channelSupplier the supplier for the channel to read from
* @param dataBufferFactory the factory to create data buffers with
* @param bufferFactory the factory to create data buffers with
* @param bufferSize the maximum size of the data buffers
* @return a flux of data buffers read from the given channel
*/
public static Flux<DataBuffer> readAsynchronousFileChannel(
Callable<AsynchronousFileChannel> channelSupplier, DataBufferFactory dataBufferFactory, int bufferSize) {
Callable<AsynchronousFileChannel> channelSupplier, DataBufferFactory bufferFactory, int bufferSize) {
return readAsynchronousFileChannel(channelSupplier, 0, dataBufferFactory, bufferSize);
return readAsynchronousFileChannel(channelSupplier, 0, bufferFactory, bufferSize);
}
/**
@ -124,32 +120,30 @@ public abstract class DataBufferUtils {
* channel when the flux is terminated.
* @param channelSupplier the supplier for the channel to read from
* @param position the position to start reading from
* @param dataBufferFactory the factory to create data buffers with
* @param bufferFactory the factory to create data buffers with
* @param bufferSize the maximum size of the data buffers
* @return a flux of data buffers read from the given channel
*/
public static Flux<DataBuffer> readAsynchronousFileChannel(Callable<AsynchronousFileChannel> channelSupplier,
long position, DataBufferFactory dataBufferFactory, int bufferSize) {
long position, DataBufferFactory bufferFactory, int bufferSize) {
Assert.notNull(channelSupplier, "'channelSupplier' must not be null");
Assert.notNull(dataBufferFactory, "'dataBufferFactory' must not be null");
Assert.notNull(bufferFactory, "'dataBufferFactory' must not be null");
Assert.isTrue(position >= 0, "'position' must be >= 0");
Assert.isTrue(bufferSize > 0, "'bufferSize' must be > 0");
DataBuffer dataBuffer = dataBufferFactory.allocateBuffer(bufferSize);
ByteBuffer byteBuffer = dataBuffer.asByteBuffer(0, bufferSize);
Flux<DataBuffer> result = Flux.using(channelSupplier,
Flux<DataBuffer> flux = Flux.using(channelSupplier,
channel -> Flux.create(sink -> {
AsynchronousFileChannelReadCompletionHandler completionHandler =
new AsynchronousFileChannelReadCompletionHandler(channel,
sink, position, dataBufferFactory, bufferSize);
channel.read(byteBuffer, position, dataBuffer, completionHandler);
sink.onDispose(completionHandler::dispose);
ReadCompletionHandler handler =
new ReadCompletionHandler(channel, sink, position, bufferFactory, bufferSize);
DataBuffer dataBuffer = bufferFactory.allocateBuffer(bufferSize);
ByteBuffer byteBuffer = dataBuffer.asByteBuffer(0, bufferSize);
channel.read(byteBuffer, position, dataBuffer, handler);
sink.onDispose(handler::dispose);
}),
DataBufferUtils::closeChannel);
return result.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
return flux.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
}
/**
@ -246,8 +240,7 @@ public abstract class DataBufferUtils {
Flux<DataBuffer> flux = Flux.from(source);
return Flux.create(sink -> {
WritableByteChannelSubscriber subscriber =
new WritableByteChannelSubscriber(sink, channel);
WritableByteChannelSubscriber subscriber = new WritableByteChannelSubscriber(sink, channel);
sink.onDispose(subscriber);
flux.subscribe(subscriber);
});
@ -292,10 +285,9 @@ public abstract class DataBufferUtils {
Flux<DataBuffer> flux = Flux.from(source);
return Flux.create(sink -> {
AsynchronousFileChannelWriteCompletionHandler completionHandler =
new AsynchronousFileChannelWriteCompletionHandler(sink, channel, position);
sink.onDispose(completionHandler);
flux.subscribe(completionHandler);
WriteCompletionHandler handler = new WriteCompletionHandler(sink, channel, position);
sink.onDispose(handler);
flux.subscribe(handler);
});
}
@ -326,21 +318,21 @@ public abstract class DataBufferUtils {
Assert.notNull(publisher, "Publisher must not be null");
Assert.isTrue(maxByteCount >= 0, "'maxByteCount' must be a positive number");
return Flux.defer(() -> {
AtomicLong countDown = new AtomicLong(maxByteCount);
return Flux.from(publisher)
.map(buffer -> {
long remainder = countDown.addAndGet(-buffer.readableByteCount());
if (remainder < 0) {
int length = buffer.readableByteCount() + (int) remainder;
return buffer.slice(0, length);
}
else {
return buffer;
}
})
.takeUntil(buffer -> countDown.get() <= 0);
}); // no doOnDiscard necessary, as this method does not drop buffers
AtomicLong countDown = new AtomicLong(maxByteCount);
return Flux.from(publisher)
.map(buffer -> {
long remainder = countDown.addAndGet(-buffer.readableByteCount());
if (remainder < 0) {
int length = buffer.readableByteCount() + (int) remainder;
return buffer.slice(0, length);
}
else {
return buffer;
}
})
.takeUntil(buffer -> countDown.get() <= 0);
// No doOnDiscard as operators used do not cache (and drop) buffers
}
/**
@ -487,8 +479,7 @@ public abstract class DataBufferUtils {
}
private static class AsynchronousFileChannelReadCompletionHandler
implements CompletionHandler<Integer, DataBuffer> {
private static class ReadCompletionHandler implements CompletionHandler<Integer, DataBuffer> {
private final AsynchronousFileChannel channel;
@ -502,7 +493,7 @@ public abstract class DataBufferUtils {
private final AtomicBoolean disposed = new AtomicBoolean();
public AsynchronousFileChannelReadCompletionHandler(AsynchronousFileChannel channel,
public ReadCompletionHandler(AsynchronousFileChannel channel,
FluxSink<DataBuffer> sink, long position, DataBufferFactory dataBufferFactory, int bufferSize) {
this.channel = channel;
@ -586,7 +577,7 @@ public abstract class DataBufferUtils {
}
private static class AsynchronousFileChannelWriteCompletionHandler extends BaseSubscriber<DataBuffer>
private static class WriteCompletionHandler extends BaseSubscriber<DataBuffer>
implements CompletionHandler<Integer, ByteBuffer> {
private final FluxSink<DataBuffer> sink;
@ -601,7 +592,7 @@ public abstract class DataBufferUtils {
private final AtomicReference<DataBuffer> dataBuffer = new AtomicReference<>();
public AsynchronousFileChannelWriteCompletionHandler(
public WriteCompletionHandler(
FluxSink<DataBuffer> sink, AsynchronousFileChannel channel, long position) {
this.sink = sink;

View File

@ -20,6 +20,8 @@ import java.util.Collections;
import java.util.function.Consumer;
import org.junit.Test;
import org.reactivestreams.Subscription;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
@ -28,7 +30,6 @@ import org.springframework.core.ResolvableType;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.LeakAwareDataBufferFactory;
import org.springframework.core.io.buffer.support.DataBufferTestUtils;
@ -36,19 +37,18 @@ import org.springframework.core.io.support.ResourceRegion;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.nio.charset.StandardCharsets.*;
import static org.junit.Assert.*;
/**
* Test cases for {@link ResourceRegionEncoder} class.
*
* @author Brian Clozel
*/
public class ResourceRegionEncoderTests {
private ResourceRegionEncoder encoder = new ResourceRegionEncoder();
private DataBufferFactory bufferFactory = new LeakAwareDataBufferFactory();
private LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory();
@Test
@ -79,10 +79,13 @@ public class ResourceRegionEncoderTests {
.consumeNextWith(stringConsumer("Spring"))
.expectComplete()
.verify();
// TODO: https://github.com/reactor/reactor-core/issues/1634
// this.bufferFactory.checkForLeaks();
}
@Test
public void shouldEncodeMultipleResourceRegionsFileResource() throws Exception {
public void shouldEncodeMultipleResourceRegionsFileResource() {
Resource resource = new ClassPathResource("ResourceRegionEncoderTests.txt", getClass());
Flux<ResourceRegion> regions = Flux.just(
new ResourceRegion(resource, 0, 6),
@ -118,6 +121,33 @@ public class ResourceRegionEncoderTests {
.consumeNextWith(stringConsumer("\r\n--" + boundary + "--"))
.expectComplete()
.verify();
// TODO: https://github.com/reactor/reactor-core/issues/1634
// this.bufferFactory.checkForLeaks();
}
@Test // gh-
public void cancelWithoutDemandForMultipleResourceRegions() {
Resource resource = new ClassPathResource("ResourceRegionEncoderTests.txt", getClass());
Flux<ResourceRegion> regions = Flux.just(
new ResourceRegion(resource, 0, 6),
new ResourceRegion(resource, 7, 9),
new ResourceRegion(resource, 17, 4),
new ResourceRegion(resource, 22, 17)
);
String boundary = MimeTypeUtils.generateMultipartBoundaryString();
Flux<DataBuffer> flux = this.encoder.encode(regions, this.bufferFactory,
ResolvableType.forClass(ResourceRegion.class),
MimeType.valueOf("text/plain"),
Collections.singletonMap(ResourceRegionEncoder.BOUNDARY_STRING_HINT, boundary)
);
ZeroDemandSubscriber subscriber = new ZeroDemandSubscriber();
flux.subscribe(subscriber);
subscriber.cancel();
this.bufferFactory.checkForLeaks();
}
@Test
@ -142,6 +172,9 @@ public class ResourceRegionEncoderTests {
.consumeNextWith(stringConsumer("Spring"))
.expectError(EncodingException.class)
.verify();
// TODO: https://github.com/reactor/reactor-core/issues/1634
// this.bufferFactory.checkForLeaks();
}
protected Consumer<DataBuffer> stringConsumer(String expected) {
@ -154,4 +187,12 @@ public class ResourceRegionEncoderTests {
}
private static class ZeroDemandSubscriber extends BaseSubscriber<DataBuffer> {
@Override
protected void hookOnSubscribe(Subscription subscription) {
// Just subscribe without requesting
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -121,9 +121,23 @@ public abstract class AbstractDataBufferAllocatingTestCase {
if (this.bufferFactory instanceof NettyDataBufferFactory) {
ByteBufAllocator allocator = ((NettyDataBufferFactory) this.bufferFactory).getByteBufAllocator();
if (allocator instanceof PooledByteBufAllocator) {
PooledByteBufAllocatorMetric metric = ((PooledByteBufAllocator) allocator).metric();
long total = getAllocations(metric.directArenas()) + getAllocations(metric.heapArenas());
assertEquals("ByteBuf Leak: " + total + " unreleased allocations", 0, total);
Instant start = Instant.now();
while (true) {
PooledByteBufAllocatorMetric metric = ((PooledByteBufAllocator) allocator).metric();
long total = getAllocations(metric.directArenas()) + getAllocations(metric.heapArenas());
if (total == 0) {
return;
}
if (Instant.now().isBefore(start.plus(Duration.ofSeconds(5)))) {
try {
Thread.sleep(50);
}
catch (InterruptedException ex) {
// ignore
}
}
assertEquals("ByteBuf Leak: " + total + " unreleased allocations", 0, total);
}
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -35,8 +35,11 @@ import java.util.concurrent.CountDownLatch;
import io.netty.buffer.ByteBuf;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.mockito.stubbing.Answer;
import org.reactivestreams.Subscription;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
@ -184,6 +187,20 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase {
.verify();
}
// TODO: Remove ignore after https://github.com/reactor/reactor-core/issues/1634
@Ignore
@Test // gh-22107
public void readAsynchronousFileChannelCancelWithoutDemand() throws Exception {
URI uri = this.resource.getURI();
Flux<DataBuffer> flux = DataBufferUtils.readAsynchronousFileChannel(
() -> AsynchronousFileChannel.open(Paths.get(uri), StandardOpenOption.READ),
this.bufferFactory, 3);
BaseSubscriber<DataBuffer> subscriber = new ZeroDemandSubscriber();
flux.subscribe(subscriber);
subscriber.cancel();
}
@Test
public void readResource() throws Exception {
Flux<DataBuffer> flux = DataBufferUtils.read(this.resource, this.bufferFactory, 3);
@ -735,5 +752,12 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase {
}
private static class ZeroDemandSubscriber extends BaseSubscriber<DataBuffer> {
@Override
protected void hookOnSubscribe(Subscription subscription) {
// Just subscribe without requesting
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,9 +17,14 @@
package org.springframework.core.io.buffer;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jetbrains.annotations.NotNull;
import org.junit.After;
@ -37,6 +42,9 @@ import org.springframework.util.Assert;
*/
public class LeakAwareDataBufferFactory implements DataBufferFactory {
private static final Log logger = LogFactory.getLog(LeakAwareDataBufferFactory.class);
private final DataBufferFactory delegate;
private final List<LeakAwareDataBuffer> created = new ArrayList<>();
@ -65,13 +73,27 @@ public class LeakAwareDataBufferFactory implements DataBufferFactory {
* method.
*/
public void checkForLeaks() {
this.created.stream()
.filter(LeakAwareDataBuffer::isAllocated)
.findFirst()
.map(LeakAwareDataBuffer::leakError)
.ifPresent(leakError -> {
throw leakError;
});
Instant start = Instant.now();
while (true) {
if (this.created.stream().noneMatch(LeakAwareDataBuffer::isAllocated)) {
return;
}
if (Instant.now().isBefore(start.plus(Duration.ofSeconds(5)))) {
try {
Thread.sleep(50);
}
catch (InterruptedException ex) {
// ignore
}
}
List<AssertionError> errors = this.created.stream()
.filter(LeakAwareDataBuffer::isAllocated)
.map(LeakAwareDataBuffer::leakError)
.collect(Collectors.toList());
errors.forEach(it -> logger.error("Leaked error: ", it));
throw new AssertionError(errors.size() + " buffer leaks detected (see logs above)");
}
}
@Override

View File

@ -29,6 +29,7 @@ import org.springframework.core.codec.AbstractEncoder;
import org.springframework.core.codec.Encoder;
import org.springframework.core.codec.Hints;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.PooledDataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpLogging;
import org.springframework.http.MediaType;
@ -124,7 +125,8 @@ public class EncoderHttpMessageWriter<T> implements HttpMessageWriter<T> {
}))
.flatMap(buffer -> {
headers.setContentLength(buffer.readableByteCount());
return message.writeWith(Mono.just(buffer));
return message.writeWith(Mono.fromCallable(() -> buffer)
.doOnDiscard(PooledDataBuffer.class, PooledDataBuffer::release));
});
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -134,7 +134,7 @@ public class FormHttpMessageWriter extends LoggingCodecSupport
logFormData(form, hints);
String value = serializeForm(form, charset);
ByteBuffer byteBuffer = charset.encode(value);
DataBuffer buffer = message.bufferFactory().wrap(byteBuffer);
DataBuffer buffer = message.bufferFactory().wrap(byteBuffer); // wrapping only, no allocation
message.getHeaders().setContentLength(byteBuffer.remaining());
return message.writeWith(Mono.just(buffer));
});

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -143,30 +143,32 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter<Objec
sb.append("data:");
}
return Flux.concat(encodeText(sb, mediaType, factory),
Flux<DataBuffer> flux = Flux.concat(
encodeText(sb, mediaType, factory),
encodeData(data, valueType, mediaType, factory, hints),
encodeText("\n", mediaType, factory))
.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
encodeText("\n", mediaType, factory));
return flux.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
});
}
private void writeField(String fieldName, Object fieldValue, StringBuilder stringBuilder) {
stringBuilder.append(fieldName);
stringBuilder.append(':');
stringBuilder.append(fieldValue.toString());
stringBuilder.append("\n");
private void writeField(String fieldName, Object fieldValue, StringBuilder sb) {
sb.append(fieldName);
sb.append(':');
sb.append(fieldValue.toString());
sb.append("\n");
}
@SuppressWarnings("unchecked")
private <T> Flux<DataBuffer> encodeData(@Nullable T data, ResolvableType valueType,
private <T> Flux<DataBuffer> encodeData(@Nullable T dataValue, ResolvableType valueType,
MediaType mediaType, DataBufferFactory factory, Map<String, Object> hints) {
if (data == null) {
if (dataValue == null) {
return Flux.empty();
}
if (data instanceof String) {
String text = (String) data;
if (dataValue instanceof String) {
String text = (String) dataValue;
return Flux.from(encodeText(StringUtils.replace(text, "\n", "\ndata:") + "\n", mediaType, factory));
}
@ -175,15 +177,14 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter<Objec
}
return ((Encoder<T>) this.encoder)
.encode(Mono.just(data), factory, valueType, mediaType, hints)
.encode(Mono.just(dataValue), factory, valueType, mediaType, hints)
.concatWith(encodeText("\n", mediaType, factory));
}
private Mono<DataBuffer> encodeText(CharSequence text, MediaType mediaType, DataBufferFactory bufferFactory) {
Assert.notNull(mediaType.getCharset(), "Expected MediaType with charset");
byte[] bytes = text.toString().getBytes(mediaType.getCharset());
return Mono.defer(() ->
Mono.just(bufferFactory.allocateBuffer(bytes.length).write(bytes)));
return Mono.fromCallable(() -> bufferFactory.wrap(bytes)); // wrapping, not allocating
}
@Override

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -126,12 +126,10 @@ public abstract class AbstractJackson2Encoder extends Jackson2CodecSupport imple
.filter(mediaType -> mediaType.isCompatibleWith(mimeType))
.findFirst()
.map(mediaType -> {
byte[] separator =
STREAM_SEPARATORS.getOrDefault(mediaType, NEWLINE_SEPARATOR);
byte[] separator = STREAM_SEPARATORS.getOrDefault(mediaType, NEWLINE_SEPARATOR);
return Flux.from(inputStream).map(value -> {
DataBuffer buffer =
encodeValue(value, mimeType, bufferFactory, elementType, hints,
encoding);
DataBuffer buffer = encodeValue(
value, mimeType, bufferFactory, elementType, hints, encoding);
if (separator != null) {
buffer.write(separator);
}
@ -139,11 +137,9 @@ public abstract class AbstractJackson2Encoder extends Jackson2CodecSupport imple
});
})
.orElseGet(() -> {
ResolvableType listType =
ResolvableType.forClassWithGenerics(List.class, elementType);
ResolvableType listType = ResolvableType.forClassWithGenerics(List.class, elementType);
return Flux.from(inputStream).collectList().map(list ->
encodeValue(list, mimeType, bufferFactory, listType, hints,
encoding)).flux();
encodeValue(list, mimeType, bufferFactory, listType, hints, encoding)).flux();
});
}
}
@ -174,8 +170,7 @@ public abstract class AbstractJackson2Encoder extends Jackson2CodecSupport imple
OutputStream outputStream = buffer.asOutputStream();
try {
JsonGenerator generator =
getObjectMapper().getFactory().createGenerator(outputStream, encoding);
JsonGenerator generator = getObjectMapper().getFactory().createGenerator(outputStream, encoding);
writer.writeValue(generator, value);
release = false;
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -40,7 +40,7 @@ import org.springframework.core.codec.Hints;
import org.springframework.core.io.Resource;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.core.io.buffer.PooledDataBuffer;
import org.springframework.core.log.LogFormatUtils;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
@ -99,8 +99,6 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport
private final List<MediaType> supportedMediaTypes;
private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory();
/**
* Constructor with a default list of part writers (String and Resource).
@ -187,17 +185,17 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport
ResolvableType elementType, @Nullable MediaType mediaType, ReactiveHttpOutputMessage outputMessage,
Map<String, Object> hints) {
return Mono.from(inputStream).flatMap(map -> {
if (this.formWriter == null || isMultipart(map, mediaType)) {
return writeMultipart(map, outputMessage, hints);
}
else {
@SuppressWarnings("unchecked")
MultiValueMap<String, String> formData = (MultiValueMap<String, String>) map;
return this.formWriter.write(Mono.just(formData), elementType, mediaType, outputMessage, hints);
}
});
return Mono.from(inputStream)
.flatMap(map -> {
if (this.formWriter == null || isMultipart(map, mediaType)) {
return writeMultipart(map, outputMessage, hints);
}
else {
@SuppressWarnings("unchecked")
Mono<MultiValueMap<String, String>> input = Mono.just((MultiValueMap<String, String>) map);
return this.formWriter.write(input, elementType, mediaType, outputMessage, hints);
}
});
}
private boolean isMultipart(MultiValueMap<String, ?> map, @Nullable MediaType contentType) {
@ -230,9 +228,12 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport
LogFormatUtils.formatValue(map, !traceOn) :
"parts " + map.keySet() + " (content masked)"));
DataBufferFactory bufferFactory = outputMessage.bufferFactory();
Flux<DataBuffer> body = Flux.fromIterable(map.entrySet())
.concatMap(entry -> encodePartValues(boundary, entry.getKey(), entry.getValue()))
.concatWith(Mono.just(generateLastLine(boundary)));
.concatMap(entry -> encodePartValues(boundary, entry.getKey(), entry.getValue(), bufferFactory))
.concatWith(generateLastLine(boundary, bufferFactory))
.doOnDiscard(PooledDataBuffer.class, PooledDataBuffer::release);
return outputMessage.writeWith(body);
}
@ -245,14 +246,16 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport
return MimeTypeUtils.generateMultipartBoundary();
}
private Flux<DataBuffer> encodePartValues(byte[] boundary, String name, List<?> values) {
private Flux<DataBuffer> encodePartValues(
byte[] boundary, String name, List<?> values, DataBufferFactory bufferFactory) {
return Flux.concat(values.stream().map(v ->
encodePart(boundary, name, v)).collect(Collectors.toList()));
encodePart(boundary, name, v, bufferFactory)).collect(Collectors.toList()));
}
@SuppressWarnings("unchecked")
private <T> Flux<DataBuffer> encodePart(byte[] boundary, String name, T value) {
MultipartHttpOutputMessage outputMessage = new MultipartHttpOutputMessage(this.bufferFactory, getCharset());
private <T> Flux<DataBuffer> encodePart(byte[] boundary, String name, T value, DataBufferFactory bufferFactory) {
MultipartHttpOutputMessage outputMessage = new MultipartHttpOutputMessage(bufferFactory, getCharset());
HttpHeaders outputHeaders = outputMessage.getHeaders();
T body;
@ -314,37 +317,46 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport
Flux<DataBuffer> partContent = partContentReady.thenMany(Flux.defer(outputMessage::getBody));
return Flux.concat(Mono.just(generateBoundaryLine(boundary)), partContent, Mono.just(generateNewLine()));
return Flux.concat(
generateBoundaryLine(boundary, bufferFactory),
partContent,
generateNewLine(bufferFactory));
}
private DataBuffer generateBoundaryLine(byte[] boundary) {
DataBuffer buffer = this.bufferFactory.allocateBuffer(boundary.length + 4);
buffer.write((byte)'-');
buffer.write((byte)'-');
buffer.write(boundary);
buffer.write((byte)'\r');
buffer.write((byte)'\n');
return buffer;
private Mono<DataBuffer> generateBoundaryLine(byte[] boundary, DataBufferFactory bufferFactory) {
return Mono.fromCallable(() -> {
DataBuffer buffer = bufferFactory.allocateBuffer(boundary.length + 4);
buffer.write((byte)'-');
buffer.write((byte)'-');
buffer.write(boundary);
buffer.write((byte)'\r');
buffer.write((byte)'\n');
return buffer;
});
}
private DataBuffer generateNewLine() {
DataBuffer buffer = this.bufferFactory.allocateBuffer(2);
buffer.write((byte)'\r');
buffer.write((byte)'\n');
return buffer;
private Mono<DataBuffer> generateNewLine(DataBufferFactory bufferFactory) {
return Mono.fromCallable(() -> {
DataBuffer buffer = bufferFactory.allocateBuffer(2);
buffer.write((byte)'\r');
buffer.write((byte)'\n');
return buffer;
});
}
private DataBuffer generateLastLine(byte[] boundary) {
DataBuffer buffer = this.bufferFactory.allocateBuffer(boundary.length + 6);
buffer.write((byte)'-');
buffer.write((byte)'-');
buffer.write(boundary);
buffer.write((byte)'-');
buffer.write((byte)'-');
buffer.write((byte)'\r');
buffer.write((byte)'\n');
return buffer;
private Mono<DataBuffer> generateLastLine(byte[] boundary, DataBufferFactory bufferFactory) {
return Mono.fromCallable(() -> {
DataBuffer buffer = bufferFactory.allocateBuffer(boundary.length + 6);
buffer.write((byte)'-');
buffer.write((byte)'-');
buffer.write(boundary);
buffer.write((byte)'-');
buffer.write((byte)'-');
buffer.write((byte)'\r');
buffer.write((byte)'\n');
return buffer;
});
}
@ -391,29 +403,31 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport
if (this.body != null) {
return Mono.error(new IllegalStateException("Multiple calls to writeWith() not supported"));
}
this.body = Flux.just(generateHeaders()).concatWith(body);
this.body = generateHeaders().concatWith(body);
// We don't actually want to write (just save the body Flux)
return Mono.empty();
}
private DataBuffer generateHeaders() {
DataBuffer buffer = this.bufferFactory.allocateBuffer();
for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) {
byte[] headerName = entry.getKey().getBytes(this.charset);
for (String headerValueString : entry.getValue()) {
byte[] headerValue = headerValueString.getBytes(this.charset);
buffer.write(headerName);
buffer.write((byte)':');
buffer.write((byte)' ');
buffer.write(headerValue);
buffer.write((byte)'\r');
buffer.write((byte)'\n');
private Mono<DataBuffer> generateHeaders() {
return Mono.fromCallable(() -> {
DataBuffer buffer = this.bufferFactory.allocateBuffer();
for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) {
byte[] headerName = entry.getKey().getBytes(this.charset);
for (String headerValueString : entry.getValue()) {
byte[] headerValue = headerValueString.getBytes(this.charset);
buffer.write(headerName);
buffer.write((byte)':');
buffer.write((byte)' ');
buffer.write(headerValue);
buffer.write((byte)'\r');
buffer.write((byte)'\n');
}
}
}
buffer.write((byte)'\r');
buffer.write((byte)'\n');
return buffer;
buffer.write((byte)'\r');
buffer.write((byte)'\n');
return buffer;
});
}
@Override

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,7 @@ import javax.xml.bind.annotation.XmlRootElement;
import javax.xml.bind.annotation.XmlType;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.AbstractSingleValueEncoder;
@ -97,7 +98,7 @@ public class Jaxb2XmlEncoder extends AbstractSingleValueEncoder<Object> {
}
@Override
protected Flux<DataBuffer> encode(Object value, DataBufferFactory dataBufferFactory,
protected Flux<DataBuffer> encode(Object value, DataBufferFactory bufferFactory,
ResolvableType type, @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
if (!Hints.isLoggingSuppressed(hints)) {
@ -107,29 +108,30 @@ public class Jaxb2XmlEncoder extends AbstractSingleValueEncoder<Object> {
});
}
boolean release = true;
DataBuffer buffer = dataBufferFactory.allocateBuffer(1024);
OutputStream outputStream = buffer.asOutputStream();
Class<?> clazz = ClassUtils.getUserClass(value);
try {
Marshaller marshaller = initMarshaller(clazz);
marshaller.marshal(value, outputStream);
release = false;
return Flux.just(buffer);
}
catch (MarshalException ex) {
return Flux.error(new EncodingException(
"Could not marshal " + value.getClass() + " to XML", ex));
}
catch (JAXBException ex) {
return Flux.error(new CodecException("Invalid JAXB configuration", ex));
}
finally {
if (release) {
DataBufferUtils.release(buffer);
return Flux.defer(() -> {
boolean release = true;
DataBuffer buffer = bufferFactory.allocateBuffer(1024);
OutputStream outputStream = buffer.asOutputStream();
Class<?> clazz = ClassUtils.getUserClass(value);
try {
Marshaller marshaller = initMarshaller(clazz);
marshaller.marshal(value, outputStream);
release = false;
return Mono.fromCallable(() -> buffer); // Rely on doOnDiscard in base class
}
}
catch (MarshalException ex) {
return Flux.error(new EncodingException(
"Could not marshal " + value.getClass() + " to XML", ex));
}
catch (JAXBException ex) {
return Flux.error(new CodecException("Invalid JAXB configuration", ex));
}
finally {
if (release) {
DataBufferUtils.release(buffer);
}
}
});
}
private Marshaller initMarshaller(Class<?> clazz) throws JAXBException {

View File

@ -0,0 +1,186 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.codec;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.function.Supplier;
import org.junit.Test;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.LeakAwareDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpOutputMessage;
import org.springframework.http.client.MultipartBodyBuilder;
import org.springframework.http.codec.json.Jackson2JsonEncoder;
import org.springframework.http.codec.multipart.MultipartHttpMessageWriter;
import org.springframework.http.codec.xml.Jaxb2XmlEncoder;
/**
* Test scenarios for data buffer leaks.
* @author Rossen Stoyanchev
* @since 5.2
*/
public class CodecDataBufferLeakTests {
private final LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory();
@Test // gh-22107
public void cancelWithEncoderHttpMessageWriterAndSingleValue() {
CharSequenceEncoder encoder = CharSequenceEncoder.allMimeTypes();
HttpMessageWriter<CharSequence> writer = new EncoderHttpMessageWriter<>(encoder);
CancellingOutputMessage outputMessage = new CancellingOutputMessage(this.bufferFactory);
writer.write(Mono.just("foo"), ResolvableType.forType(String.class), MediaType.TEXT_PLAIN,
outputMessage, Collections.emptyMap()).block(Duration.ofSeconds(5));
this.bufferFactory.checkForLeaks();
}
@Test // gh-22107
public void cancelWithJackson() {
Jackson2JsonEncoder encoder = new Jackson2JsonEncoder();
Flux<DataBuffer> flux = encoder.encode(Flux.just(new Pojo("foofoo", "barbar"), new Pojo("bar", "baz")),
this.bufferFactory, ResolvableType.forClass(Pojo.class),
MediaType.APPLICATION_JSON, Collections.emptyMap());
BaseSubscriber<DataBuffer> subscriber = new ZeroDemandSubscriber();
flux.subscribe(subscriber); // Assume sync execution (e.g. encoding with Flux.just)..
subscriber.cancel();
this.bufferFactory.checkForLeaks();
}
@Test // gh-22107
public void cancelWithJaxb2() {
Jaxb2XmlEncoder encoder = new Jaxb2XmlEncoder();
Flux<DataBuffer> flux = encoder.encode(Mono.just(new Pojo("foo", "bar")),
this.bufferFactory, ResolvableType.forClass(Pojo.class),
MediaType.APPLICATION_XML, Collections.emptyMap());
BaseSubscriber<DataBuffer> subscriber = new ZeroDemandSubscriber();
flux.subscribe(subscriber); // Assume sync execution (e.g. encoding with Flux.just)..
subscriber.cancel();
this.bufferFactory.checkForLeaks();
}
@Test // gh-22107
public void cancelWithMultipartContent() {
MultipartBodyBuilder builder = new MultipartBodyBuilder();
builder.part("part1", "value1");
builder.part("part2", "value2");
List<HttpMessageWriter<?>> writers = ClientCodecConfigurer.create().getWriters();
MultipartHttpMessageWriter writer = new MultipartHttpMessageWriter(writers);
CancellingOutputMessage outputMessage = new CancellingOutputMessage(this.bufferFactory);
writer.write(Mono.just(builder.build()), null, MediaType.MULTIPART_FORM_DATA,
outputMessage, Collections.emptyMap()).block(Duration.ofSeconds(5));
this.bufferFactory.checkForLeaks();
}
@Test // gh-22107
public void cancelWithSse() {
ServerSentEvent<?> event = ServerSentEvent.builder().data("bar").id("c42").event("foo").build();
ServerSentEventHttpMessageWriter writer = new ServerSentEventHttpMessageWriter(new Jackson2JsonEncoder());
CancellingOutputMessage outputMessage = new CancellingOutputMessage(this.bufferFactory);
writer.write(Mono.just(event), ResolvableType.forClass(ServerSentEvent.class), MediaType.TEXT_EVENT_STREAM,
outputMessage, Collections.emptyMap()).block(Duration.ofSeconds(5));
this.bufferFactory.checkForLeaks();
}
private static class CancellingOutputMessage implements ReactiveHttpOutputMessage {
private final DataBufferFactory bufferFactory;
public CancellingOutputMessage(DataBufferFactory bufferFactory) {
this.bufferFactory = bufferFactory;
}
@Override
public DataBufferFactory bufferFactory() {
return this.bufferFactory;
}
@Override
public void beforeCommit(Supplier<? extends Mono<Void>> action) {
}
@Override
public boolean isCommitted() {
return false;
}
@Override
public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
Flux<? extends DataBuffer> flux = Flux.from(body);
BaseSubscriber<DataBuffer> subscriber = new ZeroDemandSubscriber();
flux.subscribe(subscriber); // Assume sync execution (e.g. encoding with Flux.just)..
subscriber.cancel();
return Mono.empty();
}
@Override
public Mono<Void> writeAndFlushWith(Publisher<? extends Publisher<? extends DataBuffer>> body) {
Flux<? extends DataBuffer> flux = Flux.from(body).concatMap(Flux::from);
BaseSubscriber<DataBuffer> subscriber = new ZeroDemandSubscriber();
flux.subscribe(subscriber); // Assume sync execution (e.g. encoding with Flux.just)..
subscriber.cancel();
return Mono.empty();
}
@Override
public Mono<Void> setComplete() {
throw new UnsupportedOperationException();
}
@Override
public HttpHeaders getHeaders() {
return new HttpHeaders();
}
}
private static class ZeroDemandSubscriber extends BaseSubscriber<DataBuffer> {
@Override
protected void hookOnSubscribe(Subscription subscription) {
// Just subscribe without requesting
}
}
}

View File

@ -197,8 +197,11 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTestCas
Mono<MultiValueMap<String, HttpEntity<?>>> result = Mono.just(bodyBuilder.build());
Map<String, Object> hints = Collections.emptyMap();
this.writer.write(result, null, MediaType.MULTIPART_FORM_DATA, this.response, hints).block();
this.writer.write(result, null, MediaType.MULTIPART_FORM_DATA, this.response, Collections.emptyMap())
.block(Duration.ofSeconds(5));
// Make sure body is consumed to avoid leak reports
this.response.getBodyAsString().block(Duration.ofSeconds(5));
}
@Test // SPR-16376

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -147,6 +147,7 @@ public class MockServerHttpResponse extends AbstractServerHttpResponse {
Assert.notNull(charset, "'charset' must not be null");
byte[] bytes = new byte[buffer.readableByteCount()];
buffer.read(bytes);
DataBufferUtils.release(buffer);
return new String(bytes, charset);
}