diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java index d089e367e9e..0aa38699931 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java @@ -393,24 +393,16 @@ public abstract class DataBufferUtils { public static Flux takeUntilByteCount(Publisher publisher, long maxByteCount) { Assert.notNull(publisher, "Publisher must not be null"); Assert.isTrue(maxByteCount >= 0, "'maxByteCount' must be a positive number"); - AtomicLong byteCountDown = new AtomicLong(maxByteCount); + AtomicLong countDown = new AtomicLong(maxByteCount); - return Flux.from(publisher). - takeWhile(dataBuffer -> { - int delta = -dataBuffer.readableByteCount(); - long currentCount = byteCountDown.getAndAdd(delta); - return currentCount >= 0; - }). - map(dataBuffer -> { - long currentCount = byteCountDown.get(); - if (currentCount >= 0) { - return dataBuffer; - } - else { - // last buffer - int size = (int) (currentCount + dataBuffer.readableByteCount()); - return dataBuffer.slice(0, size); - } + return Flux.from(publisher) + .takeWhile(buffer -> { + int delta = -buffer.readableByteCount(); + return countDown.getAndAdd(delta) >= 0; + }) + .map(buffer -> { + long count = countDown.get(); + return count >= 0 ? buffer : buffer.slice(0, buffer.readableByteCount() + (int) count); }); } @@ -427,27 +419,23 @@ public abstract class DataBufferUtils { Assert.isTrue(maxByteCount >= 0, "'maxByteCount' must be a positive number"); AtomicLong byteCountDown = new AtomicLong(maxByteCount); - return Flux.from(publisher). - skipUntil(dataBuffer -> { - int delta = -dataBuffer.readableByteCount(); - long currentCount = byteCountDown.addAndGet(delta); - if (currentCount < 0) { - return true; - } - else { - DataBufferUtils.release(dataBuffer); + return Flux.from(publisher) + .skipUntil(buffer -> { + int delta = -buffer.readableByteCount(); + if (byteCountDown.addAndGet(delta) >= 0) { + DataBufferUtils.release(buffer); return false; } - }). - map(dataBuffer -> { - long currentCount = byteCountDown.get(); - // slice first buffer, then let others flow through - if (currentCount < 0) { - int skip = (int) (currentCount + dataBuffer.readableByteCount()); + return true; + }) + .map(buffer -> { + long count = byteCountDown.get(); + if (count < 0) { + int skipCount = buffer.readableByteCount() + (int) count; byteCountDown.set(0); - return dataBuffer.slice(skip, dataBuffer.readableByteCount() - skip); + return buffer.slice(skipCount, buffer.readableByteCount() - skipCount); } - return dataBuffer; + return buffer; }); } diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java index 028e7740f29..8f19ea31085 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java @@ -225,7 +225,7 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { } @Test - public void takeUntilByteCount() throws Exception { + public void takeUntilByteCount() { DataBuffer foo = stringBuffer("foo"); DataBuffer bar = stringBuffer("bar"); DataBuffer baz = stringBuffer("baz"); @@ -242,7 +242,7 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { } @Test - public void skipUntilByteCount() throws Exception { + public void skipUntilByteCount() { DataBuffer foo = stringBuffer("foo"); DataBuffer bar = stringBuffer("bar"); DataBuffer baz = stringBuffer("baz"); @@ -257,7 +257,7 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { } @Test - public void skipUntilByteCountShouldSkipAll() throws Exception { + public void skipUntilByteCountShouldSkipAll() { DataBuffer foo = stringBuffer("foo"); DataBuffer bar = stringBuffer("bar"); DataBuffer baz = stringBuffer("baz");