diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java index 47ba75468a..e048576f96 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java @@ -33,7 +33,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; -import java.util.function.IntPredicate; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; @@ -334,14 +333,23 @@ public abstract class DataBufferUtils { public static Flux takeUntilByteCount(Publisher publisher, long maxByteCount) { Assert.notNull(publisher, "Publisher must not be null"); Assert.isTrue(maxByteCount >= 0, "'maxByteCount' must be a positive number"); - AtomicLong countDown = new AtomicLong(maxByteCount); - return Flux.from(publisher) - .map(buffer -> { - long count = countDown.addAndGet(-buffer.readableByteCount()); - return count >= 0 ? buffer : buffer.slice(0, buffer.readableByteCount() + (int) count); - }) - .takeUntil(buffer -> countDown.get() <= 0); + return Flux.defer(() -> { + AtomicLong countDown = new AtomicLong(maxByteCount); + + return Flux.from(publisher) + .map(buffer -> { + long remainder = countDown.addAndGet(-buffer.readableByteCount()); + if (remainder < 0) { + int length = buffer.readableByteCount() + (int) remainder; + return buffer.slice(0, length); + } + else { + return buffer; + } + }) + .takeUntil(buffer -> countDown.get() <= 0); + }); // no doOnDiscard necessary, as this method does not drop buffers } /** @@ -355,26 +363,28 @@ public abstract class DataBufferUtils { public static Flux skipUntilByteCount(Publisher publisher, long maxByteCount) { Assert.notNull(publisher, "Publisher must not be null"); Assert.isTrue(maxByteCount >= 0, "'maxByteCount' must be a positive number"); - AtomicLong byteCountDown = new AtomicLong(maxByteCount); - return Flux.from(publisher) - .skipUntil(buffer -> { - int delta = -buffer.readableByteCount(); - if (byteCountDown.addAndGet(delta) >= 0) { - DataBufferUtils.release(buffer); - return false; - } - return true; - }) - .map(buffer -> { - long count = byteCountDown.get(); - if (count < 0) { - int skipCount = buffer.readableByteCount() + (int) count; - byteCountDown.set(0); - return buffer.slice(skipCount, buffer.readableByteCount() - skipCount); - } - return buffer; - }); + return Flux.defer(() -> { + AtomicLong countDown = new AtomicLong(maxByteCount); + + return Flux.from(publisher) + .skipUntil(buffer -> { + long remainder = countDown.addAndGet(-buffer.readableByteCount()); + return remainder < 0; + }) + .map(buffer -> { + long remainder = countDown.get(); + if (remainder < 0) { + countDown.set(0); + int start = buffer.readableByteCount() + (int)remainder; + int length = (int) -remainder; + return buffer.slice(start, length); + } + else { + return buffer; + } + }); + }).doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); } /** @@ -432,24 +442,14 @@ public abstract class DataBufferUtils { Assert.notNull(dataBuffers, "'dataBuffers' must not be null"); return Flux.from(dataBuffers) - .onErrorResume(DataBufferUtils::exceptionDataBuffer) .collectList() .filter(list -> !list.isEmpty()) - .flatMap(list -> { - for (int i = 0; i < list.size(); i++) { - DataBuffer dataBuffer = list.get(i); - if (dataBuffer instanceof ExceptionDataBuffer) { - list.subList(0, i).forEach(DataBufferUtils::release); - return Mono.error(((ExceptionDataBuffer) dataBuffer).throwable()); - } - } + .map(list -> { DataBufferFactory bufferFactory = list.get(0).factory(); - return Mono.just(bufferFactory.join(list)); - }); - } + return bufferFactory.join(list); + }) + .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); - private static Mono exceptionDataBuffer(Throwable throwable) { - return Mono.just(new ExceptionDataBuffer(throwable)); } @@ -638,153 +638,4 @@ public abstract class DataBufferUtils { } } - /** - * DataBuffer implementation that holds a {@link Throwable}, used in {@link #join(Publisher)}. - */ - private static final class ExceptionDataBuffer implements DataBuffer { - - private final Throwable throwable; - - - public ExceptionDataBuffer(Throwable throwable) { - this.throwable = throwable; - } - - public Throwable throwable() { - return this.throwable; - } - - // Unsupported - - @Override - public DataBufferFactory factory() { - throw new UnsupportedOperationException(); - } - - @Override - public int indexOf(IntPredicate predicate, int fromIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public int lastIndexOf(IntPredicate predicate, int fromIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public int readableByteCount() { - throw new UnsupportedOperationException(); - } - - @Override - public int writableByteCount() { - throw new UnsupportedOperationException(); - } - - @Override - public int capacity() { - throw new UnsupportedOperationException(); - } - - @Override - public DataBuffer capacity(int capacity) { - throw new UnsupportedOperationException(); - } - - @Override - public int readPosition() { - throw new UnsupportedOperationException(); - } - - @Override - public DataBuffer readPosition(int readPosition) { - throw new UnsupportedOperationException(); - } - - @Override - public int writePosition() { - throw new UnsupportedOperationException(); - } - - @Override - public DataBuffer writePosition(int writePosition) { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(int index) { - throw new UnsupportedOperationException(); - } - - @Override - public byte read() { - throw new UnsupportedOperationException(); - } - - @Override - public DataBuffer read(byte[] destination) { - throw new UnsupportedOperationException(); - } - - @Override - public DataBuffer read(byte[] destination, int offset, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public DataBuffer write(byte b) { - throw new UnsupportedOperationException(); - } - - @Override - public DataBuffer write(byte[] source) { - throw new UnsupportedOperationException(); - } - - @Override - public DataBuffer write(byte[] source, int offset, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public DataBuffer write(DataBuffer... buffers) { - throw new UnsupportedOperationException(); - } - - @Override - public DataBuffer write(ByteBuffer... buffers) { - throw new UnsupportedOperationException(); - } - - @Override - public DataBuffer slice(int index, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuffer asByteBuffer() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuffer asByteBuffer(int index, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream asInputStream() { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream asInputStream(boolean releaseOnClose) { - throw new UnsupportedOperationException(); - } - - @Override - public OutputStream asOutputStream() { - throw new UnsupportedOperationException(); - } - } - } diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java index cb47dac28a..8f53d733fe 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java @@ -412,6 +412,20 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { .verify(Duration.ofSeconds(5)); } + @Test + public void takeUntilByteCountErrorInFlux() { + DataBuffer foo = stringBuffer("foo"); + Flux flux = + Flux.just(foo).concatWith(Mono.error(new RuntimeException())); + + Flux result = DataBufferUtils.takeUntilByteCount(flux, 5L); + + StepVerifier.create(result) + .consumeNextWith(stringConsumer("foo")) + .expectError(RuntimeException.class) + .verify(Duration.ofSeconds(5)); + } + @Test public void takeUntilByteCountExact() { @@ -444,6 +458,18 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { .verify(Duration.ofSeconds(5)); } + @Test + public void skipUntilByteCountErrorInFlux() { + DataBuffer foo = stringBuffer("foo"); + Flux flux = + Flux.just(foo).concatWith(Mono.error(new RuntimeException())); + Flux result = DataBufferUtils.skipUntilByteCount(flux, 3L); + + StepVerifier.create(result) + .expectError(RuntimeException.class) + .verify(Duration.ofSeconds(5)); + } + @Test public void skipUntilByteCountShouldSkipAll() { DataBuffer foo = stringBuffer("foo");